# Check-worthiness detection using Large Language Models

First, the necessary python modules are imported

In [None]:
%load_ext autoreload
%autoreload

import os
if "src" in os.getcwd():
    os.chdir("..")

from src.claimbuster_utils import load_claimbuster_dataset
from src.checkthat_utils import load_check_that_dataset
from src.llm import HuggingFaceModel, run_llm_cross_validation, ICLUsage, PromptType, Experiment
from src.result_analysis import generate_error_analysis_report, print_padded_text, create_confusion_matrix
from src.dataset_utils import generate_cross_validation_datasets, CustomDataset
from src.plotting_utils import show_bar_plot
from src.liar_utils import LIARLabel
from src.rawfc_utils import RAWFCLabel
import pandas as pd
import ipywidgets as widgets
from huggingface_hub import login
import pandas as pd
import itertools

## Login to HuggingFace hub

In order to get access to the LLama2 model, you need to login to the Huggingace hub and have gated access to the model. Otherwise this can be skipped

In [None]:
login()

## Check-worthiness predictions (experiments E1-E4)

Generates check-worthiness detection predictions using different configurations of LLMs and performs cross validation.

### Generate Cross Validation datasets

In [None]:
%autoreload
claimbuster = load_claimbuster_dataset("../data/ClaimBuster/datasets")
clambuster_datasets = generate_cross_validation_datasets(
    data=claimbuster, 
    folder_path="../data/ClaimBuster/crossval"
)

checkthat = load_check_that_dataset("../data/CheckThat")
checkthat_datasets = generate_cross_validation_datasets(
    data=checkthat, 
    label_column="check_worthiness",
    folder_path="../data/CheckThat/crossval"
)

### Generate predictions

Using ipywidgets to select which model, dataset, and other parameters to generate LLM predictions

In [None]:
%autoreload

# General lauyout
input_style = dict(
    description_width="fit-content"
)

# Dataset 
dataset_select = widgets.Dropdown(
    options=[("ClaimBuster", CustomDataset.CLAIMBUSTER), ("CheckThat", CustomDataset.CHECK_THAT)],
    value=CustomDataset.CLAIMBUSTER,
    description="Dataset:"
)

# Model and parameters
model_select = widgets.Dropdown(
    options=[
        ("Mistral 7B Instruct", HuggingFaceModel.MISTRAL_7B_INSTRUCT), 
        ("Mixtral Instruct", HuggingFaceModel.MIXTRAL_INSTRUCT),
        ("LLama 2 7B Chat", HuggingFaceModel.LLAMA2_7B_CHAT)],
    value=HuggingFaceModel.MISTRAL_7B_INSTRUCT,
    description="Model:",
    style=input_style
)
max_new_tokens_int_text = widgets.IntText(
    value=64,
    description="Max new tokens:",
    style=input_style
)
batch_size = widgets.IntText(
    value=32,
    description="Batch size:",
    style=input_style
)
model_and_parameters = widgets.VBox(
    [model_select, max_new_tokens_int_text, batch_size],
)

# Prompting type
prompting_type = widgets.Dropdown(
    options=[("Standard", PromptType.STANDARD), ("Chain-of-Thought", PromptType.CHAIN_OF_THOUGHT)],
    value=PromptType.STANDARD,
    description="Prompting type:",
    style=input_style
)
icl_usage = widgets.Dropdown(
    options=[("Zero-shot", ICLUsage.ZERO_SHOT), ("Few-shot", ICLUsage.FEW_SHOT)],
    value=ICLUsage.ZERO_SHOT,
    description="ICL usage:",
    style=input_style
)
prompt_use = widgets.VBox(
    [prompting_type, icl_usage]
)

accordion = widgets.Accordion([
    dataset_select,
    model_and_parameters,
    prompt_use
],
    titles=["Dataset", "Model and parameters", "Prompting type"],
)

title = widgets.HTML(
    "<h1>Generation of predictions using LLMs</h1>",
)
description = widgets.HTML(
    "<div>Set the parameters to select what dataset, model and prompting to use when generating predictions. If you experience Cuda out of memory issues, please decrease the batch size.</div>",
    layout={"font-size": '14px'}
)
start_generation_button = widgets.Button(
    description="Start generation",
    disabled=False,
    button_style="success",
    layout={"height": "40px", "width": "calc(100% - 4px)"},
)

def handle_generation_click(_):
    

    print("#" * 50)
    print_padded_text("Starting generation with parameters")
    print_padded_text(f"Dataset: {dataset_select.value.value}")
    print_padded_text(f"Model: {model_select.value.name}")
    print_padded_text(f"Prompting type: {prompting_type.value.value}")
    print_padded_text(f"ICL usage: {icl_usage.value.value}")
    print("#" * 50) 
    # The generation of prediction is performed in the llm.py script to make sure the GPU resources are disposed of
    os.system(
        f"""python3 -m src.llm \
                --dataset={dataset_select.value.value} \
                --prompt-type={prompting_type.value.value} \
                --icl_usage={icl_usage.value.value} \
                --batch-size={batch_size.value} \
                --max-new-tokens={max_new_tokens_int_text.value} \
                --model-id={model_select.value.value}"""
    )


start_generation_button.on_click(handle_generation_click)


box = widgets.Box(
    [title, description, accordion, start_generation_button],
    layout=widgets.Layout(
        padding= '16px', 
        display= "flex", 
        flex_flow="column",
        align_items="stretch",
        border="1px solid black"
    )
) 
display(box)

### Cross validation

Using ipywidgets to select what models to run cross validation on

In [None]:
%autoreload

input_style = dict(
    description_width="fit-content"
)

dataset_select = widgets.Dropdown(
    options=[("ClaimBuster", CustomDataset.CLAIMBUSTER), ("CheckThat", CustomDataset.CHECK_THAT)],
    value=CustomDataset.CLAIMBUSTER,
    description="Dataset:"
)

model_select = widgets.Dropdown(
    options=[
        ("Mistral 7B Instruct", HuggingFaceModel.MISTRAL_7B_INSTRUCT), 
        ("Mixtral Instruct", HuggingFaceModel.MIXTRAL_INSTRUCT),
        ('LLama2 7B Chat', HuggingFaceModel.LLAMA2_7B_CHAT)
    ],
    value=HuggingFaceModel.MISTRAL_7B_INSTRUCT,
    description="Model:",
    style=input_style
)

prompting_type = widgets.Dropdown(
    options=[("Standard", PromptType.STANDARD), ("Chain-of-Thought", PromptType.CHAIN_OF_THOUGHT)],
    value=PromptType.STANDARD,
    description="Prompting type:",
    style=input_style
)
icl_usage = widgets.Dropdown(
    options=[("Zero-shot", ICLUsage.ZERO_SHOT), ("Few-shot", ICLUsage.FEW_SHOT)],
    value=ICLUsage.ZERO_SHOT,
    description="ICL usage:",
    style=input_style,
)

prompt_use = widgets.VBox(
    [prompting_type, icl_usage]
)

title = widgets.HTML(
    "<h1>Cross validation using LLMs</h1>",
)
description = widgets.HTML(
    "<div>Set the parameters to select what dataset, model and prompting to use when performing cross validation.</div>",
    layout={"font-size": '14px'}
)
start_cross_validation_button = widgets.Button(
    description="Start cross validation",
    disabled=False,
    button_style="success",
    layout={"height": "40px", "width": "calc(100% - 4px)"},
)

def handle_cross_validation_click(_):
    if all_configurations.value:
        datasets = [CustomDataset.CLAIMBUSTER.value, CustomDataset.CHECK_THAT.value]
        models = [HuggingFaceModel.MISTRAL_7B_INSTRUCT.name, HuggingFaceModel.MIXTRAL_INSTRUCT.name, HuggingFaceModel.LLAMA2_7B_CHAT.name]
        prompt_types = [PromptType.STANDARD.value, PromptType.CHAIN_OF_THOUGHT.value]
        icl_uses = [ICLUsage.ZERO_SHOT.value, ICLUsage.FEW_SHOT.value]
    else:
        datasets = [dataset_select.value.value]
        models = [model_select.value.name]
        prompt_types = [prompting_type.value.value]
        icl_uses = [icl_usage.value.value]
    for dataset, model, prompt_type, icl_usage in itertools.product(datasets, models, prompt_types, icl_uses):
        dataset_folder = os.path.join(
            "results",
            dataset,
            model,
            prompt_type,
            icl_usage,
        )
        dataset_path = os.path.join(dataset_folder, "generated_scores.csv")
        if not os.path.exists(dataset_path):
            print("No generated scores found")
            continue
        dataset_with_scores = pd.read_csv(dataset_path, index_col=0)
        crossval_folder = os.path.join(
            "data",
            dataset,
            "crossval"
        )
        label_column = "Verdict" if dataset == CustomDataset.CLAIMBUSTER.value else "check_worthiness"
        print("#" * 50)
        print_padded_text("Starting cross validation with parameters")
        print_padded_text(f"Dataset: {dataset}")
        print_padded_text(f"Model: {model}")
        print_padded_text(f"Prompting type: {prompt_type}")
        print_padded_text(f"ICL usage: {icl_usage}")
        print("#" * 50)
        result, _ = run_llm_cross_validation(
            data=dataset_with_scores, 
            crossval_folder=crossval_folder,
            save_folder=dataset_folder,
            label_column=label_column
        )
        display(result)

start_cross_validation_button.on_click(handle_cross_validation_click)

accordion = widgets.Accordion([
    dataset_select,
    model_select,
    prompt_use
],
    titles=["Dataset", "Model", "Prompting type"],
)

all_configurations = widgets.Checkbox(
    description="Run all configurations",
    value=False
)

def handle_all_configurations_toggled(change):
    if change['new'] == True:
        accordion.layout.display = "none"
    else:
        accordion.layout.display = "block"

all_configurations.observe(handle_all_configurations_toggled, names='value')


box = widgets.Box(
    [title, description, all_configurations, accordion, start_cross_validation_button],
    layout=widgets.Layout(
        padding= '16px', 
        display= "flex", 
        flex_flow="column",
        align_items="stretch",
        border="1px solid black"
    )
) 
display(box)

### LoRA finetuning (experiment E4)

#### ClaimBuster

In [None]:
for model_id in [HuggingFaceModel.MISTRAL_7B_INSTRUCT, HuggingFaceModel.LLAMA2_7B_CHAT]:
    os.system(
        f"""python3 -m src.lora_finetuning \
                --dataset={CustomDataset.CLAIMBUSTER.value} \
                --model-id={model_id.value} \
                --experiment={Experiment.FINE_TUNING.value}"""
    )
    result_path = f"results/ClaimBuster/{model_id.name}/lora/crossval.csv"
    results = pd.read_csv(result_path, index_col=0)
    display(f"Cross validation resuls for model {model_id.name}", results)

#### CheckThat

In [None]:
for model_id in [HuggingFaceModel.MISTRAL_7B_INSTRUCT, HuggingFaceModel.LLAMA2_7B_CHAT]:
    os.system(
        f"""python3 -m src.lora_finetuning \
                --dataset={CustomDataset.CHECK_THAT.value} \
                --model-id={model_id.value} \
                --experiment={Experiment.FINE_TUNING.value}"""
    )
    result_path = f"results/CheckThat/{model_id.name}/lora/crossval.csv"
    results = pd.read_csv(result_path, index_col=0)
    display(f"Cross validation resuls for model {model_id.name}", results)

### Summarize all results in one table for each dataset

In [None]:
def get_experiment_number(prompt_type: PromptType, icl_usage: ICLUsage) -> str:
    if prompt_type == PromptType.CHAIN_OF_THOUGHT:
        return "E3"
    if prompt_type == PromptType.LORA:
        return "E4"
    return "E1" if icl_usage == ICLUsage.ZERO_SHOT else "E2"

for dataset in [CustomDataset.CLAIMBUSTER, CustomDataset.CHECK_THAT]:
    all_results = pd.DataFrame(columns=["Exp. nr", "Model", "Prompt type", "ICL usage", "F1-macro", "Accuracy"])
    prompt_types = [PromptType.STANDARD, PromptType.CHAIN_OF_THOUGHT, PromptType.LORA]
    icl_uses = [ICLUsage.ZERO_SHOT, ICLUsage.FEW_SHOT]
    models = [HuggingFaceModel.MISTRAL_7B_INSTRUCT, HuggingFaceModel.MIXTRAL_INSTRUCT, HuggingFaceModel.LLAMA2_7B_CHAT]
    for index, (prompt_type, icl_usage, model) in enumerate(itertools.product(prompt_types, icl_uses, models)):
        result_path = os.path.join(
            "results",
            dataset.value,
            model.name,
            prompt_type.value,
            icl_usage.value if prompt_type != PromptType.LORA else '',
            "crossval.csv"
        )
        if not os.path.exists(result_path) or prompt_type == PromptType.LORA and icl_usage == ICLUsage.FEW_SHOT:
            continue
        result = pd.read_csv(result_path, index_col=0)    
        accuracy = result.loc["Average", "accuracy"]
        f1_macro = result.loc["Average", "macro avg_f1-score"]
        experiment_number = get_experiment_number(prompt_type, icl_usage)
        all_results.loc[index] = [
            experiment_number,
            model.name,
            prompt_type.value,
            icl_usage.value,
            f1_macro,
            accuracy,
        ]
    save_path = os.path.join("results", dataset.value, "all_results.csv")
    all_results.to_csv(save_path, index=False)
    display(dataset.value, all_results)

### Error analysis

#### ClaimBuster

In [None]:
%autoreload
predictions = []
models = [HuggingFaceModel.MISTRAL_7B_INSTRUCT, HuggingFaceModel.MIXTRAL_INSTRUCT, HuggingFaceModel.LLAMA2_7B_CHAT]
prompt_types = [PromptType.STANDARD, PromptType.CHAIN_OF_THOUGHT]
icl_usages = [ICLUsage.ZERO_SHOT, ICLUsage.FEW_SHOT]
model_names = []
claimbuster = load_claimbuster_dataset("data/ClaimBuster/datasets")

for model, prompt_type, icl_usage in itertools.product(models, prompt_types, icl_usages):
    predictions_path = f"results/ClaimBuster/{model.name}/{prompt_type.value}/{icl_usage.value}/predictions.csv"
    if os.path.exists(predictions_path):
        new_predictions = pd.read_csv(predictions_path, index_col=0)
        # Exclude results from LLama 2 model in the final report since they are lackluster
        model_name = f"{model.name} {prompt_type.value} {icl_usage.value}"
        if (model != HuggingFaceModel.LLAMA2_7B_CHAT):
            predictions.append(new_predictions) 
            model_names.append(model_name)
        create_confusion_matrix(
            claimbuster, 
            new_predictions.loc[claimbuster.index],
            save_path=os.path.join(os.path.dirname(predictions_path), "confusion-matrix.pdf")
        )
# LORA
for model in models:
    predictions_path = f"results/ClaimBuster/{model.name}/lora/predictions.csv"
    if os.path.exists(predictions_path):
        new_predictions = pd.read_csv(predictions_path, index_col=0)
        model_name = f"{model.name} LORA"
        if (model != HuggingFaceModel.LLAMA2_7B_CHAT):
            predictions.append(new_predictions)
            model_names.append(model_name)
        create_confusion_matrix(
            claimbuster, 
            new_predictions.loc[claimbuster.index]["prediction"],
            save_path=os.path.join(os.path.dirname(predictions_path), "confusion-matrix.pdf")
        )
mistral_scores = pd.read_csv(f"results/ClaimBuster/MISTRAL_7B_INSTRUCT/standard/zeroshot/generated_scores.csv", index_col=0)
reasoning = mistral_scores["raw_response"]
generate_error_analysis_report(
    claimbuster,
    predictions=predictions,
    model_names=model_names,
    folder_path=f"results/ClaimBuster",
    reasoning=reasoning
)

#### CheckThat 2021 Task 1a Tweets

In [None]:
%autoreload
predictions = []
models = [HuggingFaceModel.MISTRAL_7B_INSTRUCT, HuggingFaceModel.MIXTRAL_INSTRUCT, HuggingFaceModel.LLAMA2_7B_CHAT]
prompt_types = [PromptType.STANDARD, PromptType.CHAIN_OF_THOUGHT]
icl_usages = [ICLUsage.ZERO_SHOT, ICLUsage.FEW_SHOT]
model_names = []
checkthat = load_check_that_dataset("data/CheckThat")

for model, prompt_type, icl_usage in itertools.product(models, prompt_types, icl_usages):
    predictions_path = f"results/CheckThat/{model.name}/{prompt_type.value}/{icl_usage.value}/predictions.csv"
    if os.path.exists(predictions_path):
        new_predictions = pd.read_csv(predictions_path, index_col=0)
        model_name = f"{model.name} {prompt_type.value} {icl_usage.value}"
        # Exclude LLama 2  results since they are lackluster
        if model != HuggingFaceModel.LLAMA2_7B_CHAT:
            predictions.append(new_predictions) 
            model_names.append(model_name)
        create_confusion_matrix(
            checkthat, 
            new_predictions.loc[checkthat.index],
            label_column_name="check_worthiness",
            save_path=os.path.join(os.path.dirname(predictions_path), "confusion-matrix.pdf")
        )
# LORA
for model in models:
    predictions_path = f"results/CheckThat/{model.name}/lora/predictions.csv"
    if os.path.exists(predictions_path):
        new_predictions = pd.read_csv(predictions_path, index_col=0)
        model_name = f"{model.name} LORA"
        if model != HuggingFaceModel.LLAMA2_7B_CHAT:
            predictions.append(new_predictions)
            model_names.append(model_name)
        create_confusion_matrix(
            checkthat, 
            new_predictions.loc[checkthat.index]["prediction"],
            label_column_name="check_worthiness",
            save_path=os.path.join(os.path.dirname(predictions_path), "confusion-matrix.pdf")
        )

mistral_scores = pd.read_csv(f"results/CheckThat/MISTRAL_7B_INSTRUCT/standard/zeroshot/generated_scores.csv", index_col=0)
reasoning = mistral_scores["raw_response"]
generate_error_analysis_report(
    checkthat,
    predictions=predictions,
    model_names=model_names,
    folder_path="results/CheckThat",
    label_column_name="check_worthiness",
    text_column_name="tweet_text",
    reasoning=reasoning
)

## Relating truthfulness and check-worthiness (experiment E5)

Running check-worthiness detection on two datasets used for factual verification by fine-tuning an LLM on the ClaimBuster dataset.

### Generate check-worthiness predictions on factual verifiaction datasets

In [None]:
for dataset in [CustomDataset.LIAR, CustomDataset.RAWFC]:
    os.system(
        f"""python3 -m src.lora_finetuning \
                --dataset={dataset.value} \
                --model-id={HuggingFaceModel.MISTRAL_7B_INSTRUCT.value} \
                --experiment={Experiment.TRUTH_FULNESS.value}"""
    )

### LIAR

In [None]:
%autoreload

liar = pd.read_csv("results/LIAR/checkworthiness.csv", index_col=0)
liar.head()
label_to_name = {
    LIARLabel.PANTS_FIRE: "Pants on fire",
    LIARLabel.FALSE: "False",
    LIARLabel.BARELY_TRUE: "Barely true",
    LIARLabel.HALF_TRUE: "Half true",
    LIARLabel.MOSTLY_TRUE: "Mostly true",
    LIARLabel.TRUE: "True"
}
x = [label_to_name[label] for label in LIARLabel]
y = [liar[liar["label"] == label.value]["check_worthiness"].mean() for label in LIARLabel]
file_path = f"figures/liar/checkworthiness/checkworthiness.pdf"

os.makedirs(os.path.dirname(file_path), exist_ok=True)
show_bar_plot(
    x, 
    y, 
    xlabel="Label", 
    ylabel="Proportion of check-worthy claims", 
    y_ticks=[i*0.1 for i in range(11)],
    file_path=file_path, 
    force_save=True,
    use_bar_labels=True,
)

# Look for non-checkworthy claims for each label
non_checkworthy_folder = "results/LIAR/non-checkworthy"
os.makedirs(non_checkworthy_folder, exist_ok=True)
checkworthy_folder = "results/LIAR/checkworthy"
os.makedirs(checkworthy_folder, exist_ok=True)
for label in LIARLabel:
    non_checkworthy = liar.query(f"label == {label.value} and check_worthiness == 0")
    non_checkworthy.to_csv(f"{non_checkworthy_folder}/{label.name}.csv")
    checkworthy = liar.query(f"label == {label.value} and check_worthiness == 1")
    checkworthy.to_csv(f"{checkworthy_folder}/{label.name}")

### RAWFC

In [None]:
%autoreload

from plotting_utils import show_bar_plot
rawfc = pd.read_csv("results/RAWFC/checkworthiness.csv", index_col=0)
rawfc.head()
label_to_name = {
        RAWFCLabel.FALSE: "False",
        RAWFCLabel.HALF_TRUE: "Half true",
        RAWFCLabel.TRUE: "True"
    }
x = [label_to_name[label] for label in RAWFCLabel]
y = [rawfc[rawfc["label"] == label.value]["check_worthiness"].mean() for label in RAWFCLabel]
file_path = f"figures/rawfc/checkworthiness/checkworthiness.pdf"

os.makedirs(os.path.dirname(file_path), exist_ok=True)
show_bar_plot(
    x, 
    y, 
    xlabel="Label", 
    ylabel="Proportion of check-worthy claims", 
    y_ticks=[i*0.1 for i in range(11)],
    file_path=file_path, 
    force_save=True,
    use_bar_labels=True,
)

# Look for non-checkworthy claims for each label
non_checkworthy_folder = "results/RAWFC/non-checkworthy"
os.makedirs(non_checkworthy_folder, exist_ok=True)
checkworthy_folder = "results/RAWFC/checkworthy"
os.makedirs(checkworthy_folder, exist_ok=True)
for label in RAWFCLabel:
    non_checkworthy = rawfc.query(f"label == {label.value} and check_worthiness == 0")
    non_checkworthy.to_csv(f"{non_checkworthy_folder}/{label.name}.csv")
    checkworthy = rawfc.query(f"label == {label.value} and check_worthiness == 1")
    checkworthy.to_csv(f"{checkworthy_folder}/{label.name}")

## Inference time evaluation (experiment E6)

In [None]:
%autoreload

# General lauyout
input_style = dict(
    description_width="fit-content"
)

# Dataset 
dataset_select = widgets.Dropdown(
    options=[("ClaimBuster", CustomDataset.CLAIMBUSTER), ("CheckThat", CustomDataset.CHECK_THAT)],
    value=CustomDataset.CLAIMBUSTER,
    description="Dataset:"
)

# Model and parameters
model_select = widgets.Dropdown(
    options=[
        ("Mistral 7B Instruct", HuggingFaceModel.MISTRAL_7B_INSTRUCT), 
        ("Mixtral Instruct", HuggingFaceModel.MIXTRAL_INSTRUCT),
        ("LLama 2 7B Chat", HuggingFaceModel.LLAMA2_7B_CHAT)],
    value=HuggingFaceModel.MISTRAL_7B_INSTRUCT,
    description="Model:",
    style=input_style
)
max_new_tokens_int_text = widgets.IntText(
    value=64,
    description="Max new tokens:",
    style=input_style
)
batch_size = widgets.IntText(
    value=32,
    description="Batch size:",
    style=input_style
)
model_and_parameters = widgets.VBox(
    [model_select, max_new_tokens_int_text, batch_size],
)

# Prompting type
prompting_type = widgets.Dropdown(
    options=[("Standard", PromptType.STANDARD), ("Chain-of-Thought", PromptType.CHAIN_OF_THOUGHT), ("LORA", PromptType.LORA)],
    value=PromptType.STANDARD,
    description="Prompting type:",
    style=input_style
)
icl_usage = widgets.Dropdown(
    options=[("Zero-shot", ICLUsage.ZERO_SHOT), ("Few-shot", ICLUsage.FEW_SHOT)],
    value=ICLUsage.ZERO_SHOT,
    description="ICL usage:",
    style=input_style
)
prompt_use = widgets.VBox(
    [prompting_type, icl_usage]
)

accordion = widgets.Accordion([
    dataset_select,
    model_and_parameters,
    prompt_use
],
    titles=["Dataset", "Model and parameters", "Prompting type"],
)

title = widgets.HTML(
    "<h1>Inference time evaluation of LLMs</h1>",
)
description = widgets.HTML(
    "<div>Set the parameters to select what dataset, model and prompting to use when generating predictions. If you experience Cuda out of memory issues, please decrease the batch size.</div>",
    layout={"font-size": '14px'}
)
start_generation_button = widgets.Button(
    description="Start generation",
    disabled=False,
    button_style="success",
    layout={"height": "40px", "width": "calc(100% - 4px)"},
)

def handle_generation_click(_):
    print("#" * 50)
    print_padded_text("Starting generation with parameters")
    print_padded_text(f"Dataset: {dataset_select.value.value}")
    print_padded_text(f"Model: {model_select.value.name}")
    print_padded_text(f"Prompting type: {prompting_type.value.value}")
    print_padded_text(f"ICL usage: {icl_usage.value.value}")
    print("#" * 50)
    
    if "src" in os.getcwd():
        os.chdir("..")
    if prompting_type.value == PromptType.LORA:
        os.system(
            f"""python3 -m src.lora_finetuning \
                    --dataset={CustomDataset.CLAIMBUSTER.value} \
                    --model-id={model_select.value.value} \
                    --experiment={Experiment.INFERENCE_TIME.value}"""
        )
    else:
        os.system(
            f"""python3 -m src.llm \
                    --experiment={Experiment.INFERENCE_TIME.value} \
                    --dataset={dataset_select.value.value} \
                    --prompt-type={prompting_type.value.value} \
                    --icl_usage={icl_usage.value.value} \
                    --batch-size={batch_size.value} \
                    --max-new-tokens={max_new_tokens_int_text.value} \
                    --model-id={model_select.value.value}"""
        )

start_generation_button.on_click(handle_generation_click)


box = widgets.Box(
    [title, description, accordion, start_generation_button],
    layout=widgets.Layout(
        padding= '16px', 
        display= "flex", 
        flex_flow="column",
        align_items="stretch",
        border="1px solid black"
    )
) 
display(box)