Taken from notebook 5_msmarco_pretrained_on_all_responses.ipynb (just testing on first 1000 responses to compare with notebook 7_)

In [8]:
import os
import json
import pickle
import random
import multiprocessing

from tqdm.notebook import tqdm
from valerie.utils import get_logger
from valerie.preprocessing import extract_words_from_url, clean_text
from valerie.scoring import validate_predictions_phase2, compute_score_phase2
from valerie.modeling import SequenceClassificationModel, SequenceClassificationDataset, SequenceClassificationExample

In [2]:
_logger = get_logger()

In [3]:
with open("data/phase2-3/processed/responses.pkl", "rb") as fi:
    responses = pickle.load(fi)

In [4]:
responses = responses[:1000]

In [5]:
len(responses)

1000

In [6]:
def compute_responses_score(results, claims_dict):
    predictions = {}
    perfect_predictions = {}
    labels = {}

    for k, hits in results.items():
        claim = claims_dict[k]
        labels[claim.id] = claim.to_dict()
        
        hits = sorted(hits, key=lambda x: x[1], reverse=True) # sort by score
        predictions[claim.id] = {
            "label": claim.label,
            "explanation": "",
            "related_articles": {
                i + 1: x
                for i, x in enumerate([v[0] for v in hits[:2]])
            }
        }
        perfect_predictions[claim.id] = {
            "label": claim.label,
            "explanation": "",
            "related_articles": {
                i + 1: x
                for i, x in enumerate([v[0] for v in hits if v[0] in claim.related_articles.values()][:2])
            }
        }

    validate_predictions_phase2(predictions)
    score = compute_score_phase2(labels, predictions)
    validate_predictions_phase2(perfect_predictions)
    perfect_score = compute_score_phase2(labels, perfect_predictions)
    return {
        "perfect_rerank_score": perfect_score["score"],
        "perfect_rerank_error": perfect_score["error"],
        "api_score": score["score"],
        "api_error": score["error"],
    }

# Examples
---

In [7]:
def create_text_a(claim):
    text_a = claim.claim
    text_a += " "
    text_a += claim.claimant if claim.claimant else "no claimant"
    text_a += " "
    text_a += claim.date.split()[0].split("T")[0] if claim.date else "no date"
    return clean_text(text_a)

def create_text_b(article):
    text_b = ""
    if article.source:
        text_b += article.source + ". "
    if article.title:
        text_b += article.title + ". "
    if article.url:
        url_words = extract_words_from_url(article.url)
        if url_words:
            text_b += " ".join(url_words) + ". "
    if article.content:
        text_b += article.content
    return clean_text(text_b)

In [9]:
def _visit(res):
    _examples = []
    claim = res["claim"]
    if not res["res"]:
        return None
    hits = [(hit["score"], hit["article"]) for hit in res["res"]["hits"]["hits"]]
    
    text_a = create_text_a(claim)
    
    related_articles_url_set = set(claim.related_articles.values())
    
    for score, article in hits:
        text_b = create_text_b(article)
            
        _examples.append(SequenceClassificationExample(
            guid=claim.index,
            text_a=text_a,
            text_b=text_b,
            label=1 if article.url in related_articles_url_set else 0,
            art_id=article.index
        ))
    return _examples

pool = multiprocessing.Pool(48)
examples = []
misses = 0
for result in tqdm(pool.imap_unordered(_visit, responses), total=len(responses)):
    if result is None:
        misses += 1
    else:
        examples.extend(result)

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




In [10]:
examples[-1]

{
  "guid": "Phase2Dataset/10277",
  "text_a": "\"First indictment issued in Russian bribery case tied to Obama-era Uranium One deal.\" Republican News 2018-01-13",
  "text_b": "timesofisrael. Besides Sa'ar, another top Likud MK to challenge PM if primaries held -- report | The Times of Israel. times israel live blog entry besides saar another top likud challenge primaries held report. Report: State prosecutor says Netanyahu can\u2019t form government under indictment While law allows a prime minister to keep serving while facing criminal charges, Shai Nitzan said to tell his staff PM can\u2019t get mandate to form new coalition in such a situation The Times of Israel is liveblogging Thursday\u2019s events as they unfold. After Gideon Sa\u2019ar, another senior Likud lawmaker intends to run against Prime Minister Benjamin Netanyahu if party primaries are held, the Kan public broadcaster reports. The report does not name that MK, but says he\u2019ll likely announce his bid publicly in t

In [11]:
misses

0

In [12]:
print(len(responses)*30)
print(len(examples))

30000
26629


# Predict
---

In [13]:
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "dryrun"
os.environ["WANDB_WATCH"] = "false"

In [14]:
# "castorini/monot5-base-msmarco"
# "castorini/monobert-large-msmarco"
# "nboost/pt-bert-large-msmarco"]:
pretrained_model_name_or_path = "castorini/monobert-large-msmarco"
nproc=48
predict_batch_size=64

In [15]:
model = SequenceClassificationModel.from_pretrained(pretrained_model_name_or_path)

[2020-07-12 14:58:03,871] INFO:transformers.configuration_utils: loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/castorini/monobert-large-msmarco/config.json from cache at /home/jay/.cache/torch/transformers/643500d870067d59f219f7b5652919267c01bfa98024e2e74f53b28c1b6aff2b.4c88e2dec8f8b017f319f6db2b157fee632c0860d9422e4851bd0d6999f9ce38
[2020-07-12 14:58:03,872] INFO:transformers.configuration_utils: Model config BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

[2020-07-12 14:58:03,873] INFO:transformers.tokenization_utils_base: Model name 'castorini/monobert-large-msmarco' not f

In [16]:
examples_dataset = model.create_dataset(examples, nproc=nproc)

[2020-07-12 14:58:15,665] INFO:valerie.modeling: ... converting examples to features ...


HBox(children=(FloatProgress(value=0.0, description='converting examples to features', max=26629.0, style=Prog…




In [17]:
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3,4,5,6,7"
!echo $CUDA_VISIBLE_DEVICES

0,1,2,3,4,5,6,7


In [18]:
predict_output = model.predict(examples_dataset, predict_batch_size=256)

[2020-07-12 14:58:56,511] INFO:transformers.training_args: PyTorch: setting up devices
[2020-07-12 14:59:00,458] INFO:transformers.trainer: Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[2020-07-12 14:59:00,463] INFO:transformers.trainer: ***** Running Prediction *****
[2020-07-12 14:59:00,463] INFO:transformers.trainer:   Num examples = 26629
[2020-07-12 14:59:00,464] INFO:transformers.trainer:   Batch size = 2048


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=14.0, style=ProgressStyle(description_wi…






In [19]:
claims_dict = {res["claim"].index: res["claim"] for res in responses if res["res"]}
api_scores_dict = {
    res["claim"].index: {
        hit["article"].index: hit["score"] for hit in res["res"]["hits"]["hits"]
    }
    for res in responses
    if res["res"]
}


rerank_just_api_responses = {
    res["claim"].index: [
        (hit["article"].index, hit["score"]) for hit in res["res"]["hits"]["hits"]
    ]
    for res in responses
    if res["res"]
}

rerank_just_trans_responses = {res["claim"].index: [] for res in responses if res["res"]}

rerank_both_responses = {
    res["claim"].index: []
    for res in responses
    if res["res"]
}

for example, proba in tqdm(zip(examples, predict_output.predictions)):
    proba = float(proba[1]) # get probability that the article is related

    rerank_just_trans_responses[example.guid].append((example.art_id, proba))
    rerank_both_responses[example.guid].append((example.art_id, proba + api_scores_dict[example.guid][example.art_id]))
    
print('api')
print(json.dumps(compute_responses_score(rerank_just_api_responses, claims_dict), indent=2))
print()
print('trans')
print(json.dumps(compute_responses_score(rerank_just_trans_responses, claims_dict), indent=2))
print()
print('both')
print(json.dumps(compute_responses_score(rerank_both_responses, claims_dict), indent=2))
print()
print()
print()

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


api
{
  "perfect_rerank_score": 0.8630615319780865,
  "perfect_rerank_error": "'None'",
  "api_score": 0.448855518144285,
  "api_error": "'None'"
}

trans
{
  "perfect_rerank_score": 0.8630615319780865,
  "perfect_rerank_error": "'None'",
  "api_score": 0.4973083704115188,
  "api_error": "'None'"
}

both
{
  "perfect_rerank_score": 0.8630615319780865,
  "perfect_rerank_error": "'None'",
  "api_score": 0.4954667497853678,
  "api_error": "'None'"
}





# Results
---

In [20]:
compute_responses_score(rerank_just_api_responses, claims_dict)

{'perfect_rerank_score': 0.8630615319780865,
 'perfect_rerank_error': "'None'",
 'api_score': 0.448855518144285,
 'api_error': "'None'"}

In [21]:
compute_responses_score(rerank_just_trans_responses, claims_dict)

{'perfect_rerank_score': 0.8630615319780865,
 'perfect_rerank_error': "'None'",
 'api_score': 0.4973083704115188,
 'api_error': "'None'"}

In [22]:
compute_responses_score(rerank_both_responses, claims_dict)

{'perfect_rerank_score': 0.8630615319780865,
 'perfect_rerank_error': "'None'",
 'api_score': 0.4954667497853678,
 'api_error': "'None'"}