Main

In [None]:
import pandas as pd
import re

def sentence_tokenize(text):
    return [s.strip() for s in re.split(r'(?<=[.!?])\s+', str(text).strip()) if s]

def load_and_process_csv(path):
    df = pd.read_csv(path)

    processed = []
    for _, row in df.iterrows():
        claim = row['Statement']
        context = row['Context']
        evidence = row['Evidence']
        label = int(row['labels'])

        sentences = sentence_tokenize(context)
        rationale_indices = [
            i for i, s in enumerate(sentences)
            if evidence.strip() in s or s in evidence.strip()
        ]

        processed.append({
            "claim": claim,
            "evidences": sentences,
            "label": label,
            "rationale_indices": rationale_indices
        })
    return processed

train_data = load_and_process_csv("./data/train_clean.csv")
test_data = load_and_process_csv("./data/test_clean.csv")


In [2]:
from collections import Counter

# Đếm số lượng mỗi nhãn trong train/test
train_labels = [item["label"] for item in train_data]
test_labels = [item["label"] for item in test_data]

train_counts = Counter(train_labels)
test_counts = Counter(test_labels)

# Tính tỉ lệ %
def display_label_stats(counts, name):
    total = sum(counts.values())
    print(f"\n📊 {name} label distribution:")
    for label, count in counts.items():
        percent = 100 * count / total
        name_label = "SUPPORTED" if label == 0 else "REFUTED"
        print(f"  {name_label} ({label}): {count} samples ({percent:.2f}%)")

# Hiển thị
display_label_stats(train_counts, "Train")
display_label_stats(test_counts, "Test")



📊 Train label distribution:
  SUPPORTED (0): 1751 samples (51.36%)
  REFUTED (1): 1658 samples (48.64%)

📊 Test label distribution:
  REFUTED (1): 468 samples (47.95%)
  SUPPORTED (0): 508 samples (52.05%)


In [3]:
import torch
from transformers import RobertaTokenizer
from torch_geometric.data import Data

tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

def encode_graph_sample(sample, max_len=128):
    inputs = [sample['claim'] + " [SEP] " + sent for sent in sample['evidences']]
    encoding = tokenizer(inputs, padding='max_length', truncation=True, max_length=max_len, return_tensors='pt')

    num_nodes = len(inputs)
    if num_nodes < 2:
        return None

    edge_index = torch.combinations(torch.arange(num_nodes), r=2).T
    edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1).long()

    return Data(
        input_ids=encoding['input_ids'],
        attention_mask=encoding['attention_mask'],
        edge_index=edge_index,
        y=torch.tensor(sample['label']),
        rationale=torch.tensor(sample['rationale_indices']),
        num_nodes=num_nodes
    )

train_graph = [encode_graph_sample(s) for s in train_data if encode_graph_sample(s) is not None]
test_graph = [encode_graph_sample(s) for s in test_data if encode_graph_sample(s) is not None]


In [4]:
print(f"✅ Training graphs: {len(train_graph)}")
print(f"✅ Test graphs:     {len(test_graph)}")

✅ Training graphs: 3409
✅ Test graphs:     976


In [None]:
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader

def collate(batch):
    return Batch.from_data_list(batch)

train_loader = DataLoader(train_graph, batch_size=8, shuffle=True, collate_fn=collate)
test_loader = DataLoader(test_graph, batch_size=8, shuffle=False, collate_fn=collate)


In [None]:
import torch.nn as nn
import torch.nn.functional as F
from transformers import RobertaModel
from torch_geometric.nn import GCNConv, global_mean_pool

class SaGP(nn.Module):
    def __init__(self, hidden_dim=768):
        super().__init__()
        self.encoder = RobertaModel.from_pretrained("roberta-base")
        self.gcn1 = GCNConv(hidden_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, 2)
        self.rationale_head = nn.Linear(hidden_dim, 1)

    def forward(self, input_ids, attention_mask, edge_index, batch_index):
        with torch.no_grad():
            out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        # out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)

        x = out.last_hidden_state[:, 0, :]  # [CLS] token
        x = F.relu(self.gcn1(x, edge_index))
        x = F.relu(self.gcn2(x, edge_index))

        pooled = global_mean_pool(x, batch_index)  # [batch_size, hidden_dim]
        logits = self.classifier(pooled)

        rationale_scores = torch.sigmoid(self.rationale_head(x)).squeeze(-1)  # [num_nodes]
        return logits, rationale_scores


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SaGP().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(10):
    model.train()
    total_loss = 0

    for batch in train_loader:
        input_ids = batch.input_ids.to(device)
        attention_mask = batch.attention_mask.to(device)
        edge_index = batch.edge_index.to(device)
        batch_index = batch.batch.to(device)
        labels = batch.y.to(device)

        logits, _ = model(input_ids, attention_mask, edge_index, batch_index)
        loss = loss_fn(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1, Loss: 295.9503
Epoch 2, Loss: 296.2538
Epoch 3, Loss: 296.0361
Epoch 4, Loss: 296.0292
Epoch 5, Loss: 295.4590
Epoch 6, Loss: 295.2324
Epoch 7, Loss: 294.8454
Epoch 8, Loss: 295.1835
Epoch 9, Loss: 294.3911
Epoch 10, Loss: 294.9427


In [8]:
from sklearn.metrics import classification_report

model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch.input_ids.to(device)
        attention_mask = batch.attention_mask.to(device)
        edge_index = batch.edge_index.to(device)
        batch_index = batch.batch.to(device)
        labels = batch.y.to(device)

        logits, _ = model(input_ids, attention_mask, edge_index, batch_index)
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())

print("Claim Classification:")
print(classification_report(all_labels, all_preds, target_names=["SUPPORTED", "REFUTED"]))


Claim Classification:
              precision    recall  f1-score   support

   SUPPORTED       0.52      0.97      0.68       508
     REFUTED       0.60      0.04      0.08       468

    accuracy                           0.53       976
   macro avg       0.56      0.51      0.38       976
weighted avg       0.56      0.53      0.39       976



In [9]:
print(f"Evaluated samples: {len(all_labels)}")

Evaluated samples: 976


In [10]:
from torch_geometric.data import Batch

def evaluate_rationale(model, data_loader, threshold=0.5):
    model.eval()
    all_precisions = []
    all_recalls = []
    all_f1s = []

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch.input_ids.to(device)
            attention_mask = batch.attention_mask.to(device)
            edge_index = batch.edge_index.to(device)
            batch_index = batch.batch.to(device)

            _, rationale_scores = model(input_ids, attention_mask, edge_index, batch_index)

            graph_sizes = torch.bincount(batch_index).tolist()
            num_graphs = len(graph_sizes)

            # Tách lại từng Data sample từ batch
            data_list = batch.to_data_list()

            start = 0
            for i, data_item in enumerate(data_list):
                size = data_item.num_nodes
                end = start + size
                scores = rationale_scores[start:end].cpu()
                pred_idxs = [j for j, s in enumerate(scores) if s > threshold]

                # --- Lấy ground truth rationale ---
                rationale_tensor = data_item.rationale
                if isinstance(rationale_tensor, torch.Tensor):
                    if rationale_tensor.ndim == 0:
                        true_idxs = [int(rationale_tensor.item())]
                    else:
                        true_idxs = [int(x.item()) for x in rationale_tensor]
                elif isinstance(rationale_tensor, (int, float)):
                    true_idxs = [int(rationale_tensor)]
                else:
                    true_idxs = list(rationale_tensor)

                true_idxs = [idx for idx in true_idxs if isinstance(idx, int) and idx >= 0]

                # --- Tính Precision / Recall / F1 ---
                if not true_idxs and not pred_idxs:
                    precision = recall = f1 = 1.0
                elif not pred_idxs:
                    precision = recall = f1 = 0.0
                else:
                    true_set = set(true_idxs)
                    pred_set = set(pred_idxs)
                    tp = len(true_set & pred_set)
                    precision = tp / len(pred_set) if pred_set else 0.0
                    recall = tp / len(true_set) if true_set else 0.0
                    f1 = 2 * precision * recall / (precision + recall + 1e-8) if (precision + recall) > 0 else 0.0

                all_precisions.append(precision)
                all_recalls.append(recall)
                all_f1s.append(f1)

                start = end

    print("Rationale Extraction Quality:")
    print(f"Precision: {sum(all_precisions)/len(all_precisions):.3f}")
    print(f"Recall:    {sum(all_recalls)/len(all_recalls):.3f}")
    print(f"F1-score:  {sum(all_f1s)/len(all_f1s):.3f}")


In [11]:
evaluate_rationale(model, test_loader, threshold=0.3)

Rationale Extraction Quality:
Precision: 0.053
Recall:    0.684
F1-score:  0.095


In [12]:
evaluate_rationale(model, test_loader, threshold=0.5)

Rationale Extraction Quality:
Precision: 0.266
Recall:    0.388
F1-score:  0.274


In [13]:
evaluate_rationale(model, test_loader, threshold=0.7)

Rationale Extraction Quality:
Precision: 0.316
Recall:    0.316
F1-score:  0.316


In [None]:
# torch.save(model.state_dict(), "./model/sagp_model.pt")