# Evidence retrieval
 Import libraries and stem corpus

In [1]:
import bm25s as bm25
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import Stemmer

# get evidences, extract column
df_evidences = pd.read_csv("../../data/processed/evidences.csv")
df_evidences = df_evidences.dropna()
df_evidences = df_evidences["evidence"]

# stem evidences
stemmer = Stemmer.Stemmer("english")
corpus_stemmed = bm25.tokenize(df_evidences.values, stopwords="en", stemmer=stemmer)

  from .autonotebook import tqdm as notebook_tqdm
                                                       

Get top-k evidences for each claim

In [2]:
# Amount of documents to retrieve
k = 5

# Init bm25 retriever
retriever = bm25.BM25()
retriever.index(corpus_stemmed)

# Get and extract
claims_df = pd.read_csv("../../data/processed/claims.csv")
claims_df = claims_df.dropna()
claims_df = claims_df["claim"]

# Save found evidences for each claim
claims_with_evidences = []

for claim in claims_df.values:
    # Disable show progress, because of large output + jupiter issues
    query_tokens = bm25.tokenize(claim, stemmer=stemmer, show_progress=False)
    results = retriever.retrieve(query_tokens, k = k, show_progress=False)
    # save claim with all evidences as tuple
    docs = df_evidences.iloc[results.documents[0].tolist()]
    claims_with_evidences.append((claim, docs))

                                                              

# Classification
classify evidences for each claim

In [3]:
tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli", local_files_only=True)
model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli", local_files_only=True)
model.eval()

# score evidences of one claim
def score_evidence(claim, evidence):
    # claim length must be same size as evidence length
    # create evidence-claim pairs
    claims = [claim] * len(evidence)
    inputs = tokenizer(
        evidence,
        claims,
        padding=True,
        return_tensors="pt",
        truncation=True,
        max_length=512
    )

    # disable gradient calculation (not needed)
    with torch.no_grad():
        outputs = model(**inputs)

    # turn logits into probabilities using softmax
    return torch.softmax(outputs.logits, dim=-1).squeeze()

# try to verify or refute claims
def verify_claim(
    claim,
    evidences,
    support_threshold=0.7,
    refutation_threshold=0.7
):

    probs = score_evidence(claim, evidences)

    support_scores = probs[:, 2] # label 2 = support
    refute_scores = probs[:, 0]  # label 0 = refute

    max_support = support_scores.max().item()
    max_refute = refute_scores.max().item()

    # support/refute if over threshold
    if max_support >= support_threshold:
        return "SUPPORTED", max_support
    elif max_refute >= refutation_threshold:
        return "REFUTED", max_refute
    else:
        return "NOT_ENOUGH_INFO", max(max_support, max_refute)

# store final classification in list for now
claim_classified = []

# this loop needs some time...
for claim_with_evidence in claims_with_evidences:
    result = verify_claim(claim_with_evidence[0], claim_with_evidence[1].values.tolist())
    claim_classified.append(result)
    print(result)
print(claim_classified[:7])

Some weights of the model checkpoint at roberta-large-mnli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


('NOT_ENOUGH_INFO', 0.6665968894958496)
('REFUTED', 0.978366494178772)
('SUPPORTED', 0.982164740562439)
('NOT_ENOUGH_INFO', 0.6775383353233337)
('NOT_ENOUGH_INFO', 0.4155708849430084)
('NOT_ENOUGH_INFO', 0.6211366653442383)
('SUPPORTED', 0.9932665824890137)
('REFUTED', 0.9994128942489624)
('NOT_ENOUGH_INFO', 0.2539704144001007)
('REFUTED', 0.7063152194023132)
('REFUTED', 0.8410676121711731)
('REFUTED', 0.9480153918266296)
('REFUTED', 0.9849709868431091)
('REFUTED', 0.8275517821311951)
('REFUTED', 0.993050217628479)
('REFUTED', 0.9982215762138367)
('NOT_ENOUGH_INFO', 0.09203308820724487)
('NOT_ENOUGH_INFO', 0.5824480056762695)
('NOT_ENOUGH_INFO', 0.08355019241571426)
('NOT_ENOUGH_INFO', 0.47513359785079956)
('NOT_ENOUGH_INFO', 0.3984202444553375)


KeyboardInterrupt: 