In [None]:
import pandas as pd
import re

def sentence_tokenize(text):
    return [s.strip() for s in re.split(r'(?<=[.!?])\s+', str(text).strip()) if s]

def load_and_process_csv(path):
    df = pd.read_csv(path)
    processed = []

    for _, row in df.iterrows():
        claim = row['Statement']
        context = row['Context']
        evidence = row['Evidence']
        label = int(row['labels'])

        sentences = sentence_tokenize(context)
        rationale_indices = [
            i for i, s in enumerate(sentences)
            if evidence.strip() in s or s in evidence.strip()
        ]

        processed.append({
            "claim": claim,
            "evidences": sentences,
            "label": label,
            "rationale_indices": rationale_indices
        })

    return processed

train_data = load_and_process_csv("./data/train_clean.csv")
test_data = load_and_process_csv("./data/test_clean.csv")


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RobertaModel, RobertaTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

class IBModel(nn.Module):
    def __init__(self, hidden_dim=768):
        super().__init__()
        self.encoder = RobertaModel.from_pretrained("roberta-base")
        self.mask_head = nn.Linear(hidden_dim, 1)
        self.classifier = nn.Linear(hidden_dim, 2)

    def forward(self, claim, sentences):
        inputs = [claim + " [SEP] " + s for s in sentences]
        enc = tokenizer(inputs, return_tensors='pt', padding=True, truncation=True, max_length=128)

        with torch.no_grad():
            out = self.encoder(input_ids=enc['input_ids'].to(device),
                                attention_mask=enc['attention_mask'].to(device))

        cls_embeddings = out.last_hidden_state[:, 0, :]  # [num_sentences, hidden_dim]
        mask_logits = self.mask_head(cls_embeddings).squeeze(-1)  # [num_sentences]
        mask_probs = torch.sigmoid(mask_logits)

        masked_embeds = cls_embeddings * mask_probs.unsqueeze(1)
        pooled = masked_embeds.sum(dim=0) / (mask_probs.sum() + 1e-6)

        logits = self.classifier(pooled)
        return logits, mask_probs


In [3]:
def compute_loss(logits, label, masks, lambda_sparsity=1.0):
    ce_loss = F.cross_entropy(logits.unsqueeze(0), torch.tensor([label]).to(logits.device))
    sparsity = masks.mean()
    return ce_loss + lambda_sparsity * sparsity


In [4]:
model = IBModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

for epoch in range(10):
    model.train()
    total_loss = 0
    for sample in train_data:
        claim = sample["claim"]
        sentences = sample["evidences"]
        label = sample["label"]

        logits, masks = model(claim, sentences)
        loss = compute_loss(logits, label, masks, lambda_sparsity=0.5)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1, Loss: 2750.6581
Epoch 2, Loss: 2434.0114
Epoch 3, Loss: 2377.1775
Epoch 4, Loss: 2364.2100
Epoch 5, Loss: 2359.8566
Epoch 6, Loss: 2357.6049
Epoch 7, Loss: 2356.3872
Epoch 8, Loss: 2355.3495
Epoch 9, Loss: 2354.8925
Epoch 10, Loss: 2353.8180


In [5]:
from sklearn.metrics import classification_report

model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for sample in test_data:
        claim = sample["claim"]
        sentences = sample["evidences"]
        label = sample["label"]

        logits, _ = model(claim, sentences)
        pred = torch.argmax(logits).item()

        all_preds.append(pred)
        all_labels.append(label)

print("Claim Classification:")
print(classification_report(all_labels, all_preds, target_names=["SUPPORTED", "REFUTED"]))


Claim Classification:
              precision    recall  f1-score   support

   SUPPORTED       0.55      0.74      0.63       508
     REFUTED       0.55      0.35      0.42       468

    accuracy                           0.55       976
   macro avg       0.55      0.54      0.53       976
weighted avg       0.55      0.55      0.53       976



In [6]:
def evaluate_ib_rationale(model, dataset, threshold=0.3):
    model.eval()
    all_precisions = []
    all_recalls = []
    all_f1s = []

    with torch.no_grad():
        for sample in dataset:
            claim = sample["claim"]
            sentences = sample["evidences"]
            gold = sample["rationale_indices"]

            _, scores = model(claim, sentences)
            pred = [i for i, s in enumerate(scores.cpu()) if s > threshold]

            true_set = set(gold)
            pred_set = set(pred)

            if not true_set and not pred_set:
                precision = recall = f1 = 1.0
            elif not pred_set:
                precision = recall = f1 = 0.0
            else:
                tp = len(true_set & pred_set)
                precision = tp / len(pred_set) if pred_set else 0.0
                recall = tp / len(true_set) if true_set else 0.0
                f1 = 2 * precision * recall / (precision + recall + 1e-8) if (precision + recall) > 0 else 0.0

            all_precisions.append(precision)
            all_recalls.append(recall)
            all_f1s.append(f1)

    print("Rationale Extraction Quality:")
    print(f"Precision: {sum(all_precisions)/len(all_precisions):.3f}")
    print(f"Recall:    {sum(all_recalls)/len(all_recalls):.3f}")
    print(f"F1-score:  {sum(all_f1s)/len(all_f1s):.3f}")


In [7]:
evaluate_ib_rationale(model, test_data, threshold=0.3)


Rationale Extraction Quality:
Precision: 0.316
Recall:    0.316
F1-score:  0.316


In [8]:
evaluate_ib_rationale(model, test_data, threshold=0.5)


Rationale Extraction Quality:
Precision: 0.316
Recall:    0.316
F1-score:  0.316


In [9]:
evaluate_ib_rationale(model, test_data, threshold=0.7)


Rationale Extraction Quality:
Precision: 0.316
Recall:    0.316
F1-score:  0.316


In [10]:
torch.save(model.state_dict(), "./model/ib_model.pt")

In [11]:
# model = IBModel().to(device)
# model.load_state_dict(torch.load("ib_model.pt", map_location=device))
# model.eval()