In [1]:
import torch
import pandas as pd
import re
import random
from torch_geometric.data import Data
from transformers import RobertaModel, AutoTokenizer
from torch_geometric.nn import GCNConv, global_mean_pool

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("roberta-base")  

In [2]:
class SaGP(torch.nn.Module):
    def __init__(self, hidden_dim=768):
        super().__init__()
        self.encoder = RobertaModel.from_pretrained("roberta-base")
        self.gcn1 = GCNConv(hidden_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.classifier = torch.nn.Linear(hidden_dim, 2)
        self.rationale_head = torch.nn.Linear(hidden_dim, 1)

    def forward(self, input_ids, attention_mask, edge_index, batch_index):
        with torch.no_grad():
            out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        x = out.last_hidden_state[:, 0, :]
        x = torch.nn.functional.relu(self.gcn1(x, edge_index))
        x = torch.nn.functional.relu(self.gcn2(x, edge_index))
        pooled = global_mean_pool(x, batch_index)
        logits = self.classifier(pooled)
        rationale_scores = torch.sigmoid(self.rationale_head(x)).squeeze(-1)
        return logits, rationale_scores


In [3]:
model = SaGP().to(device)
model.load_state_dict(torch.load("./model/sagp_model.pt", map_location=device))
model.eval()


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


SaGP(
  (encoder): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm(

In [4]:
def sentence_tokenize(text):
    return [s.strip() for s in re.split(r'(?<=[.!?])\s+', str(text).strip()) if s]

def encode_graph_sample(sample, max_len=128):
    inputs = [sample['claim'] + " [SEP] " + sent for sent in sample['evidences']]
    encoding = tokenizer(inputs, padding='max_length', truncation=True, max_length=max_len, return_tensors='pt')
    num_nodes = len(inputs)
    edge_index = torch.combinations(torch.arange(num_nodes), r=2).T
    edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1).long()
    return Data(
        input_ids=encoding['input_ids'],
        attention_mask=encoding['attention_mask'],
        edge_index=edge_index,
        y=torch.tensor(sample['label']),
        rationale=torch.tensor(sample['rationale_indices']),
        num_nodes=num_nodes
    )


In [5]:
def load_and_process_csv(path):
    df = pd.read_csv(path)
    data = []
    for _, row in df.iterrows():
        sentences = sentence_tokenize(row["Context"])
        evidence = str(row["Evidence"]).strip()
        rationale_indices = [
            i for i, s in enumerate(sentences)
            if evidence in s or s in evidence
        ]
        data.append({
            "claim": row["Statement"],
            "evidences": sentences,
            "label": int(row["labels"]),
            "evidence_text": evidence,
            "rationale_indices": rationale_indices
        })
    return data

test_data = load_and_process_csv("./data/test_clean.csv")


In [197]:
def demo_single_sample(model, dataset, threshold=0.5):
    model.eval()
    sample = random.choice(dataset)
    claim = sample["claim"]
    sentences = sample["evidences"]
    label = sample["label"]
    rationale_gt = sample["rationale_indices"]
    evidence_text = sample["evidence_text"]

    encoded = encode_graph_sample(sample)
    input_ids = encoded.input_ids.to(device)
    attn_mask = encoded.attention_mask.to(device)
    edge_index = encoded.edge_index.to(device)
    batch_index = torch.zeros(input_ids.size(0), dtype=torch.long).to(device)

    with torch.no_grad():
        logits, rationale_scores = model(input_ids, attn_mask, edge_index, batch_index)

    pred = torch.argmax(logits).item()
    rationale_pred = [i for i, s in enumerate(rationale_scores.cpu()) if s > threshold]
    # if not rationale_pred:
    #     rationale_pred = [torch.argmax(rationale_scores).item()]
    #     print(f"\n⚠️ Không có câu nào vượt threshold. Chọn Top-1 rationale thay thế.")


    if not rationale_pred:
        top2_indices = torch.topk(rationale_scores, k=2).indices.tolist()
        rationale_pred = top2_indices
        print(f"\n⚠️ Không có câu nào vượt threshold. Chọn Top-2 rationale thay thế:")




    print("📝 Claim:")
    print(claim)

    print("\n📜 Context Sentences:")
    for i, sent in enumerate(sentences):
        mark = ""
        if i in rationale_pred:
            mark += "🟩 Pred"
            print(f"  - ({i}) {sent} {mark}")
        # if i in rationale_gt:
        #     mark += "🟦 GT"
        #     # print(f"  - ({i}) {sent} {mark}")
        # print(f"  - ({i}) {sent} {mark}")

    print("\n📚 Ground-truth Evidence Text:")
    print(evidence_text)

    print("\n🎯 True Label:", "SUPPORTED" if label == 0 else "REFUTED")
    print("🔮 Predicted Label:", "SUPPORTED" if pred == 0 else "REFUTED")


In [228]:
demo_single_sample(model, test_data, threshold=0.7)


⚠️ Không có câu nào vượt threshold. Chọn Top-2 rationale thay thế:
📝 Claim:
Thí sinh theo quy định phải mang theo và xuất trình giấy tờ tùy thân gồm một trong các loại giấy tờ sau: CMND, CCCD, hộ chiếu.

📜 Context Sentences:
  - (0) Theo quy định, thí sinh (TS) phải mang theo và xuất trình giấy tờ tùy thân đã sử dụng đăng ký dự thi khi đến địa điểm thi. 🟩 Pred
  - (1) Giấy tờ tùy thân gồm một trong các loại giấy tờ sau (bản chính, còn hạn sử dụng): CMND, CCCD, hộ chiếu. 🟩 Pred

📚 Ground-truth Evidence Text:
Theo quy định, thí sinh phải mang theo và xuất trình giấy tờ tùy thân CMND, CCCD, hộ chiếu Giấy tờ tùy thân gồm một trong các loại giấy tờ sau

🎯 True Label: SUPPORTED
🔮 Predicted Label: SUPPORTED
