In [None]:
from datasets import load_dataset, Dataset
from tqdm import tqdm
import pandas as pd

In [None]:
fever = wiki_pages = load_dataset("fever/fever", "v1.0", trust_remote_code=True)["paper_test"]
fever_gold = load_dataset("copenlu/fever_gold_evidence")["test"]
wiki_pages = load_dataset("fever/fever", "wiki_pages")["wikipedia_pages"]

# Create a dictionary so we can retrieve the text of a wikipedia page by its url faster
wiki_dict = {}
for item in tqdm(wiki_pages):
    wiki_url = item["id"]
    wiki_text = item["text"]
    wiki_dict[wiki_url] = wiki_text

In [None]:
# Create a dictionary with the FEVER claims and their evidence (coarse-grained)
claims_coarse = {}
for i, item in enumerate(fever):
    claim = item["claim"]

    evidence_url = item["evidence_wiki_url"]
    evidence = wiki_dict.get(evidence_url, "")

    if claim not in claims_coarse:
        claims_coarse[claim] = set()
    claims_coarse[claim].add(evidence)

# Verify that there are no claims without evidence
for claim, evidence_set in claims_coarse.items():
    missing_evidence = False
    if len(evidence_set) == 0:
        missing_evidence = True
    assert not missing_evidence, "There is a claim without evidence."
    claims_coarse[claim] = list(evidence_set)

print("Loaded coarse-grained FEVER claims")
print("Number of coarse-grained FEVER claims:", len(claims_coarse))

In [None]:
# Create a dictionary with the FEVER claims and their evidence (fine-grained)
claims_fine = {}
for i, item in enumerate(fever_gold):
    claim = item["claim"]
    evidence = item["evidence"][0][2]
    if claim not in claims_fine:
        claims_fine[claim] = set()
    claims_fine[claim].add(evidence)

# Verify that there are no claims without evidence
for claim, evidence_set in claims_fine.items():
    missing_evidence = False
    if len(evidence_set) == 0:
        missing_evidence = True
    assert not missing_evidence, "There is a claim without evidence."
    claims_fine[claim] = list(evidence_set)

print("Loaded fine-grained FEVER claims")
print("Number of fine-grained FEVER claims:", len(claims_fine))

In [None]:
# Create a dictionary with the FEVER claims and their labels
labels = {}
for item in fever:
    claim = item["claim"]
    label = item["label"]
    if claim not in labels:
        labels[claim] = set()
    labels[claim].add(label)

# Verify that there are no claims with multiple labels
clash_count = 0
for claim, label_set in labels.items():
    if len(label_set) > 1:
        clash_count += 1
    labels[claim] = list(label_set)

print("Loaded FEVER labels 🎉")
print("Number of FEVER labels:", len(labels))
print("Number of claims with multiple labels:", clash_count)

In [None]:
# Merge the coarse-grained and fine-grained FEVER claims and store dataset

mismatch = False
for claim in claims_coarse:
    if claim not in claims_fine:
        print("Missing fine-grained evidence for claim:", claim)
        mismatch = True

for claim in claims_fine:
    if claim not in claims_coarse:
        print("Missing coarse-grained evidence for claim:", claim)
        mismatch = True


assert not mismatch, "Mismatch between coarse-grained and fine-grained FEVER claims"

claims = []
for claim, labels_list in labels.items():
    if len(labels_list) != 1:
        continue

    evidence_coarse = claims_coarse[claim]
    evidence_fine = claims_fine[claim]

    claims.append({
        "claim": claim,
        "label": labels_list[0],
        "evidence_coarse": list(evidence_coarse),
        "evidence_fine": list(evidence_fine)
    })

print("Merged coarse-grained and fine-grained FEVER claims")
print("Number of merged FEVER claims:", len(claims)) 

claims_df = pd.DataFrame(claims)
claims_dataset = Dataset.from_pandas(claims_df)

print(f"Filter out claims with label 'NOT ENOUGH INFO'")
claims_filtered = claims_dataset.filter(lambda x: "NOT ENOUGH INFO" not in x["label"])
print(f"{len(claims_filtered)} claims remaining")

print(f"Filter out claims without any coarse evidence")
claims_filtered = claims_filtered.filter(lambda x: any("" != s and not s.isspace() for s in x["evidence_coarse"]))
print(f"{len(claims_filtered)} claims remaining")

print("Filter out claims without any fine evidence")
claims_filtered = claims_filtered.filter(lambda x: any("" != s and not s.isspace() for s in x["evidence_fine"]))
print(f"{len(claims_filtered)} claims remaining")

claims_filtered.save_to_disk("fever-fine-coarse")

print("Saved merged FEVER claims to 💿")