# Data
---

In [1]:
import os
import json
import heapq
import pickle
import random
import multiprocessing

import spacy
from tqdm.auto import tqdm

from valerie.utils import get_logger
from valerie.preprocessing import extract_words_from_url, clean_text
from valerie.scoring import validate_predictions_phase2, compute_score_phase2
from valerie.modeling import SequenceClassificationModel, SequenceClassificationDataset, SequenceClassificationExample

In [2]:
nproc=4

In [3]:
nlp = spacy.load("en_core_web_lg")

In [4]:
_logger = get_logger()

In [8]:
with open("data/phase2-validation/processed/responses.pkl", "rb") as fi:
    responses = pickle.load(fi)

In [9]:
# responses = responses[:8]

In [10]:
len(responses)

500

In [11]:
def compute_responses_score(results, claims_dict):
    predictions = {}
    perfect_predictions = {}
    labels = {}

    for k, hits in results.items():
        claim = claims_dict[k]
        labels[claim.id] = claim.to_dict()
        
        hits = sorted(hits, key=lambda x: x[1], reverse=True) # sort by score
        predictions[claim.id] = {
            "label": claim.label,
            "explanation": "",
            "related_articles": {
                i + 1: x
                for i, x in enumerate([v[0] for v in hits[:2]])
            }
        }
        perfect_predictions[claim.id] = {
            "label": claim.label,
            "explanation": "",
            "related_articles": {
                i + 1: x
                for i, x in enumerate([v[0] for v in hits if v[0] in claim.related_articles.values()][:2])
            }
        }

    validate_predictions_phase2(predictions)
    score = compute_score_phase2(labels, predictions)
    validate_predictions_phase2(perfect_predictions)
    perfect_score = compute_score_phase2(labels, perfect_predictions)
    return {
        "perfect_rerank_score": perfect_score["score"],
        "perfect_rerank_error": perfect_score["error"],
        "api_score": score["score"],
        "api_error": score["error"],
    }

In [12]:
def create_text_a(claim):
    text_a = claim.claim
    text_a += " "
    text_a += claim.claimant if claim.claimant else "no claimant"
    text_a += " "
    text_a += claim.date.split()[0].split("T")[0] if claim.date else "no date"
    return clean_text(text_a)

def create_text_b_content(article):
    text_b = ""
    if article.source:
        text_b += article.source + ". "
    if article.title:
        text_b += article.title + ". "
    if article.url:
        url_words = extract_words_from_url(article.url)
        if url_words:
            text_b += " ".join(url_words) + ". "
    if article.content:
        text_b += article.content
    return clean_text(text_b)

# Examples
---

In [13]:
def create_text_b_curated(article, claim):
    support = []
    for sent in article.nlp.sents:
        support.append({
            "text": sent.text,
            "score": claim.nlp.similarity(sent)
        })
    support = heapq.nlargest(16, support, key=lambda x: x["score"])
    claim.support[article.index] = support
    text_b = clean_text(" ".join([s["text"] for s in support]))
    return text_b

In [14]:
claims_dict = {}
articles_dict = {}

examples = []
for res in tqdm(responses):
    claim = res["claim"]
    claim.text_a = create_text_a(claim)
    claim.nlp = nlp(claim.text_a, disable=["textcat", "tagger", "parser", "ner"])
    claim.res = res
    claim.support = {}
    
    claims_dict[claim.index] = claim

    hits = [hit["article"] for hit in res["res"]["hits"]["hits"]]
    related_articles_url_set = set(claim.related_articles.values())

    for article in hits[:16]:
        if not hasattr(article, "nlp"):
            article.nlp = nlp(create_text_b_content(article), disable=["textcat", "tagger", "ner"])
        if not hasattr(article, "text_b"):
            article.text_b = create_text_b_curated(article, claim)
        articles_dict[article.index] = article
            
        examples.append(SequenceClassificationExample(
            guid=claim.index,
            text_a=claim.text_a,
            text_b=article.text_b,
            label=1 if article.url in related_articles_url_set else 0,
            art_id=article.index
        ))

HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))

  





In [16]:
print(len(claims_dict))
print(len(claims_dict)*30)
print(len(articles_dict))

500
15000
3839


In [None]:
# def _visit(claim):
#     _examples = []
#     hits = [(hit["score"], hit["article"]) for hit in res["res"]["hits"]["hits"]][:16]
#     related_articles_url_set = set(claim.related_articles.values())
    
#     for score, article in hits:
#         text_b = create_text_b_curated(article, claim)
            
#         _examples.append(SequenceClassificationExample(
#             guid=claim.index,
#             text_a=claim.text_a,
#             text_b=text_b,
#             label=1 if article.url in related_articles_url_set else 0,
#             art_id=article.index
#         ))
#     return _examples

# pool = multiprocessing.Pool(16)
# examples = []
# misses = 0
# for result in tqdm(pool.imap_unordered(_visit, responses), total=len(responses)):
#     if result is None:
#         misses += 1
#     else:
#         examples.extend(result)

In [21]:
examples[0]

{
  "guid": "Phase2ValidationDataset/383",
  "text_a": "\"Huge! Results From Breaking Chloroquine Study Show 100% Cure Rate For Patients Infected With The Coronavirus.\" Facebook post 2020-03-19",
  "text_b": "school of medicine advisor announced a 100% cure rate in a controlled study done in france of 40 people with the #chinacoronavirus with a malaria drug called #hydroxychloroquine. results from breaking chloroquine study show 100% cure rate for patients infected with the coronavirus | tea party (gateway pundit) \u2013 on monday dr. anthony fauci, director of the national institute of allergy and infectious diseases, announced that the first trial vaccine for the coronavirus is now being tested. the test includes 45 people age 18-55 and they are receiving two injections, one at zero days, one at 28 days. tea party huge results from breaking chloroquine study show cure rate for patients infected with the coronavirus. dr. rigano said their study found that those covid-19 patients who 

In [25]:
list(claims_dict.values())[-10].claim

'An Ohio man died from complications related to the COVID-19 coronavirus disease weeks after he downplayed the virus on social media.'

In [None]:
print(json.dumps(list(claims_dict.values())[-10].support, indent=2))

In [26]:
print(len(responses)*16)
print(len(examples))

8000
8000


# Predict
---

In [27]:
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "dryrun"
os.environ["WANDB_WATCH"] = "false"

In [28]:
# "castorini/monot5-base-msmarco"
# "castorini/monobert-large-msmarco"
# "nboost/pt-bert-large-msmarco"]:
pretrained_model_name_or_path = "castorini/monobert-large-msmarco"

In [29]:
model = SequenceClassificationModel.from_pretrained(pretrained_model_name_or_path)

[2020-07-12 04:01:32,139] INFO:transformers.configuration_utils: loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/castorini/monobert-large-msmarco/config.json from cache at /home/jay/.cache/torch/transformers/643500d870067d59f219f7b5652919267c01bfa98024e2e74f53b28c1b6aff2b.4c88e2dec8f8b017f319f6db2b157fee632c0860d9422e4851bd0d6999f9ce38
[2020-07-12 04:01:32,140] INFO:transformers.configuration_utils: Model config BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

[2020-07-12 04:01:32,140] INFO:transformers.tokenization_utils_base: Model name 'castorini/monobert-large-msmarco' not f

In [30]:
examples_dataset = model.create_dataset(examples, nproc=16)

[2020-07-12 04:01:42,715] INFO:valerie.modeling: ... converting examples to features ...


HBox(children=(FloatProgress(value=0.0, description='converting examples to features', max=8000.0, style=Progr…




In [31]:
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3,4,5,6,7"
!echo $CUDA_VISIBLE_DEVICES

0,1,2,3,4,5,6,7


In [32]:
predict_output = model.predict(examples_dataset, predict_batch_size=256)

[2020-07-12 04:02:00,398] INFO:transformers.training_args: PyTorch: setting up devices
[2020-07-12 04:02:04,803] INFO:transformers.trainer: Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[2020-07-12 04:02:04,806] INFO:transformers.trainer: ***** Running Prediction *****
[2020-07-12 04:02:04,806] INFO:transformers.trainer:   Num examples = 8000
[2020-07-12 04:02:04,807] INFO:transformers.trainer:   Batch size = 2048


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=4.0, style=ProgressStyle(description_wid…






In [33]:
claims_dict = {res["claim"].index: res["claim"] for res in responses if res["res"]}
api_scores_dict = {
    res["claim"].index: {
        hit["article"].index: hit["score"] for hit in res["res"]["hits"]["hits"]
    }
    for res in responses
    if res["res"]
}


rerank_just_api_responses = {
    res["claim"].index: [
        (hit["article"].index, hit["score"]) for hit in res["res"]["hits"]["hits"]
    ]
    for res in responses
    if res["res"]
}

rerank_just_trans_responses = {res["claim"].index: [] for res in responses if res["res"]}

rerank_both_responses = {
    res["claim"].index: []
    for res in responses
    if res["res"]
}

for example, proba in tqdm(zip(examples, predict_output.predictions)):
    proba = float(proba[1]) # get probability that the article is related

    rerank_just_trans_responses[example.guid].append((example.art_id, proba))
    rerank_both_responses[example.guid].append((example.art_id, proba + api_scores_dict[example.guid][example.art_id]))
    
print('api')
print(json.dumps(compute_responses_score(rerank_just_api_responses, claims_dict), indent=2))
print()
print('trans')
print(json.dumps(compute_responses_score(rerank_just_trans_responses, claims_dict), indent=2))
print()
print('both')
print(json.dumps(compute_responses_score(rerank_both_responses, claims_dict), indent=2))
print()
print()
print()

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


api
{
  "perfect_rerank_score": 1.1557801493610302,
  "perfect_rerank_error": "'None'",
  "api_score": 0.6617083840074162,
  "api_error": "'None'"
}

trans
{
  "perfect_rerank_score": 1.0491232931485748,
  "perfect_rerank_error": "'None'",
  "api_score": 0.7396992039496483,
  "api_error": "'None'"
}

both
{
  "perfect_rerank_score": 1.0491232931485748,
  "perfect_rerank_error": "'None'",
  "api_score": 0.7086715926336704,
  "api_error": "'None'"
}





# Results
---

In [34]:
compute_responses_score(rerank_just_api_responses, claims_dict)

{'perfect_rerank_score': 1.1557801493610302,
 'perfect_rerank_error': "'None'",
 'api_score': 0.6617083840074162,
 'api_error': "'None'"}

In [35]:
compute_responses_score(rerank_just_trans_responses, claims_dict)

{'perfect_rerank_score': 1.0491232931485748,
 'perfect_rerank_error': "'None'",
 'api_score': 0.7396992039496483,
 'api_error': "'None'"}

In [36]:
compute_responses_score(rerank_both_responses, claims_dict)

{'perfect_rerank_score': 1.0491232931485748,
 'perfect_rerank_error': "'None'",
 'api_score': 0.7086715926336704,
 'api_error': "'None'"}