# Data
---

In [None]:
import os
import json
import heapq
import pickle
import random
import multiprocessing

import spacy
from tqdm.auto import tqdm

from valerie.utils import get_logger
from valerie.preprocessing import extract_words_from_url, clean_text
from valerie.scoring import validate_predictions_phase2, compute_score_phase2
from valerie.modeling import SequenceClassificationModel, SequenceClassificationDataset, SequenceClassificationExample

In [None]:
nlp = spacy.load("en_core_web_lg")

In [None]:
_logger = get_logger()

In [None]:
with open("data/phase2-3/processed/responses.pkl", "rb") as fi:
    responses = pickle.load(fi)

In [None]:
print(len(responses))

In [None]:
def compute_responses_score(results, claims_dict):
    predictions = {}
    perfect_predictions = {}
    labels = {}

    for k, hits in results.items():
        claim = claims_dict[k]
        labels[claim.id] = claim.to_dict()
        
        hits = sorted(hits, key=lambda x: x[1], reverse=True) # sort by score
        predictions[claim.id] = {
            "label": claim.label,
            "explanation": "",
            "related_articles": {
                i + 1: x
                for i, x in enumerate([v[0] for v in hits[:2]])
            }
        }
        perfect_predictions[claim.id] = {
            "label": claim.label,
            "explanation": "",
            "related_articles": {
                i + 1: x
                for i, x in enumerate([v[0] for v in hits if v[0] in claim.related_articles.values()][:2])
            }
        }

    validate_predictions_phase2(predictions)
    score = compute_score_phase2(labels, predictions)
    validate_predictions_phase2(perfect_predictions)
    perfect_score = compute_score_phase2(labels, perfect_predictions)
    return {
        "perfect_rerank_score": perfect_score["score"],
        "perfect_rerank_error": perfect_score["error"],
        "api_score": score["score"],
        "api_error": score["error"],
    }

In [None]:
def create_text_a(claim):
    text_a = claim.claim
    text_a += " "
    text_a += claim.claimant if claim.claimant else "no claimant"
    text_a += " "
    text_a += claim.date.split()[0].split("T")[0] if claim.date else "no date"
    return clean_text(text_a)

def create_text_b_content(article):
    text_b = ""
    if article.source:
        text_b += article.source + ". "
    if article.title:
        text_b += article.title + ". "
    if article.url:
        url_words = extract_words_from_url(article.url)
        if url_words:
            text_b += " ".join(url_words) + ". "
    if article.content:
        text_b += article.content
    return clean_text(text_b)

# Run Spacy on Data

### Claims

In [None]:
misses = 0
claims_list = []
for res in tqdm(responses):
    if not res["res"]:
        misses += 1
        continue
    claim = res["claim"]
    claim.text_a = create_text_a(claim)
    claim.res = res
    claim.support = {}
    claims_list.append(claim)

In [None]:
claims_texts = [claim.text_a for claim in claims_list]

In [None]:
claims_docs = [doc for doc in tqdm(nlp.pipe(claims_texts, n_process=16, disable=["textcat", "tagger", "parser", "ner"]), total=len(claims_texts))]

In [None]:
claims_dict = {}
for claim, doc in tqdm(zip(claims_list, claims_docs)):
    claim.doc = doc
    claims_dict[claim.index] = claim

### Articles

In [None]:
misses = 0
articles_list = []
for res in tqdm(responses):
    if not res["res"]:
        misses += 1
        continue
    for hit in res["res"]["hits"]["hits"]:
        article = hit["article"]
#         article.text_b = create_text_b_content(article)
        articles_list.append(article)

In [None]:
articles_list = list(set(articles_list))

In [None]:
def _text_b_text(article):
    return article, create_text_b_content(article)

articles_texts = {}
pool = multiprocessing.Pool(16)
for article, text_b in tqdm(pool.imap_unordered(_text_b_text, articles_list), total=len(articles_list)):
    articles_texts[article.index] = text_b

In [None]:
pool.close()

In [None]:
for article in articles_list:
    article.text_b = articles_texts[article.index]

In [None]:
articles_texts = [article.text_b for article in tqdm(articles_list)]

In [None]:
articles_docs = [doc for doc in tqdm(nlp.pipe(articles_texts, n_process=16, disable=["textcat", "tagger", "ner"]), total=len(articles_texts))]

In [None]:
articles_dict = {}
for article, doc in tqdm(zip(articles_list, articles_docs)):
    article.doc = doc
    articles_dict[article.index] = article

# Examples
---

In [None]:
def create_text_b_curated(article, claim):
    support = []
    for sent in article.doc.sents:
        support.append({
            "text": sent.text,
            "score": claim.doc.similarity(sent)
        })
    support = heapq.nlargest(32, support, key=lambda x: x["score"])
    claim.support[article.index] = support
    text_b = clean_text(" ".join([s["text"] for s in support]))
    return text_b

def gen_examples(claim):
    hits_indices = [hit["url"] for hit in claim.res["res"]["hits"]["hits"]]
    hits = [articles_dict[idx] for idx in hits_indices]
    
    related_articles_url_set = set(claim.related_articles.values())

    examples_to_add = []
    for article in hits:
        article.text_b = create_text_b_curated(article, claim)

        examples_to_add.append(SequenceClassificationExample(
            guid=claim.index,
            text_a=claim.text_a,
            text_b=article.text_b,
            label=1 if article.url in related_articles_url_set else 0,
            art_id=article.index
        ))
    return examples_to_add

In [None]:
# examples = []
# for examples_to_add in tqdm(pool.imap_unordered(gen_examples, claims_dict.values()), total=len(claims_dict)):
#     examples.extend(examples_to_add)

examples = []
for claim in tqdm(claims_dict.values()):
    hits_indices = [hit["url"] for hit in claim.res["res"]["hits"]["hits"]]
    hits = [articles_dict[idx] for idx in hits_indices]

    related_articles_url_set = set(claim.related_articles.values())

    for article in hits:
        article.text_b = create_text_b_curated(article, claim)

        examples.append(SequenceClassificationExample(
            guid=claim.index,
            text_a=claim.text_a,
            text_b=article.text_b,
            label=1 if article.url in related_articles_url_set else 0,
            art_id=article.index
        ))

In [None]:
print(len(claims_dict))
print(len(articles_dict))
print()
print(len(claims_dict)*30)
print(len(examples))

In [None]:
print(examples[0])

In [None]:
print(list(claims_dict.values())[0].claim)

In [None]:
# print(json.dumps(list(claims_dict.values())[0].support, indent=2))

In [None]:
print(len(responses)*16)
print(len(examples))

# Predict
---

In [None]:
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "dryrun"
os.environ["WANDB_WATCH"] = "false"

In [None]:
# "castorini/monot5-base-msmarco"
# "castorini/monobert-large-msmarco"
# "nboost/pt-bert-large-msmarco"]:
pretrained_model_name_or_path = "castorini/monobert-large-msmarco"

In [None]:
model = SequenceClassificationModel.from_pretrained(pretrained_model_name_or_path)

In [None]:
examples_dataset = model.create_dataset(examples, nproc=16)

In [None]:
predict_output = model.predict(examples_dataset, predict_batch_size=256)

In [1]:
claims_dict = {res["claim"].index: res["claim"] for res in responses if res["res"]}
api_scores_dict = {
    res["claim"].index: {
        hit["article"].index: hit["score"] for hit in res["res"]["hits"]["hits"]
    }
    for res in responses
    if res["res"]
}


rerank_just_api_responses = {
    res["claim"].index: [
        (hit["article"].index, hit["score"]) for hit in res["res"]["hits"]["hits"]
    ]
    for res in responses
    if res["res"]
}

rerank_just_trans_responses = {res["claim"].index: [] for res in responses if res["res"]}

rerank_both_responses = {
    res["claim"].index: []
    for res in responses
    if res["res"]
}

for example, proba in tqdm(zip(examples, predict_output.predictions)):
    proba = float(proba[1]) # get probability that the article is related

    rerank_just_trans_responses[example.guid].append((example.art_id, proba))
    rerank_both_responses[example.guid].append((example.art_id, proba + api_scores_dict[example.guid][example.art_id]))
    
print('api')
print(json.dumps(compute_responses_score(rerank_just_api_responses, claims_dict), indent=2))
print()
print('trans')
print(json.dumps(compute_responses_score(rerank_just_trans_responses, claims_dict), indent=2))
print()
print('both')
print(json.dumps(compute_responses_score(rerank_both_responses, claims_dict), indent=2))
print()
print()
print()

api
{
  "perfect_rerank_score": 0.918748461309641,
  "perfect_rerank_error": "'None'",
  "api_score": 0.48984172731830333,
  "api_error": "'None'"
}

trans
{
  "perfect_rerank_score": 0.918748461309641,
  "perfect_rerank_error": "'None'",
  "api_score": 0.5750774971552669,
  "api_error": "'None'"
}

both
{
  "perfect_rerank_score": 0.918748461309641,
  "perfect_rerank_error": "'None'",
  "api_score": 0.5405780150774728,
  "api_error": "'None'"
}
