# HW12: Scientific Claim Verification

**<span style="color:red">Important Instructions, read carefully!</span>** 

* Remember that these homework work as a completion grade. The homework is structured in two parts: (1) Information Retrieval from a fact base and (2) claim verification. In case you already submitted 9 homeworks and want to only submit a 10th notebook, it is fine to just do the first part. 

* In this notebook, we will build an automated claim verification system for scientific claims based on the [SciFact dataset](https://arxiv.org/abs/2004.14974). 

* In case you need additional computational resources (GPUs), please get in touch. It is possible to solve the homework (on a downsampled dataset) without these. If you want to build really cool systems, you probably want to use the whole dataset and train models which require compute not feasible on your local machine -- get in touch.

* Next, the best models to date on this dataset perform poorly. If you build a system yielding competitive scores, it is possible to do follow-up work! There exists a [global leaderboard](https://leaderboard.allenai.org/scifact/submissions/public) where you can submit test set results if you like. The [baseline system](https://arxiv.org/abs/2004.14974) achieves around 40% F1, we have built a system in January and obtain 55% F1, the [current state of the art](https://arxiv.org/pdf/2010.11930.pdf) is at around 65% F1 -- there exists ample room for improvements.

* You can find the github repo for SciFact with additional information (and possibly example code which could be usueful for solving this exercise) [here](https://github.com/allenai/scifact).

* For the first part of the assignment, we don't expect you to train your own transformer model, but you obviously can. For the second part of the assignment, we expect you to train your own textual entailment model.

* Lastly, We don't expect you to have a competitive system at the end of this homework, anything works as long as you have put in some effort.

**All instructions provided can be substituted by your own ideas and only serve as a rough guideline for how to tackle the task!**

* If you want, it is also possible to build a similar system for question answering (e.g. [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/explore/1.1/dev/)), automated fact checking (e.g. [Climate-FEVER](https://www.sustainablefinance.uzh.ch/en/research/climate-fever.html), [FEVER](https://fever.ai/)) or other involved NLP tasks completely freestyle on a dataset of your choice. If so, please get in touch with a suggestion.

* You are allowed to use any tools and help you can find online to tackle this task.

**An Example from SciFact**

Consider the claim *"Consumption of whole fruits increases the risk of type 2 diabetes."*

We are given a fact base of 5000 abstracts and have to find evidence from the fact base which either supports or refutes the claim. In this example, the goal is to find the following sentence from the corpus:

*'Greater consumption of specific whole fruits, particularly blueberries, grapes, and apples, is significantly associated with a lower risk of type 2 diabetes, whereas greater consumption of fruit juice is associated with a higher risk.'*

It is easy to see that this sentence contradicts the claim. The goal of the task is to return all sentences in the fact base which contradict or support a claim -- with the corresponding label. In this case, we would return 
"evidence": {"1974176": [{"sentences": [11], "label": "CONTRADICT"}
where "1974176" is the doc_id of the abstract we found the evidence in, and it is the 11th sentence in that abstract.

In [None]:
# obtain the data
name="https://scifact.s3-us-west-2.amazonaws.com/release/latest/data.tar.gz"
!wget $name

!tar -xvf data.tar.gz
!rm data.tar.gz

**Part 1 of the Assignment: Information Retrieval**

In this section, we will build a document retrieval system which takes as input a claim and returns a number of candidate abstracts which are similar to the claim. Commonly, we start with a recall-oriented system which returns abstracts likely to contain evidence sentences. Then, we follow up with a more advanced model which selects only relevant sentences from the retrieved abstracts.

In [None]:
# install some helper utils
!pip install jsonlines

In [None]:
import jsonlines

# load the corpus (the fact base with the abstracts)
corpus = {str(doc['doc_id']): doc for doc in jsonlines.open("data/corpus.jsonl")}

# load the claims

# if you don't want to work with GPUs , you probably want to only consider 200/50 train/dev examples
# otherwise, this exercise might take too long

cpu_only = True
if cpu_only:
    claims_train = [claim for claim in jsonlines.open("data/claims_train.jsonl") if claim["evidence"]][:200]
    claims_dev = [claim for claim in jsonlines.open("data/claims_dev.jsonl") if claim["evidence"]][:50]
else:
    claims_train = [claim for claim in jsonlines.open("data/claims_train.jsonl") if claim["evidence"]][:200]
    claims_dev = [claim for claim in jsonlines.open("data/claims_dev.jsonl") if claim["evidence"]][:50]

print (len(claims_train))
print (len(claims_dev))

**Inspect the data**

Let's have a look at the corpus first. We see that every abstract has a unique doc_id, the title of the paper, the abstract (sentences are already tokenized) and a flag "structured" which is not relevant for this task.

In [None]:
print ("number of documents in the corpus", len(corpus))
print (corpus["1974176"])
print (corpus["1974176"].keys()) 
# dict_keys(['doc_id', 'title', 'abstract', 'structured'])
# abstract is a list of sentences

Next, we look at an example claim in more detail. We find that it has a unique id, the claim as a string and annotated evidence. The evidence is a dictionairy where each key points to the abstract in the corpus. The values are a list where each entry contains the sentence number in the corresponding abstract and a label whether this sentence contradicts or supports the claim.

In [None]:
print(claims_train[0])
print(claims_train[0].keys())

**Random Baseline for Abstract Retrieval**

As mentioned before, we need a system which retrieves abstracts from the corpus. Some ideas of how to tackle this include 
* create document embeddings via TF-IDF, SBERT, universal sentence encoder or any embedding technique you like. Embedd each claim and each abstract. Then find the closest abstracts for each claim
* use BM25 for document retrieval
* do something else which works

We provide a random baseline and evaluate recall for this method. Not surprisingly, this does not work well.

In [None]:
# random baseline
import random
def retrieve(claim, corpus, k):
    return random.sample(corpus.keys(), k=k)

retrieved_documents = []
for k in (3,5,10):
    for claim in claims_train:
        result = retrieve(claim["claim"], corpus, k)
        claim["doc_ids"] = result

    for claim in claims_dev:
        result = retrieve(claim["claim"], corpus, k)
        claim["doc_ids"] = result

In [None]:
# evaluate
def evaluate(claims):
    TP, FP, FN = 0, 0, 0
    for claim in claims:
        # relevant abstracts
        if claim["evidence"]:
            true_abstracts = set(claim["evidence"].keys())
            retrieved_abstracts = set(claim["doc_ids"])
            TP += len(true_abstracts.intersection(retrieved_abstracts))
            FN += len(true_abstracts.difference(retrieved_abstracts))
            FP += len(retrieved_abstracts.difference(true_abstracts))
        else:
            FP += len(claim["doc_ids"])
    try:
        pr = TP / (TP + FP)
        rc = TP / (TP + FN)
        f1 = 2 * pr * rc / (pr + rc)
    except ZeroDivisionError:
        pr, rc, f1 = 0,0,0
    print ("precision",pr, "recall",rc, "f1",f1)

print ("train claims")
evaluate(claims_train)
print ("dev claim")
evaluate(claims_dev)


In [None]:
##TODO create your own abstract retrieval system


In [None]:
##TODO evaluate your system. 
# If it operates on document level, we suggest to evaluate your system with k=3, k=5, k=10 retrieved documents

# else, evaluate it using some reasonable method

In [None]:
# save your results, we suggest for k=3, which makes the rest of this exercise less time consuming

k=3
for claim in claims_train:
    result = retrieve(claim["claim"], corpus, k)
    claim["doc_ids"] = result

for claim in claims_dev:
    result = retrieve(claim["claim"], corpus, k)
    claim["doc_ids"] = result
    
import json
with open("data/claims_train_with_retrieved_documents.jsonl", "w") as outfile:
    for claim in claims_train:
        json.dump(claim, outfile)
        outfile.write("\n")
        
with open("data/claims_dev_with_retrieved_documents.jsonl", "w") as outfile:
    for claim in claims_dev:
        json.dump(claim, outfile)
        outfile.write("\n")

**Sentence Retrieval**

Now, we have candidate documents for every claim. As we have seen before, the precision achieved is not very convincing. So, we train a second module which takes a claim and a sentence as input and decides whether this sentence is possible evidence which supports or verifies the claim. This is just another pairwise sentence classification task and is usually tackled as a binary classification.

* If you want to train your own model on CPU, we suggest to use [distilbert](https://huggingface.co/distilroberta-base) (which took me 20 minutes to fine-tune for one epoch on CPU). If you have access to GPUs, there's a variety of models to choose from, e.g. [here](https://huggingface.co/transformers/pretrained_models.html) or [here](https://huggingface.co/models).

* We also provide a model [here](https://www.dropbox.com/s/mh3lrg3z626d0xw/scibert_model.zip?dl=0) which is a BertForSequenceClassification checkpoint fine-tuned from [SciBERT](https://huggingface.co/allenai/scibert_scivocab_uncased) which you could use in this task. The model is trained to predict class 1 for annotated evidence sentences and class 0 for randomly sampled negative examples. You can download this model with 

* wget https://www.dropbox.com/s/mh3lrg3z626d0xw/scibert_model.zip?dl=0

* You can use any other model/method you like if you think it works reasonably well on this specific task

* We again provide a random baseline for demonstration purposes

In [None]:
# random baseline

import numpy as np
from tqdm import tqdm

for claim in tqdm(claims_dev):
    doc_ids = claim["doc_ids"]
    predicted_evidence = {}
    for doc_id in doc_ids:
        sentences = corpus[doc_id]["abstract"]
        predictions = np.random.normal(loc=-1, size=len(sentences))
        predicted_sentences = [i for i,j in enumerate(predictions) if j > 0]
        if predicted_sentences:
            predicted_evidence[doc_id] = {"sentences": predicted_sentences}
    claim["predicted_evidence"] = predicted_evidence

with open("data/claims_dev_with_predicted_sentences.jsonl", "w") as outfile:
    for claim in claims_dev:
        json.dump(claim, outfile)
        outfile.write("\n")

In [None]:
##TODO for every claim in the development set, and for every sentence in each retrieved abstract
#predict whether it is evidence or not

In [None]:
# some utils to evalaute this using the official metrics for SciFact
from collections import Counter
def safe_divide(num, denom):
    if denom == 0:
        return 0
    else:
        return num / denom

def compute_f1(counts, difficulty=None):
    correct_key = "correct" if difficulty is None else f"correct_{difficulty}"
    precision = safe_divide(counts[correct_key], counts["retrieved"])
    recall = safe_divide(counts[correct_key], counts["relevant"])
    f1 = safe_divide(2 * precision * recall, precision + recall)
    return {"precision": precision, "recall": recall, "f1": f1}

def is_correct(pred_sentence, pred_sentences, gold_sets):
    """
    A predicted sentence is correctly identified if it is part of a gold
    rationale, and all other sentences in the gold rationale are also
    predicted rationale sentences.
    """
    for gold_set in gold_sets:
        gold_sents = gold_set["sentences"]
        if pred_sentence in gold_sents:
            if all([x in pred_sentences for x in gold_sents]):
                return True
            else:
                return False

    return False


def evaluate_sentence_retrieval(dataset, rationale_selection):
    counts = Counter()
    for data, retrieval in zip(dataset, rationale_selection):
        assert data['id'] == retrieval['id']

        # Count all the gold evidence sentences.
        for doc_key, gold_rationales in data["evidence"].items():
            for entry in gold_rationales:
                counts["relevant"] += len(entry["sentences"])

        claim_id = retrieval['id']
        for doc_id, pred_sentences in retrieval['predicted_evidence'].items():
            true_evidence_sets = data['evidence'].get(doc_id) or []
            for pred_sentence in pred_sentences:
                counts["retrieved"] += 1
                if is_correct(pred_sentence, pred_sentences, true_evidence_sets):
                    counts["correct"] += 1
    f1 = compute_f1(counts)
    print(f1)
    




In [None]:
# evaluate 
evaluate_sentence_retrieval(claims_dev, claims_dev)
# and we find that our random baseline behaves poorly :(


In [None]:
##TODO evaluate your predictions

**Part 2 of the Assignment: Claim Verification**

To recap: For every claim, we have retrieved possible evidence sentence. Now, we want to determine whether these sentences support or contradict a claim. Usually, this is handled via textual entailment; if the evidence entails the claim, it is supported (and else, it is contradicted). For this task, you should train your own model, we propose to start from a [distilbert checkpoint](https://huggingface.co/typeform/distilbert-base-uncased-mnli) which has been pre-trained on MNLI. 

Again, we provide a random baseline and evaluate this baseline.

In [None]:
# random baseline

id2label = {0:"SUPPORT", 1: "NOT_ENOUGH_INFO", 2:"CONTRADICT"}
for claim in claims_dev:
    predicted_evidence = claim["predicted_evidence"]
    labels = {}
    for doc_id, sentence_ids in predicted_evidence.items():
        abstract = corpus[doc_id]["abstract"]
        sentences = " ".join(abstract[i] for i in sentence_ids["sentences"])
        label = id2label[np.random.choice([0,1,2])]
        # if we predict neutral, we just ignore these evidence sentences
        labels[doc_id] = {"label": label}
    claim["labels"] = labels
            
        
claims_dev[0]

In [None]:
##TODO create an appropriate dataset to fine-tune your model 
# (input to your model should be [CLS] claim [SEP] evidence_sentence [SEP]
# it might be required to sample some evidence sentences which are not annotated and act as "neutral" or 
# "NOT_ENOUGH_INFO" examples
# (hint: labels in mnli are: LABELS = {'CONTRADICT': 0, 'NOT_ENOUGH_INFO': 1, 'SUPPORT': 2}
)

In [None]:
# some utils to evalaute this using the official metrics for SciFact
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
def evaluate_labels(dataset, label_prediction):
    LABELS = {'CONTRADICT': 0, 'NOT_ENOUGH_INFO': 1, 'SUPPORT': 2}
    pred_labels = []
    true_labels = []

    for data, prediction in zip(dataset, label_prediction):
        assert data['id'] == prediction['id']

        if not prediction['labels']:
            continue

        claim_id = data['id']
        for doc_id, pred in prediction['labels'].items():
            pred_label = pred['label']
            true_label = {es['label'] for es in data['evidence'].get(doc_id) or []}
            assert len(true_label) <= 1, 'Currently support only one label per doc'
            true_label = next(iter(true_label)) if true_label else 'NOT_ENOUGH_INFO'
            pred_labels.append(LABELS[pred_label])
            true_labels.append(LABELS[true_label])

    print(f'Accuracy           {round(sum([pred_labels[i] == true_labels[i] for i in range(len(pred_labels))]) / len(pred_labels), 4)}')
    print(f'Macro F1:          {f1_score(true_labels, pred_labels, average="macro").round(4)}')
    print(f'Macro F1 w/o NEI:  {f1_score(true_labels, pred_labels, average="macro", labels=[0, 2]).round(4)}')
    print()
    print('                   [C      N      S     ]')
    print(f'F1:                {f1_score(true_labels, pred_labels, average=None).round(4)}')
    print(f'Precision:         {precision_score(true_labels, pred_labels, average=None).round(4)}')
    print(f'Recall:            {recall_score(true_labels, pred_labels, average=None).round(4)}')
    print()
    print('Confusion Matrix:')
    print(confusion_matrix(true_labels, pred_labels))
evaluate_labels(claims_dev, claims_dev)



In [None]:
##TODO evaluate your own predictions