# Data
---

In [1]:
import os
import json
import pickle
import random
import multiprocessing

from tqdm.auto import tqdm
from valerie.utils import get_logger
from valerie.preprocessing import extract_words_from_url, clean_text
from valerie.scoring import validate_predictions_phase2, compute_score_phase2
from valerie.modeling import SequenceClassificationModel, SequenceClassificationDataset, SequenceClassificationExample

In [5]:
_logger = get_logger()

In [6]:
with open("data/phase2-3/processed/responses.pkl", "rb") as fi:
    responses = pickle.load(fi)

In [7]:
# responses = random.sample(responses, k=100)

In [8]:
len(responses)

13061

In [9]:
def compute_responses_score(results, claims_dict):
    predictions = {}
    perfect_predictions = {}
    labels = {}

    for k, hits in results.items():
        claim = claims_dict[k]
        labels[claim.id] = claim.to_dict()
        
        hits = sorted(hits, key=lambda x: x[1], reverse=True) # sort by score
        predictions[claim.id] = {
            "label": claim.label,
            "explanation": "",
            "related_articles": {
                i + 1: x
                for i, x in enumerate([v[0] for v in hits[:2]])
            }
        }
        perfect_predictions[claim.id] = {
            "label": claim.label,
            "explanation": "",
            "related_articles": {
                i + 1: x
                for i, x in enumerate([v[0] for v in hits if v[0] in claim.related_articles.values()][:2])
            }
        }

    validate_predictions_phase2(predictions)
    score = compute_score_phase2(labels, predictions)
    validate_predictions_phase2(perfect_predictions)
    perfect_score = compute_score_phase2(labels, perfect_predictions)
    return {
        "perfect_rerank_score": perfect_score["score"],
        "perfect_rerank_error": perfect_score["error"],
        "api_score": score["score"],
        "api_error": score["error"],
    }

# Examples
---

In [10]:
def create_text_a(claim):
    text_a = claim.claim
    text_a += " "
    text_a += claim.claimant if claim.claimant else "no claimant"
    text_a += " "
    text_a += claim.date.split()[0].split("T")[0] if claim.date else "no date"
    return clean_text(text_a)

def create_text_b(article):
    text_b = ""
    if article.source:
        text_b += article.source + ". "
    if article.title:
        text_b += article.title + ". "
    if article.url:
        url_words = extract_words_from_url(article.url)
        if url_words:
            text_b += " ".join(url_words) + ". "
    if article.content:
        text_b += article.content
    return clean_text(text_b)

In [None]:
def _visit(res):
    _examples = []
    claim = res["claim"]
    if not res["res"]:
        return None
    hits = [(hit["score"], hit["article"]) for hit in res["res"]["hits"]["hits"]]
    
    text_a = create_text_a(claim)
    
    related_articles_url_set = set(claim.related_articles.values())
    
    for score, article in hits:
        text_b = create_text_b(article)
            
        _examples.append(SequenceClassificationExample(
            guid=claim.index,
            text_a=text_a,
            text_b=text_b,
            label=1 if article.url in related_articles_url_set else 0,
            art_id=article.index
        ))
    return _examples

pool = multiprocessing.Pool(48)
examples = []
misses = 0
for result in tqdm(pool.imap_unordered(_visit, responses), total=len(responses)):
    if result is None:
        misses += 1
    else:
        examples.extend(result)

In [18]:
with open("data/phase2-3/processed/rerank_examples.pkl", "wb") as fo:
    pickle.dump(examples, fo)

In [11]:
with open("data/phase2-3/processed/rerank_examples.pkl", "rb") as fi:
    examples = pickle.load(fi)

ERROR! Session/line number was not unique in database. History logging moved to new session 52


In [13]:
examples[-1]

{
  "guid": "phase2/10461",
  "text_a": "\"The Embassy urges the Department of State and Department of Justice to take urgent measures to respect the legitimate rights of the Russian citizen, as well as to ensure proper conditions of Pyotr (Peter) Levashov\u2019s detention and the protection of his human dignity,\" the statement said. \"The Embassy also demands that the Russian citizen is provided with medicine to treat his diagnosed diseases. We also expect human rights organizations to intervene in this situation.\" Consulate General of Russia in New York 2018-02-08",
  "text_b": "newyorker. Four Women Accuse New York\u2019s Attorney General, Eric Schneiderman, of Physical Abuse | The New Yorker. new yorker news news desk four women accuse new york attorney general physical abuse. Four Women Accuse New York\u2019s Attorney General of Physical Abuse Eric Schneiderman has raised his profile as a voice against sexual misconduct. Now, after suing Harvey Weinstein, he faces a #MeToo recko

In [14]:
misses

0

In [15]:
print(len(responses)*30)
print(len(examples))

391830
349041


# Predict
---

In [16]:
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "dryrun"
os.environ["WANDB_WATCH"] = "false"

In [17]:
# "castorini/monot5-base-msmarco"
# "castorini/monobert-large-msmarco"
# "nboost/pt-bert-large-msmarco"]:
pretrained_model_name_or_path = "castorini/monobert-large-msmarco"
nproc=48
predict_batch_size=64

In [18]:
model = SequenceClassificationModel.from_pretrained(pretrained_model_name_or_path)

[2020-07-11 23:28:24,572] INFO:transformers.configuration_utils: loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/castorini/monobert-large-msmarco/config.json from cache at /home/jay/.cache/torch/transformers/643500d870067d59f219f7b5652919267c01bfa98024e2e74f53b28c1b6aff2b.4c88e2dec8f8b017f319f6db2b157fee632c0860d9422e4851bd0d6999f9ce38
[2020-07-11 23:28:24,573] INFO:transformers.configuration_utils: Model config BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

[2020-07-11 23:28:24,573] INFO:transformers.tokenization_utils_base: Model name 'castorini/monobert-large-msmarco' not f

In [19]:
examples_dataset = model.create_dataset(examples, nproc=nproc)

[2020-07-11 23:28:39,318] INFO:valerie.modeling: ... converting examples to features ...


HBox(children=(FloatProgress(value=0.0, description='converting examples to features', max=349041.0, style=Pro…




In [20]:
examples_dataset.save("data/phase2-3/processed/rerank_examples_dataset.bin")

[2020-07-11 23:37:50,143] INFO:valerie.modeling: .. saving features to cached file data/phase2-3/processed/rerank_examples_dataset.bin ...


In [21]:
!du -csh data/phase2-3/processed/rerank_examples_dataset.bin

1.2G	data/phase2-3/processed/rerank_examples_dataset.bin
1.2G	total


In [24]:
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3,4,5,6,7"
!echo $CUDA_VISIBLE_DEVICES

0,1,2,3,4,5,6,7


In [25]:
predict_output = model.predict(examples_dataset, predict_batch_size=256)

[2020-07-11 23:48:23,567] INFO:transformers.training_args: PyTorch: setting up devices
[2020-07-11 23:48:23,578] INFO:transformers.trainer: Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[2020-07-11 23:48:23,580] INFO:transformers.trainer: ***** Running Prediction *****
[2020-07-11 23:48:23,580] INFO:transformers.trainer:   Num examples = 349041
[2020-07-11 23:48:23,580] INFO:transformers.trainer:   Batch size = 1024


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=341.0, style=ProgressStyle(description_w…




In [27]:
with open("data/phase2-3/processed/rerank_predict_output.pkl", "wb") as fo:
    pickle.dump(predict_output, fo)

In [26]:
claims_dict = {res["claim"].index: res["claim"] for res in responses if res["res"]}
api_scores_dict = {
    res["claim"].index: {
        hit["article"].index: hit["score"] for hit in res["res"]["hits"]["hits"]
    }
    for res in responses
    if res["res"]
}


rerank_just_api_responses = {
    res["claim"].index: [
        (hit["article"].index, hit["score"]) for hit in res["res"]["hits"]["hits"]
    ]
    for res in responses
    if res["res"]
}

rerank_just_trans_responses = {res["claim"].index: [] for res in responses if res["res"]}

rerank_both_responses = {
    res["claim"].index: []
    for res in responses
    if res["res"]
}

for example, proba in tqdm(zip(examples, predict_output.predictions)):
    proba = float(proba[1]) # get probability that the article is related

    rerank_just_trans_responses[example.guid].append((example.art_id, proba))
    rerank_both_responses[example.guid].append((example.art_id, proba + api_scores_dict[example.guid][example.art_id]))
    
print('api')
print(json.dumps(compute_responses_score(rerank_just_api_responses, claims_dict), indent=2))
print()
print('trans')
print(json.dumps(compute_responses_score(rerank_just_trans_responses, claims_dict), indent=2))
print()
print('both')
print(json.dumps(compute_responses_score(rerank_both_responses, claims_dict), indent=2))
print()
print()
print()

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


api
{
  "perfect_rerank_score": 0.9141486415632643,
  "perfect_rerank_error": "'None'",
  "api_score": 0.48792873363807776,
  "api_error": "'None'"
}

trans
{
  "perfect_rerank_score": 0.9141486415632643,
  "perfect_rerank_error": "'None'",
  "api_score": 0.5452106483278761,
  "api_error": "'None'"
}

both
{
  "perfect_rerank_score": 0.9141486415632643,
  "perfect_rerank_error": "'None'",
  "api_score": 0.5376940740316324,
  "api_error": "'None'"
}





# Results
---

In [29]:
compute_responses_score(rerank_just_api_responses, claims_dict)

{'perfect_rerank_score': 0.9141486415632643,
 'perfect_rerank_error': "'None'",
 'api_score': 0.48792873363807776,
 'api_error': "'None'"}

In [30]:
compute_responses_score(rerank_just_trans_responses, claims_dict)

{'perfect_rerank_score': 0.9141486415632643,
 'perfect_rerank_error': "'None'",
 'api_score': 0.5452106483278761,
 'api_error': "'None'"}

In [31]:
compute_responses_score(rerank_both_responses, claims_dict)

{'perfect_rerank_score': 0.9141486415632643,
 'perfect_rerank_error': "'None'",
 'api_score': 0.5376940740316324,
 'api_error': "'None'"}