In [None]:
# ======================================================
# ðŸ§© Kaggle Patch: fsspec.implementations.chained missing
# ======================================================
import sys, types

# Táº¡o module giáº£ trÆ°á»›c khi datasets import fsspec
fake_chained = types.ModuleType("fsspec.implementations.chained")

class ChainedFileSystem:
    """Stub for Kaggle missing module."""
    def __init__(self, *args, **kwargs):
        pass

fake_chained.ChainedFileSystem = ChainedFileSystem
sys.modules["fsspec.implementations.chained"] = fake_chained

print("âœ… fsspec chained stub inserted successfully (Kaggle-safe)")


In [None]:
!pip install -q datasets==2.20.0 transformers==4.45.0 evaluate rouge_score tqdm joblib


In [None]:
import torch
import torch.nn as nn
import math, joblib
from tqdm.notebook import tqdm
from datasets import load_dataset
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration
from evaluate import load as load_metric

# Hyperparameters
HYPER = {
    "model_name": "microsoft/prophetnet-large-uncased",
    "lambda_simp": 0.8,
    "lr": 2e-5,
    "batch_size": 4,
    "epochs": 50,
    "patience": 10,
    "min_delta": 0.001,
    "max_input_len": 512,
    "max_label_len": 128,
}


In [None]:
print("ðŸ”¹ Loading dataset...")
dataset = load_dataset("cnn_dailymail", "3.0.0", split="train[:2000]")

# Chia train/val/test
total = len(dataset)
train_size = int(0.8 * total)
val_size = int(0.1 * total)
test_size = total - train_size - val_size
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])
print(f"ðŸ“Š Split: Train={len(train_set)}, Val={len(val_set)}, Test={len(test_set)}")

tokenizer = ProphetNetTokenizer.from_pretrained(HYPER["model_name"])

def preprocess(example):
    model_inputs = tokenizer(  
        example["article"], truncation=True, padding="max_length",
        max_length=HYPER["max_input_len"], return_tensors="pt"
    )
    labels = tokenizer(
        example["highlights"], truncation=True, padding="max_length",
        max_length=HYPER["max_label_len"], return_tensors="pt"
    )
    model_inputs = {k: v.squeeze(0) for k, v in model_inputs.items()}
    labels = {k: v.squeeze(0) for k, v in labels.items()}
    model_inputs["labels"] = labels["input_ids"]
    model_inputs["labels"][model_inputs["labels"] == tokenizer.pad_token_id] = -100
    model_inputs["attention_mask"] = (model_inputs["input_ids"] != tokenizer.pad_token_id).long()
    return model_inputs

train_set = [preprocess(x) for x in train_set]
val_set = [preprocess(x) for x in val_set]
test_set = [preprocess(x) for x in test_set]

def collate_fn(batch):
    return {
        "input_ids": torch.stack([b["input_ids"] for b in batch]),
        "attention_mask": torch.stack([b["attention_mask"] for b in batch]),
        "labels": torch.stack([b["labels"] for b in batch]),
    }

train_loader = DataLoader(train_set, batch_size=HYPER["batch_size"], shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=HYPER["batch_size"], collate_fn=collate_fn)


In [None]:
class SATSWrapper(nn.Module):
    def __init__(self, model_name=HYPER["model_name"], freq_table=None, lambda_simp=HYPER["lambda_simp"]):
        super().__init__()
        self.model = ProphetNetForConditionalGeneration.from_pretrained(model_name)
        self.freq_table = freq_table or {}
        self.lambda_simp = lambda_simp

    def compute_simp_loss(self, logits):
        probs = torch.softmax(logits, dim=-1)
        token_ids = torch.argmax(probs, dim=-1)
        batch_size, seq_len = token_ids.shape
        losses = []
        for i in range(batch_size):
            word_scores = [self.freq_table.get(int(t), 0.5) for t in token_ids[i]]
            losses.append(sum(word_scores) / len(word_scores))
        return torch.tensor(losses, device=logits.device).mean()

    def forward(self, input_ids, attention_mask, labels):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        lm_loss = outputs.loss
        logits = outputs.logits
        simp_loss = self.compute_simp_loss(logits)
        total_loss = lm_loss + self.lambda_simp * simp_loss
        return total_loss

def build_dummy_freq_table(tokenizer):
    vocab = tokenizer.get_vocab()
    scores = {}
    for token, idx in vocab.items():
        freq = abs(hash(token)) % 1_000_000 + 2
        val = 1 / math.log(freq)
        val = (val - 0.1) / (1.2 - 0.1)
        scores[idx] = min(max(val, 0.0), 1.0)
    return scores


In [None]:
freq_table = build_dummy_freq_table(tokenizer)
model = SATSWrapper(freq_table=freq_table).to("cuda")
optimizer = AdamW(model.parameters(), lr=HYPER["lr"])

best_val_loss = float("inf")
patience_counter = 0

for epoch in range(HYPER["epochs"]):
    model.train()
    total_train_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{HYPER['epochs']}"):
        input_ids, attention_mask, labels = batch["input_ids"].to("cuda"), batch["attention_mask"].to("cuda"), batch["labels"].to("cuda")
        loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_train_loss += loss.item()
    avg_train_loss = total_train_loss / len(train_loader)

    # Validation
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids, attention_mask, labels = batch["input_ids"].to("cuda"), batch["attention_mask"].to("cuda"), batch["labels"].to("cuda")
            val_loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            total_val_loss += val_loss.item()
    avg_val_loss = total_val_loss / len(val_loader)

    print(f"âœ… Epoch {epoch+1}: Train Loss={avg_train_loss:.4f} | Val Loss={avg_val_loss:.4f}")

    # Early stopping
    if avg_val_loss < best_val_loss - HYPER["min_delta"]:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "best_model.pt")
        joblib.dump(model, "best_model.pkl")
        print("ðŸ’¾ Saved best model (.pt + .pkl)")
    else:
        patience_counter += 1
        if patience_counter >= HYPER["patience"]:
            print("ðŸ›‘ Early stopping triggered")
            break


In [None]:
from evaluate import load as load_metric
rouge = load_metric("rouge")

# Load láº¡i model tá»‘t nháº¥t
model.load_state_dict(torch.load("best_model.pt"))
model.eval()

sample = test_set[0]
article = dataset[int(len(dataset)*0.9)]["article"]  # unseen article
inputs = tokenizer(article, return_tensors="pt", truncation=True, max_length=HYPER["max_input_len"]).to("cuda")

summary_ids = model.model.generate(**inputs, max_length=150, num_beams=5, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

print("ðŸ“° Original Summary:\n", dataset[int(len(dataset)*0.9)]["highlights"])
print("\nðŸ§  Generated Summary:\n", summary)
