# Check-worthiness detection using Large Language Models

First, the necessary python modules are imported

In [1]:
%load_ext autoreload

from claimbuster_utils import load_claimbuster_dataset
from checkthat_utils import load_check_that_dataset
import pandas as pd
from llm import load_huggingface_model, HuggingFaceModel, run_llm_cross_validation, generate_llm_predictions, ICLUsage, PromptType
from result_analysis import generate_error_analysis_report, print_padded_text
from dataset_utils import generate_cross_validation_datasets, Dataset
import ipywidgets as widgets
import os

2024-04-07 15:05:03.960963: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-07 15:05:03.961044: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-07 15:05:03.962882: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-07 15:05:03.974677: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Generate Cross Validation datasets

In [2]:
%autoreload
claimbuster = load_claimbuster_dataset("../data/ClaimBuster/datasets")
clambuster_datasets = generate_cross_validation_datasets(
    data=claimbuster, 
    folder_path="../data/ClaimBuster/crossval"
)

checkthat = load_check_that_dataset("../data/CheckThat")
checkthat_datasets = generate_cross_validation_datasets(
    data=checkthat, 
    label_column="check_worthiness",
    folder_path="../data/CheckThat/crossval"
)

## Load model

In [3]:
%autoreload
model_id = HuggingFaceModel.MIXTRAL_INSTRUCT
pipe = load_huggingface_model(model_id, max_new_tokens=1024)


Loading checkpoint shards: 100%|██████████| 19/19 [03:01<00:00,  9.54s/it]


## Generate predictions

Using ipywidgets to select which model, dataset, and other parameters to generate LLM predictions

In [2]:
%autoreload

# General lauyout
input_style = dict(
    description_width="fit-content"
)

# Dataset 
dataset_select = widgets.Dropdown(
    options=[("ClaimBuster", Dataset.CLAIMBUSTER), ("CheckThat", Dataset.CHECK_THAT)],
    value=Dataset.CLAIMBUSTER,
    description="Dataset:"
)

# Model and parameters
model_select = widgets.Dropdown(
    options=[("Mistral 7B Instruct", HuggingFaceModel.MISTRAL_7B_INSTRUCT), ("Mixtral Instruct", HuggingFaceModel.MIXTRAL_INSTRUCT)],
    value=HuggingFaceModel.MISTRAL_7B_INSTRUCT,
    description="Model:",
    style=input_style
)
max_new_tokens_int_text = widgets.IntText(
    value=64,
    description="Max new tokens:",
    style=input_style
)
batch_size = widgets.IntText(
    value=128,
    description="Batch size:",
    style=input_style
)
model_and_parameters = widgets.VBox(
    [model_select, max_new_tokens_int_text, batch_size],
)

# Prompting type
prompting_type = widgets.Dropdown(
    options=[("Standard", PromptType.STANDARD), ("Chain-of-Thought", PromptType.CHAIN_OF_THOUGHT)],
    value=PromptType.STANDARD,
    description="Prompting type:",
    style=input_style
)
icl_usage = widgets.Dropdown(
    options=[("Zero-shot", ICLUsage.ZERO_SHOT), ("Few-shot", ICLUsage.FEW_SHOT)],
    value=ICLUsage.ZERO_SHOT,
    description="ICL usage:",
    style=input_style
)
prompt_use = widgets.VBox(
    [prompting_type, icl_usage]
)

accordion = widgets.Accordion([
    dataset_select,
    model_and_parameters,
    prompt_use
],
    titles=["Dataset", "Model and parameters", "Prompting type"],
)

title = widgets.HTML(
    "<h1>Generation of predictions using LLMs</h1>",
)
description = widgets.HTML(
    "<div>Set the parameters to select what dataset, model and prompting to use when generating predictions. If you experience Cuda out of memory issues, please decrease the batch size.</div>",
    layout={"font-size": '14px'}
)
start_generation_button = widgets.Button(
    description="Start generation",
    disabled=False,
    button_style="success",
    layout={"height": "40px", "width": "calc(100% - 4px)"},
)

def handle_generation_click(_):
    if dataset_select.value == Dataset.CLAIMBUSTER:
        dataset = load_claimbuster_dataset("../data/ClaimBuster/datasets")
        label_column = "Verdict"
        text_column = "Text"
    else:
        dataset = load_check_that_dataset("../data/CheckThat")
        label_column = "check_worthiness"
        text_column = "tweet_text"

    instruction_path = os.path.join(
        "../prompts",
        dataset_select.value.value,
        prompting_type.value.value,
        icl_usage.value.value,
        "instruction.txt"
    )
    print("#" * 50)
    print_padded_text("Starting generation with parameters")
    print_padded_text(f"Dataset: {dataset_select.value.value}")
    print_padded_text(f"Model: {model_select.value.name}")
    print_padded_text(f"Prompting type: {prompting_type.value.value}")
    print_padded_text(f"ICL usage: {icl_usage.value.value}")
    print("#" * 50)
    if not os.path.exists(instruction_path):
        print("No instruction found, exiting...")
        return
    with open(instruction_path, "r") as f:
        instruction = f.read().replace("\n", "")
    prompts = [ f"[INST]{instruction} '''{text}'''[/INST]" for text in dataset[text_column]]
    print("Loading model...")
    pipe = load_huggingface_model(
        model_id=model_select.value, 
        max_new_tokens=max_new_tokens_int_text.value
    )

    print("Generating predictions...")
    save_path = os.path.join(
        "../results",
        dataset_select.value.value,
        model_select.value.name,
        prompting_type.value.value,
        icl_usage.value.value,
        "generated_scores.csv"
    )
    generate_llm_predictions(
        data=dataset,
        prompts=prompts,
        pipe=pipe,
        batch_size=batch_size.value,
        label_column=label_column,
        text_column=text_column,
        save_path=save_path
    )


start_generation_button.on_click(handle_generation_click)


box = widgets.Box(
    [title, description, accordion, start_generation_button],
    layout=widgets.Layout(
        padding= '16px', 
        display= "flex", 
        flex_flow="column",
        align_items="stretch",
        border="1px solid black"
    )
) 
display(box)

Box(children=(HTML(value='<h1>Generation of predictions using LLMs</h1>'), HTML(value='<div>Set the parameters…

##################################################
#      Starting generation with parameters       #
#              Dataset: ClaimBuster              #
#           Model: MISTRAL_7B_INSTRUCT           #
#            Prompting type: standard            #
#               ICL usage: fewshot               #
##################################################
Loading model...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

max_new_tokens=64
Generating predictions...


  0%|          | 0/9674 [00:00<?, ?it/s]

  dataset_with_scores.loc[dataset_index, "score"] = score


## Cross validation

Using ipywidgets to select what models to run cross validation on

In [3]:
%autoreload

input_style = dict(
    description_width="fit-content"
)

dataset_select = widgets.Dropdown(
    options=[("ClaimBuster", Dataset.CLAIMBUSTER), ("CheckThat", Dataset.CHECK_THAT)],
    value=Dataset.CLAIMBUSTER,
    description="Dataset:"
)

model_select = widgets.Dropdown(
    options=[("Mistral 7B Instruct", HuggingFaceModel.MISTRAL_7B_INSTRUCT), ("Mixtral Instruct", HuggingFaceModel.MIXTRAL_INSTRUCT)],
    value=HuggingFaceModel.MISTRAL_7B_INSTRUCT,
    description="Model:",
    style=input_style
)

prompting_type = widgets.Dropdown(
    options=[("Standard", PromptType.STANDARD), ("Chain-of-Thought", PromptType.CHAIN_OF_THOUGHT)],
    value=PromptType.STANDARD,
    description="Prompting type:",
    style=input_style
)
icl_usage = widgets.Dropdown(
    options=[("Zero-shot", ICLUsage.ZERO_SHOT), ("Few-shot", ICLUsage.FEW_SHOT)],
    value=ICLUsage.ZERO_SHOT,
    description="ICL usage:",
    style=input_style
)
prompt_use = widgets.VBox(
    [prompting_type, icl_usage]
)

title = widgets.HTML(
    "<h1>Cross validation using LLMs</h1>",
)
description = widgets.HTML(
    "<div>Set the parameters to select what dataset, model and prompting to use when performing cross validation.</div>",
    layout={"font-size": '14px'}
)
start_cross_validation_button = widgets.Button(
    description="Start cross validation",
    disabled=False,
    button_style="success",
    layout={"height": "40px", "width": "calc(100% - 4px)"},
)

def handle_cross_validation_click(_):
    dataset_folder = os.path.join(
        "../results",
        dataset_select.value.value,
        model_select.value.name,
        prompting_type.value.value,
        icl_usage.value.value,
    )
    dataset_path = os.path.join(dataset_folder, "generated_scores.csv")
    if not os.path.exists(dataset_path):
        print("No generated scores found")
        return
    dataset_with_scores = pd.read_csv(dataset_path, index_col=0)
    crossval_folder = os.path.join(
        "../data",
        dataset_select.value.value,
        "crossval"
    )
    label_column = "Verdict" if dataset_select.value == Dataset.CLAIMBUSTER else "check_worthiness"
    print("#" * 50)
    print_padded_text("Starting cross validation with parameters")
    print_padded_text(f"Dataset: {dataset_select.value.value}")
    print_padded_text(f"Model: {model_select.value.name}")
    print_padded_text(f"Prompting type: {prompting_type.value.value}")
    print_padded_text(f"ICL usage: {icl_usage.value.value}")
    print("#" * 50)
    result, _ = run_llm_cross_validation(
        data=dataset_with_scores, 
        crossval_folder=crossval_folder,
        save_folder=dataset_folder,
        label_column=label_column
    )
    display(result)

start_cross_validation_button.on_click(handle_cross_validation_click)

accordion = widgets.Accordion([
    dataset_select,
    model_select,
    prompt_use
],
    titles=["Dataset", "Model", "Prompting type"],
)

box = widgets.Box(
    [title, description, accordion, start_cross_validation_button],
    layout=widgets.Layout(
        padding= '16px', 
        display= "flex", 
        flex_flow="column",
        align_items="stretch",
        border="1px solid black"
    )
) 
display(box)

Box(children=(HTML(value='<h1>Cross validation using LLMs</h1>'), HTML(value='<div>Set the parameters to selec…

##################################################
#   Starting cross validation with parameters    #
#              Dataset: ClaimBuster              #
#           Model: MISTRAL_7B_INSTRUCT           #
#            Prompting type: standard            #
#               ICL usage: fewshot               #
##################################################
self.threshold=31
self.threshold=8
self.threshold=31
self.threshold=8


Unnamed: 0,accuracy,0_precision,0_recall,0_f1-score,1_precision,1_recall,1_f1-score,macro avg_precision,macro avg_recall,macro avg_f1-score,weighted avg_precision,weighted avg_recall,weighted avg_f1-score
0,0.809012,0.824949,0.929977,0.87432,0.7431,0.506512,0.60241,0.784024,0.718245,0.738365,0.801568,0.809012,0.796647
1,0.746589,0.912657,0.713542,0.800909,0.536517,0.829233,0.651507,0.724587,0.771387,0.726208,0.805211,0.746589,0.758232
2,0.797767,0.822732,0.913723,0.865844,0.702,0.507959,0.589421,0.762366,0.710841,0.727632,0.78823,0.797767,0.786849
3,0.757651,0.930566,0.713955,0.807995,0.548033,0.86686,0.671525,0.739299,0.790407,0.73976,0.821248,0.757651,0.768995
Average,0.777755,0.872726,0.817799,0.837267,0.632412,0.677641,0.628715,0.752569,0.74772,0.732991,0.804064,0.777755,0.777681


## Zero-shot classification

#### Using contextual features

In [None]:
data = load_claimbuster_dataset(
    "../data/ClaimBuster/datasets",
    use_contextual_features=True,
    debate_transcripts_folder="../data/ClaimBuster/debate_transcripts",
)[:10]

contexts = data["previous_sentences"].tolist()
prompts = [
    f"{instruction} For context, the following senteces were said prior to the one in question: {context} Only evaluate the check-worthiness of the following sentence: '''{text}'''"
    for text, context in zip(texts, contexts)
]
zeroshot_output = "../results/ClaimBuster/{model_id.name}/zeroshot/zeroshot_contextual_preds.csv"

generate_llm_predictions(
    data=data,
    pipe=pipe, 
    prompts=prompts, 
    save_path=zeroshot_output
)

#### Error analysis

In [42]:
%autoreload
mistral_predictins = pd.read_csv(f"../results/ClaimBuster/{HuggingFaceModel.MISTRAL_7B_INSTRUCT.name}/zeroshot/predictions.csv", index_col=0)
mixtral_predictions = pd.read_csv(f"../results/ClaimBuster/{HuggingFaceModel.MIXTRAL_INSTRUCT.name}/zeroshot/predictions.csv", index_col=0)
lora_predictions = pd.read_csv(f"../results/ClaimBuster/{HuggingFaceModel.MISTRAL_7B_INSTRUCT.name}/lora/predictions.csv", index_col=0)
predictions = [mistral_predictins, mistral_predictins, lora_predictions]
model_names = [HuggingFaceModel.MISTRAL_7B_INSTRUCT.name, HuggingFaceModel.MIXTRAL_INSTRUCT.name, "LORA"]
display(claimbuster.head())
generate_error_analysis_report(
    claimbuster,
    predictions=predictions,
    model_names=model_names,
    folder_path=f"../results/ClaimBuster"
)

Unnamed: 0_level_0,Verdict,Text
sentence_id,Unnamed: 1_level_1,Unnamed: 2_level_1
27247,1,We're 9 million jobs short of that.
10766,1,"You know, last year up to this time, we've los..."
3327,1,And in November of 1975 I was the first presid...
19700,1,And what we've done during the Bush administra...
12600,1,Do you know we don't have a single program spo...


##################################################
#              MISTRAL_7B_INSTRUCT               #
#              False positives: 913              #
#              False negatives: 726              #
##################################################
#                MIXTRAL_INSTRUCT                #
#              False positives: 913              #
#              False negatives: 726              #
##################################################
#                      LORA                      #
#              False positives: 366              #
#              False negatives: 406              #
##################################################
#                     Total                      #
#             False positives: 1109              #
#              False negatives: 860              #
#        Overlapping false positives: 170        #
#        Overlapping false negatives: 272        #
##################################################


### CheckThat 2021 Task 1a Tweets

#### Error analysis

In [40]:
%autoreload
folder_path = f"../results/CheckThat"
mistral_predictions = pd.read_csv(f"{folder_path}/{HuggingFaceModel.MISTRAL_7B_INSTRUCT.name}/zeroshot/predictions.csv", index_col=0)
mixtral_predictions = pd.read_csv(f"{folder_path}/{HuggingFaceModel.MIXTRAL_INSTRUCT.name}/zeroshot/predictions.csv", index_col=0)
lora_predictions = pd.read_csv(f"{folder_path}/{HuggingFaceModel.MISTRAL_7B_INSTRUCT.name}/lora/predictions.csv", index_col=0)
results = [mistral_predictions, mixtral_predictions, lora_predictions]
model_names = [HuggingFaceModel.MISTRAL_7B_INSTRUCT.name, HuggingFaceModel.MIXTRAL_INSTRUCT.name, "LORA"]
generate_error_analysis_report(
    checkthat,
    predictions=results,
    model_names=model_names,
    folder_path=folder_path,
    label_column_name="check_worthiness",
    text_column_name="tweet_text",
)

##################################################
#              MISTRAL_7B_INSTRUCT               #
#              False positives: 315              #
#              False negatives: 102              #
##################################################
#                MIXTRAL_INSTRUCT                #
#              False positives: 161              #
#              False negatives: 131              #
##################################################
#                      LORA                      #
#              False positives: 115              #
#              False negatives: 77               #
##################################################
#                     Total                      #
#              False positives: 390              #
#              False negatives: 182              #
#        Overlapping false positives: 54         #
#        Overlapping false negatives: 32         #
##################################################
