<a href="https://colab.research.google.com/github/kevin-rn/Grounding-LM/blob/main/fact_check.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Task
Extract sentences from cnn dailymail articles and index them. Use claim detection or evidence sentence selection models to achieve this. For each summary generated from model consider it to be a claim and retrieve closed sentences from index. Use an out of box stance detection model to verify the summary against retrieved evidences.  


In [1]:
import os
from google.colab import drive
drive.mount('/content/drive')
%cd drive/MyDrive/Grounding_LM/

Mounted at /content/drive
/content/drive/MyDrive/Grounding_LM


In [2]:
%pip install -q transformers
%pip install -q sentence-transformers
%pip install -q -U annoy

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m63.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m32.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m133.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m83.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m23.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for sentence-transformers (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m647.5/647.5 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing met

In [3]:
from annoy import AnnoyIndex
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, GPT2LMHeadModel, GPT2Tokenizer
import torch
from tqdm.auto import tqdm
import nltk
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

nltk.download('punkt')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tqdm.pandas()

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


### Load data

In [4]:
df_test = pd.read_csv("results/generated summaries/t5_large_cnn_dailymail.csv", index_col=0)
df_test.head()

Unnamed: 0,text,summary,id,generated
0,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,f001ec5c4704938247d27a44948eebb37ae98d01,The Palestinians have become a member of the I...
1,(CNN)Never mind cats having nine lives. A stra...,"Theia, a bully breed mix, was apparently hit b...",230c522854991d053fe98a718b1defa077a8efef,A dog that was apparently buried alive after b...
2,"(CNN)If you've been following the news lately,...",Mohammad Javad Zarif has spent more time with ...,4495ba8f3a340d97a9df1476f8a35502bcce1f69,It's been a busy week for Iran.
3,(CNN)Five Americans who were monitored for thr...,17 Americans were exposed to the Ebola virus w...,a38e72fed88684ec8d60dd5856282e999dc8c0ca,Five Americans who were being treated for Ebol...
4,(CNN)A Duke student has admitted to hanging a ...,Student is no longer on Duke University campus...,c27cf1b136cc270023de959e7ab24638021bc43f,A student at Duke University has admitted hang...



### Claim detection

1. Load pre-trained claim detection model (BERT pretrained on Claimbuster dataset)
2. Split each source document text into sentences using NLTK's `sent_tokenize`
3. Extract claimworthy sentences from this

In [5]:
claim_tokenizer = AutoTokenizer.from_pretrained("Nithiwat/bert-base_claimbuster")
claim_model = AutoModelForSequenceClassification.from_pretrained("Nithiwat/bert-base_claimbuster").to(device)

Downloading (…)okenizer_config.json:   0%|          | 0.00/348 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/881 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

In [6]:
def extract_claimworthy(sentences):
    tokenized_inputs = claim_tokenizer(sentences, padding=True, truncation=True, return_tensors="pt").to(device)
    with torch.no_grad():
        logits = claim_model(**tokenized_inputs).logits
        logits = logits.cpu()
    label_indices = torch.nonzero(logits.argmax(dim=1) == 0).squeeze().cpu()
    if label_indices.dim() == 0:
        label_indices = label_indices.unsqueeze(0)

    claimworthy = [sentences[idx] for idx in label_indices]
    return claimworthy

In [7]:
df_test['sentences'] = df_test['text'].apply(sent_tokenize)

In [15]:
# df_test['claims'] = df_test['sentences'].progress_apply(extract_claims)
# df_test.to_csv('claims.csv', index=False)

In [8]:
sentences = extract_claimworthy(df_test['sentences'][0])

print(f"evidence: {' '.join(sentences)} \nclaim: {df_test['generated'][0]}")

evidence: The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. These are substantive commitments, which cannot be taken lightly," she said. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." 
claim: The Palestinians have become a member of the International Criminal Court (ICC).


### Construct Index

1. Load sentence-transformers model to create text embeddings for sentences & paragraphs
2. Calculate embeddings for each claimworthy sentence
3. Store embeddings using ANNOY library for index and retrieval.

In [9]:
model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2') # 384 dimensional dense vector space

Downloading (…)001fa/.gitattributes:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)3bbb8001fa/README.md:   0%|          | 0.00/3.69k [00:00<?, ?B/s]

Downloading (…)bb8001fa/config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)001fa/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/314 [00:00<?, ?B/s]

Downloading (…)3bbb8001fa/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)b8001fa/modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

In [24]:
def get_embeddings(txt_inputs):
    embeddings = [model.encode(txt) for txt in txt_inputs]
    return embeddings

In [11]:
def index_annoy(phrase_embeddings, embedding_dim = 384, number_of_trees=100):
  ann = AnnoyIndex(embedding_dim, metric = "angular")
  for index, embed in enumerate(phrase_embeddings):
      ann.add_item(index, embed)
  ann.build(number_of_trees)
  ann.save("data/cnn_claims.annoy")

In [12]:
index_annoy(get_embeddings(sentences))

### Factchecking
1. Retrieve top-k source document claimworthy sentence embeddings from ANNOY for a given claim (generated summary).
2. Calculate cosine similarity between the given claim and the retrieved sentences and keep the ones above certain cosine similarity.
3. Load pre-trained fact-checking model and infer whether evidence supports, refutes or is neutral for the given claim.

In [33]:
checkpoint = 'Dzeniks/roberta-fact-check'
factcheck_model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
factcheck_tokenizer = AutoTokenizer.from_pretrained(checkpoint)
label_mapping = ['support', 'refute', 'neutral']

Downloading (…)lve/main/config.json:   0%|          | 0.00/686 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/499M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/405 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

In [31]:
def get_top_nn_neighbours(annoy, claim, df_index, k=15, p=0.5):
    new_emb = model.encode(claim)
    top_matches = annoy.get_nns_by_vector(new_emb, k)
    evidence_sentences =  [df_test["sentences"][df_index][i] for i in top_matches]
    evidence_embeddings = get_embeddings(evidence_sentences)
    sim_scores = cosine_similarity([new_emb], evidence_embeddings).tolist()

    top_sentences = []
    for idx, similarity in sorted(enumerate(sim_scores[0]), key=lambda x: x[1], reverse=True):
        if similarity > p:
          top_sentences.append(evidence_sentences[idx])
    return top_sentences

def fact_check(claim, evidence):
    features = factcheck_tokenizer.encode_plus(claim, evidence, truncation=True, return_tensors="pt", max_length=512).to(device)
    factcheck_model.eval()
    with torch.no_grad():
      prediction = factcheck_model(**features).logits
      logits = prediction.cpu().numpy()
      label = label_mapping[logits.argmax().item()]
    return label

In [13]:
annoy = AnnoyIndex(384, metric="angular")
annoy.load("data/cnn_claims.annoy")

True

In [32]:
claim = df_test['generated'][0]
results = get_top_nn_neighbours(annoy, claim, 0)
results

['The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014."',
 '(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories.']

In [36]:
label = fact_check(claim, ''.join(results))
print(f"Label: {label}")

Label: support
