# Check-worthiness detection using Large Language Models

First, the necessary python modules are imported

In [37]:
%load_ext autoreload

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from claimbuster_utils import load_claimbuster_dataset
from tqdm.auto import tqdm
import json
import numpy as np
import re
import torch
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import cross_validate, StratifiedKFold
from sklearn.base import BaseEstimator, TransformerMixin

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load model

In [3]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.2",
    torch_dtype=torch.float16,
    quantization_config = bnb_config,
    # attn_implementation="flash_attention_2", 
    device_map={"": 0}
)
model.config.use_cache = False
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
pipe = pipeline(
    "text-generation", 
    model=model, 
    tokenizer=tokenizer, 
    return_full_text=False,
    max_new_tokens=256,
    pad_token_id=tokenizer.eos_token_id
)


Loading checkpoint shards: 100%|██████████| 3/3 [00:27<00:00,  9.07s/it]


## Zero-shot classification

### ClaimBuster

In [4]:
with open("../prompts/ClaimBuster/standard/zero-shot.txt", "r") as f:
    instruction = f.read().replace("\n", "")
use_contextual = False
data = load_claimbuster_dataset(
    "../data/ClaimBuster_Datasets/datasets",
    use_contextual_features=use_contextual,
    debate_transcripts_folder="../data/ClaimBuster_Datasets/debate_transcripts",
)

texts = data["Text"]
if use_contextual is False:
    prompts = [f"{instruction} '''{text}'''" for text in texts]
    zeroshot_output = "../results/ClaimBuster/zeroshot1.csv"
else:
    contexts = data["previous_sentences"].tolist()
    prompts = [
        f"{instruction} For context, the following senteces were said prior to the one in question: {context} Only evaluate the check-worthiness of the following sentence: '''{text}'''"
        for text, context in zip(texts, contexts)
    ]
    zeroshot_output = "../results/ClaimBuster/zeroshot_contextual.csv"


class ProgressDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, idx):
        return self.dataset[idx]

    def __len__(self):
        return len(self.dataset)


prompts_data = ProgressDataset(prompts)

dataset_with_scores = data.copy()

display(data.head())
dict_matcher = re.compile(r"{.*}")
score_matcher = re.compile(r"([Ss]core[^\d]*)(\d+)")
non_check_worthy_matcher = re.compile(
    r"(non-checkworthy)|(not check-worthy)|(non check-worthy)"
)

responses = pipe(prompts_data, batch_size=128)
for index, result in enumerate(tqdm(responses, total=len(prompts))):
    response = result[0]["generated_text"].replace("\n", "")
    dataset_index = data.index[index]
    try:
        parsed_json = json.loads(dict_matcher.search(response).group(0))
        dataset_with_scores.loc[dataset_index, "score"] = parsed_json["score"]
        dataset_with_scores.loc[dataset_index, "reasoning"] = parsed_json["reasoning"]
    except (json.decoder.JSONDecodeError, AttributeError) as e:
        # Try to find score
        score = score_matcher.search(response)
        if score is not None:
            score = score[2]
        else:
            score = 0.0 if non_check_worthy_matcher.search(response) else np.nan
        dataset_with_scores.loc[dataset_index, "score"] = score
        dataset_with_scores.loc[dataset_index, "reasoning"] = response
        continue
# Set the following column order: Verdict, score, Text, reasoning, previous_sentences
columns =  ["Verdict", "score", "Text", "reasoning"]
if use_contextual:
    columns.append("previous_sentences")
dataset_with_scores = dataset_with_scores[columns]
dataset_with_scores.to_csv(zeroshot_output, index=True)

Unnamed: 0_level_0,Verdict,Text
sentence_id,Unnamed: 1_level_1,Unnamed: 2_level_1
27247,1,We're 9 million jobs short of that.
10766,1,"You know, last year up to this time, we've los..."
3327,1,And in November of 1975 I was the first presid...
19700,1,And what we've done during the Bush administra...
12600,1,Do you know we don't have a single program spo...


  dataset_with_scores.loc[dataset_index, "score"] = score
100%|██████████| 9674/9674 [31:28<00:00,  5.12it/s]


#### Discussion of results

In [43]:
# Print the number of empty scores
dataset_path = "../results/ClaimBuster/zeroshot_contextual.csv"
dataset_with_scores = pd.read_csv(dataset_path, index_col=0)
class ThresholdOptimizer(BaseEstimator, TransformerMixin):

    def __init__(self):
        self.threshold = None

    def fit(self, x: pd.DataFrame, y: pd.Series):
        
        y_gold = x["Verdict"].values

        reports = []
        for threshold in range(1, 100):
            y_pred = x["score"].map(lambda x: 1 if x >= threshold else 0).values
            report = classification_report(y_gold, y_pred, output_dict=True)
            report["threshold"] = threshold
            reports.append(report)
        self.threshold = max(reports, key=lambda report: report["macro avg"]["f1-score"])["threshold"]
    
    def predict(self, x: pd.DataFrame):
        predictions = x["score"].map(lambda x: 1 if x >= self.threshold else 0)
        return predictions

# Do a four fold cross validation where the threshold is optimized
print(cross_validate(
    ThresholdOptimizer(),
    X=dataset_with_scores,
    y=dataset_with_scores["Verdict"],
    cv=StratifiedKFold(n_splits=4),
    scoring=["f1_macro", "accuracy"],
))

{'fit_time': array([1.28376102, 1.50472569, 1.22792745, 1.2808435 ]), 'score_time': array([0.00417757, 0.0038116 , 0.00471807, 0.00366807]), 'test_f1_macro': array([0.64137709, 0.63468915, 0.62384077, 0.62942657]), 'test_accuracy': array([0.68168665, 0.68582059, 0.66956162, 0.67162945])}


### CheckThat 2021 Task 1a Tweets