In [6]:
import json
import ujson
import torch
from pathlib import Path
from sentence_transformers import CrossEncoder, InputExample, losses
from torch.utils.data import DataLoader
from tqdm import tqdm

# --------- Configuration ---------
DATA_DIR = Path("data")
TRAIN_FN = DATA_DIR / "train-claims-top100.json"
DEV_FN = DATA_DIR / "dev-claims-top100.json"
TEST_FN = DATA_DIR / "test-claims-top100.json"

MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
BATCH_SIZE = 32
NUM_EPOCHS = 3
TOP_M = 6
CHECKPOINT_DIR = DATA_DIR / "checkpoints"
FINAL_MODEL_DIR = DATA_DIR / "fine-tuned-model"

CHECKPOINT_DIR.mkdir(exist_ok=True, parents=True)
FINAL_MODEL_DIR.mkdir(exist_ok=True, parents=True)

# --------- Prepare Training Data ---------
def load_claim_texts(path):
    with open(path, encoding='utf-8') as f:
        data = ujson.load(f)
    return {cid: (item["claim_text"] if isinstance(item, dict) else item) for cid, item in data.items()}
def load_groundtruth_evidence(path):
    with open(path, encoding='utf-8') as f:
        data = ujson.load(f)
    gt = {}
    for cid, details in data.items():
        if isinstance(details, dict):
            gt[cid] = set(details.get("evidences", []))
        else:
            gt[cid] = set()
    return gt

def prepare_training_data(top100_path, claims_path, gt_evidence, evidences):
    with open(top100_path, encoding='utf-8') as f:
        top100 = ujson.load(f)
    claims = load_claim_texts(claims_path)

    examples = []
    for cid, entry in tqdm(top100.items(), desc="Preparing data"):
        claim = claims.get(cid, "")
        cand_ids = entry["evidences"] if isinstance(entry, dict) else entry
        gt_evid = gt_evidence.get(cid, set())

        for eid in cand_ids:
            label = 1.0 if eid in gt_evid else 0.0
            examples.append(InputExample(texts=[claim, evidences[eid]], label=label))

    return examples

# Load evidences
with open(DATA_DIR / "evidence.json", encoding='utf-8') as f:
    evid_dict = ujson.load(f)

# Ground truth evidence
train_gt = load_groundtruth_evidence(DATA_DIR / "train-claims.json")
dev_gt = load_groundtruth_evidence(DATA_DIR / "dev-claims.json")

# Training and Dev data
train_examples = prepare_training_data(TRAIN_FN, DATA_DIR / "train-claims.json", train_gt, evid_dict)
dev_examples = prepare_training_data(DEV_FN, DATA_DIR / "dev-claims.json", dev_gt, evid_dict)

# --------- Fine-tuning Cross-Encoder ---------
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CrossEncoder(MODEL_NAME, num_labels=1, device=device)
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)
loss_fn = losses.MSELoss(model=model)

Preparing data: 100%|██████████| 1228/1228 [00:00<00:00, 2391.02it/s]
Preparing data: 100%|██████████| 154/154 [00:00<00:00, 4268.25it/s]


In [5]:
from sentence_transformers.evaluation import SentenceEvaluator
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np

class CustomCrossEncoderEvaluator(SentenceEvaluator):
    def __init__(self, examples, name=""):
        self.examples = examples
        self.name = name

    def __call__(self, model, output_path=None, epoch=-1, steps=-1):
        texts = [example.texts for example in self.examples]
        labels = [example.label for example in self.examples]

        preds = model.predict(texts)
        preds_binary = (np.array(preds) > 0.5).astype(int)

        accuracy = accuracy_score(labels, preds_binary)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds_binary, average='binary')

        print(f"\n[{self.name} Evaluation] Epoch {epoch}, Step {steps}:")
        print(f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

        # Returning F1 for save_best_model criterion
        return f1

In [7]:
# Fine-tune
dev_evaluator = CustomCrossEncoderEvaluator(dev_examples, name="dev")

model.fit(
    train_dataloader=train_dataloader,
    epochs=NUM_EPOCHS,
    evaluator=dev_evaluator,
    evaluation_steps=1000,
    warmup_steps=int(0.1 * len(train_dataloader) * NUM_EPOCHS),
    output_path=str(FINAL_MODEL_DIR),
    save_best_model=True
)

Step,Training Loss,Validation Loss,Evaluator
1000,0.0523,No log,0.126316
2000,0.0547,No log,0.101449
3000,0.0458,No log,0.156463
4000,0.045,No log,0.080586
5000,0.0409,No log,0.221538
6000,0.0383,No log,0.144828
7000,0.037,No log,0.155932
8000,0.0256,No log,0.2263
9000,0.033,No log,0.257373
10000,0.0303,No log,0.271003



[dev Evaluation] Epoch 0.26055237102657636, Step 1000:
Accuracy: 0.9919, Precision: 0.6429, Recall: 0.0700, F1: 0.1263

[dev Evaluation] Epoch 0.5211047420531527, Step 2000:
Accuracy: 0.9919, Precision: 0.7368, Recall: 0.0545, F1: 0.1014

[dev Evaluation] Epoch 0.7816571130797291, Step 3000:
Accuracy: 0.9919, Precision: 0.6216, Recall: 0.0895, F1: 0.1565

[dev Evaluation] Epoch 1.0, Step 3838:
Accuracy: 0.9920, Precision: 0.6774, Recall: 0.0817, F1: 0.1458

[dev Evaluation] Epoch 1.0422094841063054, Step 4000:
Accuracy: 0.9919, Precision: 0.6875, Recall: 0.0428, F1: 0.0806

[dev Evaluation] Epoch 1.3027618551328817, Step 5000:
Accuracy: 0.9918, Precision: 0.5294, Recall: 0.1401, F1: 0.2215

[dev Evaluation] Epoch 1.563314226159458, Step 6000:
Accuracy: 0.9919, Precision: 0.6364, Recall: 0.0817, F1: 0.1448

[dev Evaluation] Epoch 1.8238665971860344, Step 7000:
Accuracy: 0.9919, Precision: 0.6053, Recall: 0.0895, F1: 0.1559

[dev Evaluation] Epoch 2.0, Step 7676:
Accuracy: 0.9918, Preci

In [None]:
# --------- Rerank Test Set ---------
model = CrossEncoder(str(FINAL_MODEL_DIR), device=device)

def rerank_and_save(top100_path, claims_path, evidences, out_dense_path, out_text_path):
    with open(top100_path, encoding='utf-8') as f:
        top100 = ujson.load(f)
    claims = load_claim_texts(claims_path)

    dense_out = {}
    text_out = {}

    for cid, entry in tqdm(top100.items(), desc="Reranking"):
        cand_ids = entry["evidences"] if isinstance(entry, dict) else entry
        claim = claims.get(cid, "")
        pairs = [(claim, evidences[eid]) for eid in cand_ids]

        scores = model.predict(pairs, batch_size=BATCH_SIZE)

        top_idx = scores.argsort()[-TOP_M:][::-1]
        top_ids = [cand_ids[i] for i in top_idx]

        dense_out[cid] = top_ids
        text_out[cid] = {
            "claim_text": claim,
            "ranked_evidences": [{"id": eid, "text": evidences[eid]} for eid in top_ids]
        }

    with open(out_dense_path, "w", encoding="utf-8") as f:
        json.dump(dense_out, f, ensure_ascii=False, indent=2)

    with open(out_text_path, "w", encoding="utf-8") as f:
        json.dump(text_out, f, ensure_ascii=False, indent=2)

rerank_and_save(
    TEST_FN,
    DATA_DIR / "dev-claims.json",
    evid_dict,
    DATA_DIR / f"dev-claims-top{TOP_M}-dense-fce.json",
    DATA_DIR / f"dev-claims-top{TOP_M}-text-fce.json"
)

print("Done!")

Reranking: 100%|██████████| 153/153 [00:12<00:00, 11.88it/s]

Done!



