In [None]:
import torch
import chromadb
import pandas as pd
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from datasets import load_dataset, load_from_disk, Dataset, load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from matplotlib import pyplot as plt

Note: You have to request access to Llama-3-8B-Instruct and login to huggingface to run this notebook.

- https://huggingface.co/docs/huggingface_hub/quick-start
- https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct

In [None]:
class Generator:
    def __init__(self):
        pass

    def __call__(self, prompt, max_new_tokens=100):
        torch.mps.empty_cache()
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
            self.model.device
        )
        attention_mask = torch.ones_like(input_ids)

        self.model.eval()
        with torch.no_grad():
            output_ids = self.model.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                attention_mask=attention_mask,
                num_return_sequences=1,
            )

        output_text = self.tokenizer.decode(
            output_ids[0][input_ids.shape[-1] :], skip_special_tokens=True
        )
        return output_text


class Llama3_8b(Generator):
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained(
            "meta-llama/Meta-Llama-3-8B-Instruct"
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            "meta-llama/Meta-Llama-3-8B-Instruct",
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )


generator = Llama3_8b()

<b style="color:red">IMPORTANT</b>: Before running the experiments contained in this notebook, make sure that you have created the "fever-fine-coarse" dataset using the `create_fever_fine_coarse_dataset.ipynb` notebook.

In [None]:
claims = load_from_disk("fever-fine-coarse")
print(f"{len(claims)} claims loaded")

sample_size = 10
print(f"Sample {sample_size} claims for experiments")
claims_sample = claims.shuffle(seed=150620241351).select(range(sample_size))
claims_sample

# Quantitative Evaluation of Meta-Llama-3-8B-Instruct on FEVER-Fine-Coarse

In [None]:
# Function to evaluate the results in the experiments

def eval_accuracy(filename):
    predictions = pd.read_csv(filename)

    total = 0
    correct = 0

    invalid = 0

    supports_total = 0
    supports_correct = 0

    refutes_total = 0
    refutes_correct = 0

    nef_supports = 0
    nef_refutes = 0

    for _, item in predictions.iterrows():
        label = item["label"]
        prediction = item["prediction"]

        if label == "SUPPORTS":
            supports_total += 1
            if label == prediction:
                supports_correct += 1
            elif prediction == "NOT ENOUGH EVIDENCE":
                nef_supports += 1
        elif label == "REFUTES":
            refutes_total += 1
            if label == prediction:
                refutes_correct += 1
            elif prediction == "NOT ENOUGH EVIDENCE":
                nef_refutes += 1
        
        total += 1
        if label == prediction:
            correct += 1
        elif prediction is None:
            invalid += 1


    print(f"Total Accuracy: {correct}/{total} ({correct/total:.2f})")
    print(f"Supports correct: {supports_correct}/{supports_total} ({supports_correct/supports_total:.2f})")
    print(f"Refutes correct: {refutes_correct}/{refutes_total} ({refutes_correct/refutes_total:.2f})")
    print(f"Not enough info (supports): {nef_supports}/{supports_total} ({nef_supports/supports_total:.2f})")
    print(f"Not enough info (refutes): {nef_refutes}/{refutes_total} ({nef_refutes/refutes_total:.2f})")
    print(f"Invalid predictions: {invalid} ({invalid/total:.2f})")

    return correct, supports_correct, supports_total, refutes_correct, refutes_total, nef_supports, nef_refutes 

## Experiment 1

Llama3-8b with gold evidence (fine) and 2 answer choices (SUPPORTS, REFUTES)

In [None]:
prompt_template = """
You are a helpful, smart, kind, and efficient AI assistant who always fulfills the user requests to the best of its abilities and strictly sticks to the given instructions.

Instructions:
You answer SUPPORTS if context EXPLICITLY supports the claim.
You answer REFUTES if the context EXPLICITLY refutes the claim.

Context:
{context}

Claim:
{claim}

Answer:
"""

set_seed(5243)

invalid_count = 0
predictions = []

for claim in tqdm(claims_sample):
    label = claim["label"]
    claim_text = claim["claim"]
    context = "\n".join(claim["evidence_fine"])
    prompt = prompt_template.format(context=context, claim=claim_text)
    response = generator(prompt, max_new_tokens=10)
    
    prediction = None
    if "SUPPORTS" in response:
        prediction = "SUPPORTS"
    elif "REFUTES" in response:
        prediction = "REFUTES"
    
    predictions.append({"claim": claim_text, "context": context, "label": label, "prediction": prediction})

# Store
predictions_df = pd.DataFrame(predictions)
predictions_df.to_csv("predictions-gold-fine-2way.csv", index=False)

# Evaluate
eval_accuracy("predictions-gold-fine-2way.csv")

## Experiment 2

Llama3-8b with gold evidence (fine) and 3 answer choices (SUPPORTS, REFUTES, NOT ENOUGH INFO)

In [None]:
prompt_template = """
You are a helpful, smart, kind, and efficient AI assistant who always fulfills the user requests to the best of its abilities and strictly sticks to the given instructions.

Instructions:
You answer SUPPORTS if context EXPLICITLY supports the claim.
You answer REFUTES if the context EXPLICITLY refutes the claim.
You answer NOT ENOUGH EVIDENCE if the context does not provide enough information to explicitly support or refute the claim.

Context:
{context}

Claim:
{claim}

Answer:
"""

set_seed(5243)

invalid_count = 0
predictions = []
for claim in tqdm(claims_sample):
    label = claim["label"]
    claim_text = claim["claim"]
    context = "\n".join(claim["evidence_fine"])
    prompt = prompt_template.format(context=context, claim=claim_text)
    response = generator(prompt, max_new_tokens=10)
    
    prediction = None
    if "SUPPORTS" in response:
        prediction = "SUPPORTS"
    elif "REFUTES" in response:
        prediction = "REFUTES"
    elif "NOT ENOUGH EVIDENCE" in response:
        prediction = "NOT ENOUGH EVIDENCE"
    
    predictions.append({"claim": claim_text, "context": context, "label": label, "prediction": prediction})

# Store
predictions_df = pd.DataFrame(predictions)
predictions_df.to_csv("predictions-gold-fine-3way.csv", index=False)

# Evaluate
eval_accuracy("predictions-gold-fine-3way.csv")

# Experiment 3

Llama3-8b with gold evidence (coarse) and 2 answer choices (SUPPORTS, REFUTES)

In [None]:
prompt_template = """
You are a helpful, smart, kind, and efficient AI assistant who always fulfills the user requests to the best of its abilities and strictly sticks to the given instructions.

Instructions:
You answer SUPPORTS if context EXPLICITLY supports the claim.
You answer REFUTES if the context EXPLICITLY refutes the claim.

Context:
{context}

Claim:
{claim}

Answer:
"""

set_seed(5243)

invalid_count = 0
predictions = []

for claim in tqdm(claims_sample):
    label = claim["label"]
    claim_text = claim["claim"]
    context = "\n".join(claim["evidence_coarse"])
    prompt = prompt_template.format(context=context, claim=claim_text)
    response = generator(prompt, max_new_tokens=10)
    
    prediction = None
    if "SUPPORTS" in response:
        prediction = "SUPPORTS"
    elif "REFUTES" in response:
        prediction = "REFUTES"
    
    predictions.append({"claim": claim_text, "context": context, "label": label, "prediction": prediction})

# Store
predictions_df = pd.DataFrame(predictions)
predictions_df.to_csv("predictions-gold-coarse.csv", index=False)

# Evaluate
eval_accuracy("predictions-gold-coarse.csv")

# Experiment 4

Llama3-8b with gold evidence (coarse) and 3 answer choices (SUPPORTS, REFUTES, NOT ENOUGH INFO)

In [None]:
prompt_template = """
You are a helpful, smart, kind, and efficient AI assistant who always fulfills the user requests to the best of its abilities and strictly sticks to the given instructions.

Instructions:
You answer SUPPORTS if context EXPLICITLY supports the claim.
You answer REFUTES if the context EXPLICITLY refutes the claim.
You answer NOT ENOUGH EVIDENCE if the context does not provide enough information to explicitly support or refute the claim.

Context:
{context}

Claim:
{claim}

Answer:
"""

set_seed(5243)

invalid_count = 0
predictions = []

for claim in tqdm(claims_sample):
    label = claim["label"]
    claim_text = claim["claim"]
    context = "\n".join(claim["evidence_coarse"])
    prompt = prompt_template.format(context=context, claim=claim_text)
    response = generator(prompt, max_new_tokens=10)
    
    prediction = None
    if "SUPPORTS" in response:
        prediction = "SUPPORTS"
    elif "REFUTES" in response:
        prediction = "REFUTES"
    elif "NOT ENOUGH EVIDENCE" in response:
        prediction = "NOT ENOUGH EVIDENCE"
    
    predictions.append({"claim": claim_text, "context": context, "label": label, "prediction": prediction})

# Store the results
predictions_df = pd.DataFrame(predictions)
predictions_df.to_csv("predictions-gold-coarse-3way.csv", index=False)

eval_accuracy("predictions-gold-coarse-3way.csv")

# Experiment 5

Llama3-8b without context (parametric knowledge only) and 2 answer choices (SUPPORTS, REFUTES)

Note: The fact verification task is to tell whether a claim is supported or refuted by the given evidence.
Here, the model just needs to predict whether it "thinks" the claim is "true" or "false" based on the knowledge it has been trained on.
Strictly speaking, this is not a fact verification task, but shall serve as a baseline for the fact verification task.

In [None]:
prompt_template = """
You are a helpful, smart, kind, and efficient AI assistant who always fulfills the user requests to the best of its abilities and strictly sticks to the given instructions.

Instructions:
You answer SUPPORTS if the claim is true.
You answer REFUTES if the claim is false.

Claim:
{claim}

Answer:
"""

set_seed(5243)

invalid_count = 0
predictions = []

for claim in tqdm(claims_sample):
    label = claim["label"]
    claim_text = claim["claim"]
    prompt = prompt_template.format(claim=claim_text)
    response = generator(prompt, max_new_tokens=10)

    prediction = None
    if "SUPPORTS" in response:
        prediction = "SUPPORTS"
    elif "REFUTES" in response:
        prediction = "REFUTES"

    predictions.append({"claim": claim_text, "label": label, "prediction": prediction})

# Store the results
predictions_df = pd.DataFrame(predictions)
predictions_df.to_csv("predictions-parametric.csv", index=False)

eval_accuracy("predictions_parametric.csv")

# Experiment 6

Llama3-8b without context (parametric knowledge only) and 3 answer choices (SUPPORTS, REFUTES, NOT ENOUGH INFO)

In [None]:
prompt_template = """
You are a helpful, smart, kind, and efficient AI assistant who always fulfills the user requests to the best of its abilities and strictly sticks to the given instructions.

Instructions:
You answer SUPPORTS if the claim is true.
You answer REFUTES if the claim is false.
You answer NOT ENOUGH EVIDENCE if your are not sure.

Claim:
{claim}

Answer:
"""

set_seed(5243)

invalid_count = 0
predictions = []

for claim in tqdm(claims_sample):
    label = claim["label"]
    claim_text = claim["claim"]
    prompt = prompt_template.format(claim=claim_text)
    response = generator(prompt, max_new_tokens=10)

    prediction = None
    if "SUPPORTS" in response:
        prediction = "SUPPORTS"
    elif "REFUTES" in response:
        prediction = "REFUTES"

    predictions.append({"claim": claim_text, "label": label, "prediction": prediction})

# Store the results
predictions_df = pd.DataFrame(predictions)
predictions_df.to_csv("predictions-parametric-3way.csv", index=False)

eval_accuracy("predictions-parametric-3way.csv")

## Experiment 7

Create the retriever and evaluate its accuracy.

In [None]:
client = chromadb.PersistentClient(path="./chroma_facts2")
embed = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2").encode

n = 500000
try:
    db = client.get_collection(f"claims-n{n}")
except:
    db = client.create_collection(
        name=f"claims-n{n}"
    )

    wiki_pages = load_dataset("fever/fever", "wiki_pages")["wikipedia_pages"]
    wiki_pages_sample = wiki_pages.shuffle(seed=42).select(range(n))

    pages = set()
    for page in tqdm(wiki_pages_sample):
        pages.add(page["text"])

    claims = load_from_disk("fever_fine_coarse")
    for claim in tqdm(claims):
        for txt in claim["evidence_coarse"]:
            pages.add(txt)
        for txt in claim["evidence_fine"]:
            pages.add(txt)

    pages = list(pages)

    for i, txt in enumerate(tqdm(pages)):
        documents = [txt]
        embeddings = [embed(txt).tolist()]
        db.add(ids=[str(i)], documents=documents, embeddings=embeddings)

retriever = lambda query, k: db.query(
    query_embeddings=embed(query).tolist(),
    n_results=k,
)["documents"][0]

In [None]:
for i in range(1, 11):
    data = []
    for claim in tqdm(claims_sample):
        evidence_retriever = retriever(claim["claim"], i)
        evidence_gt = claim["evidence_fine"] + claim["evidence_coarse"]

        evidence_found = False
        for evidence in evidence_gt:
            for evidence_ in evidence_retriever:
                if evidence in evidence_:
                    evidence_found = True

        data.append(
            {
                "claim": claim["claim"],
                "label": claim["label"],
                "evidence_found": evidence_found,
                "evidence_gt": evidence_gt,
                "evidence_retriever": evidence_retriever,
            }
        )

    data_df = pd.DataFrame(data)
    data_hf = Dataset.from_pandas(data_df)
    data_hf.save_to_disk(f"claims_sample_retrieval_2_top_{i}")

In [None]:
accuracies = []
for k in range(1, 11):
    data = load_from_disk(f"claims_sample_retrieval_2_top_{i}")

    count = 0
    for row in data:
        count += 1 if row["evidence_found"] else 0
    
    accuracy = count/len(data)
    accuracies.append(accuracy)

ax = plt.plot(range(1, 11), accuracies, marker='o', label='With Evidence (Total)')
plt.title("Top-k retrieval accuracy")
plt.ylim(0, 1)
plt.grid(alpha=0.25)
plt.tight_layout()
plt.xlabel("k")
plt.ylabel("Accuracy")
plt.show()


## Experiment 8

Evaluate the retriever and the generator together.

In [None]:
prompt_template = """
You are a helpful, smart, kind, and efficient AI assistant who always fulfills the user requests to the best of its abilities and strictly sticks to the given instructions.

Instructions:
You answer SUPPORTS if context EXPLICITLY supports the claim.
You answer REFUTES if the context EXPLICITLY refutes the claim.

Context:
{context}

Claim:
{claim}

Answer:
"""
for k in [1, 2, 3, 4, 5, 6, 10]:
    set_seed(5243)
    data = load_from_disk(f"claims_sample_retrieval_top_{k}")

    predictions = []
    for row in tqdm(data):
        label = row["label"]
        claim = row["claim"]
        context = "\n".join(row["evidence_retriever"])
        prompt = prompt_template.format(context=context, claim=claim)
        response = generator(prompt, max_new_tokens=10)
        
        prediction = None
        if "SUPPORTS" in response:
            prediction = "SUPPORTS"
        elif "REFUTES" in response:
            prediction = "REFUTES"
        
        predictions.append({"claim": claim, "context": context, "label": label, "prediction": prediction, "evidence_found": row["evidence_found"]})

    # stroe predictions
    predictions_df = pd.DataFrame(predictions)
    predictions_df.to_csv(f"predictions_claims_sample_retrieval_top_{k}.csv", index=False)

Compute the average context length of $k$ retrieved documents for $k = 1, 2, 3, 4, 5, 6, 10$.

In [None]:
from transformers import AutoTokenizer

avg_context_len_list = []
for k in [1, 2, 3, 4, 5, 6, 10]:
    predictions = pd.read_csv(f"predictions_claims_sample_retrieval_top_{k}.csv")

    context_len_avg = 0
    for _, item in tqdm(predictions.iterrows()):
        context = item["context"]
        tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
        inputs = tokenizer(context, return_tensors="pt")
        context_len = len(inputs["input_ids"][0])
        context_len_avg += context_len

    context_len_avg /= len(predictions)
    avg_context_len_list.append(context_len_avg)
print(avg_context_len_list)

# Plots

Note: To compute the plots you have to enter the data obtained from the evaluation function manually.

In [None]:
total_accuracies = []
supports_accuracies = []
refutes_accuracies = []

for k in [10]:
    predictions = pd.read_csv(f"predictions_claims_sample_retrieval_top_{k}.csv")

    label_supports = 0
    label_refutes = 0
    correct_supports = 0
    correct_refutes = 0

    for _, item in predictions.iterrows():
        label = item["label"]
        prediction = item["prediction"]
        claim = item["claim"]
        context = item["context"]
        evidence_found = bool(item["evidence_found"])

        if not evidence_found:
            continue

        if label == "SUPPORTS":
            label_supports += 1
            if label == prediction:
                correct_supports += 1
        elif label == "REFUTES":
            label_refutes += 1
            if label == prediction:
                correct_refutes += 1
        else:
            print("Unknown label...")

    total_accuracies.append((correct_supports + correct_refutes)/(label_supports + label_refutes))
    supports_accuracies.append(correct_supports/label_supports)
    refutes_accuracies.append(correct_refutes/label_refutes)

print(f"Total accuracies: {total_accuracies}")
print(f"Supports accuracies: {supports_accuracies}")
print(f"Refutes accuracies: {refutes_accuracies}")

In [None]:
avg_context_length = [184.72, 334.745, 466.355, 608.745, 736.585, 1359.502512562814]
total_accuracies_with_evidence = [0.9764705882352941, 0.970873786407767, 0.9629629629629629, 0.9724770642201835, 0.9553571428571429, 0.9473684210526315]
supports_accuracies_with_evidence = [0.9622641509433962, 0.9682539682539683, 0.9545454545454546, 0.9696969696969697, 0.9558823529411765, 0.9285714285714286]
refutes_accuracies_with_evidence = [1.0, 0.975, 0.9761904761904762, 0.9767441860465116, 0.9545454545454546, 0.9772727272727273]
total_accuracies_without_evidence = [0.6782608695652174, 0.6494845360824743, 0.6521739130434783, 0.6703296703296703, 0.6363636363636364, 0.6235294117647059]
supports_accuracies_without_evidence = [0.33962264150943394, 0.27906976744186046, 0.225, 0.3, 0.23684210526315788, 0.19444444444444445]
refutes_accuracies_without_evidence = [0.967741935483871, 0.9444444444444444, 0.9807692307692307, 0.9607843137254902, 0.94, 0.9387755102040817]


plt.figure(figsize=(12, 8))

plt.plot(avg_context_length, total_accuracies_with_evidence, marker='o', label='With Evidence (Total)')
plt.plot(avg_context_length, supports_accuracies_with_evidence, marker='o', label='With Evidence (Supports)')
plt.plot(avg_context_length, refutes_accuracies_with_evidence, marker='o', label='With Evidence (Refutes)')
plt.plot(avg_context_length, total_accuracies_without_evidence, linestyle='dashed', marker='o', label='Without Evidence (Total)')
plt.plot(avg_context_length, supports_accuracies_without_evidence, linestyle='dashed', marker='o', label='Without Evidence (Supports)')
plt.plot(avg_context_length, refutes_accuracies_without_evidence, linestyle='dashed', marker='o', label='Without Evidence (Refutes)')


plt.xlabel('Avg. Context Length')
plt.ylabel('Accuracy')
plt.title('Accuracies w.r.t. Avg. Context Length')
plt.legend(loc='center right', bbox_to_anchor=(0.95, 0.7))
plt.grid(alpha=0.25)
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

filenames = [
    "predictions-parametric-2way.csv",
    "predictions-parametric-3way.csv",
    "predictions-gold-fine-2way.csv",
    "predictions-gold-fine-3way.csv",
    "predictions-gold-coarse-2way.csv",
    "predictions-gold-coarse-3way.csv",
]

avg_context_length = [
    "Parametric 2-way",
    "Parametric 3-way",
    "Gold fine 2-way",
    "Gold fine 3-way",
    "Gold coarse 2-way",
    "Gold fine 3-way",
]

correct_list = []
total_list = []
supp_correct_list = []
supp_total_list = []
ref_correct_list = []
ref_total_list = []
nef_supp_list = []
nef_ref_list = []

for filename in filenames:
    correct, supp_correct, supp_total, ref_correct, ref_total, nef_supp, nef_ref = eval_accuracy(filename)
    correct_list.append(correct)
    total_list.append(supp_total + ref_total)
    supp_correct_list.append(supp_correct)
    supp_total_list.append(supp_total)
    ref_correct_list.append(ref_correct)
    ref_total_list.append(ref_total)
    nef_supp_list.append(nef_supp)
    nef_ref_list.append(nef_ref)

x = np.arange(len(avg_context_length))
bar_width = 0.12

fig, ax = plt.subplots(figsize=(14, 8))

ax.bar(
    x - 2.75 * bar_width,
    total_list,
    width=bar_width,
    color="orange",
    edgecolor="orange",
    hatch="/",
    alpha=0.25,
    label="No. of claims",
)
ax.bar(
    x - 2.75 * bar_width,
    correct_list,
    width=bar_width,
    color="orange",
    label="No. of correct predictions",
)


ax.bar(
    x - 1 * bar_width,
    supp_total_list,
    width=bar_width,
    color="green",
    edgecolor="green",
    hatch="/",
    alpha=0.25,
    label="No. of claims with label 'SUPPORTS'",
)
ax.bar(
    x - 1 * bar_width,
    supp_correct_list,
    width=bar_width,
    color="green",
    label="No. of correct predictions for label 'SUPPORTS'",
)
ax.bar(
    x - 0 * bar_width,
    nef_supp_list,
    width=bar_width,
    color="grey",
    hatch="o",
    alpha=0.5,
    label="No. of claims with label 'SUPPORTS' but predicted 'NOT ENOUGH EVIDENCE'",
)

ax.bar(
    x + 1.75 * bar_width,
    ref_total_list,
    width=bar_width,
    color="red",
    edgecolor="red",
    hatch="/",
    alpha=0.25,
    label="No. of claims with label 'REFUTES'",
)
ax.bar(
    x + 1.75 * bar_width,
    ref_correct_list,
    width=bar_width,
    color="red",
    label="No. of correct predictions for label 'REFUTES'",
)
ax.bar(
    x + 2.75 * bar_width,
    nef_ref_list,
    width=bar_width,
    color="grey",
    hatch="O",
    alpha=0.5,
    label="No. of claims with label 'REFUTES' but predicted 'NOT ENOUGH EVIDENCE'",
)

ax.set_ylabel("No. of Claims")
ax.set_xticks(x)
ax.set_xticklabels(avg_context_length)
ax.legend()
ax.grid(alpha=0.25)
plt.show()

## Chain of Thought (CoT) Evaluation

In [None]:
CoT_prompt_template = """
You are a helpful, smart, kind, and efficient AI assistant who always fulfills the user requests to the best of its abilities and strictly sticks to the given instructions.

Instructions:
You answer SUPPORTS if context EXPLICITLY supports the claim.
You answer REFUTES if the context EXPLICITLY refutes the claim.
You answer NOT ENOUGH EVIDENCE if the context does not provide enough information to explicitly support or refute the claim.

Context:
Barack Hussein Obama II[a] (born August 4, 1961) is an American politician who served as the 44th president of the United States from 2009 to 2017. As a member of the Democratic Party, he was the first African-American president in United States history. Obama previously served as a U.S. senator representing Illinois from 2005 to 2008, as an Illinois state senator from 1997 to 2004, and as a community service organizer, civil rights lawyer, and university lecturer. 

Claim:
Obama served as a US senator before becoming president.

Answer:
The context states that Barack Obama served as United States senator representing Illinois from 2005 to 2008.
It also mentions that he served as president from 2009 to 2017.
Hence, the context SUPPORTS the claim.

Context:
{context}

Claim:
{claim}

Answer:
"""

predictions = pd.read_csv("predictions-gold-coarse-3way.csv")

total = len(predictions)
counter = 0
new_predictions = []
for _, item in predictions.iterrows():
    counter += 1
    print(f"Processing {counter}/{total}", end="\r")

    if item["label"] != item["prediction"]:
        label = item["label"]
        prediction = item["prediction"]
        context = item["context"]
        claim = item["claim"]

        if label == "REFUTES" and label != prediction:
            prompt = CoT_prompt_template.format(context=context, claim=claim)
            response = generator(prompt, max_new_tokens=100)

            new_prediction = None
            if "SUPPORTS" in response:
                new_prediction = "SUPPORTS"
            elif "REFUTES" in response:
                new_prediction = "REFUTES"
            elif "NOT ENOUGH EVIDENCE" in response:
                new_prediction = "NOT ENOUGH EVIDENCE"

            new_predictions.append({"claim": claim, "context": context, "label": label, "old_prediction": prediction, "new_prediction": "REFUTES", "response": response})

new_predictions_df = pd.DataFrame(new_predictions)
new_predictions_df.to_csv("predictions-gold-coarse-3way-refutes-CoT.csv", index=False)

In [None]:
new_predictions_df = pd.read_csv("predictions-gold-coarse-3way-refutes-CoT.csv")

total = 0
correct = 0
for _, item in new_predictions_df.iterrows():
    total += 1
    label = item["label"]
    prediction = item["new_prediction"]
    response = "".join(item["response"].split("\n")[0:2])
    if "REFUTES" in response:
        print(response)
        correct += 1

print(f"Total Accuracy: {correct}/{total} ({correct/total:.2f})")