In [None]:
!pip install torch transformers datasets evaluate tqdm -q
!pip install evaluate
!pip install rouge_score
!pip uninstall torch torchvision torchaudio -y
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121


Found existing installation: torch 2.5.1+cu121
Uninstalling torch-2.5.1+cu121:
  Successfully uninstalled torch-2.5.1+cu121
Found existing installation: torchvision 0.20.1+cu121
Uninstalling torchvision-0.20.1+cu121:
  Successfully uninstalled torchvision-0.20.1+cu121
Found existing installation: torchaudio 2.5.1+cu121
Uninstalling torchaudio-2.5.1+cu121:
  Successfully uninstalled torchaudio-2.5.1+cu121
Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch
  Using cached https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp312-cp312-linux_x86_64.whl (780.4 MB)
Collecting torchvision
  Using cached https://download.pytorch.org/whl/cu121/torchvision-0.20.1%2Bcu121-cp312-cp312-linux_x86_64.whl (7.3 MB)
Collecting torchaudio
  Using cached https://download.pytorch.org/whl/cu121/torchaudio-2.5.1%2Bcu121-cp312-cp312-linux_x86_64.whl (3.4 MB)
Installing collected packages: torch, torchvision, torchaudio
Successfully installed torch-2.5.1+cu121 torchaudio-2.5.1+c

In [None]:
import torch
import torch.nn as nn
import math
from tqdm.notebook import tqdm
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration
from evaluate import load as load_metric

In [None]:
HYPER = {
    "model_name": "microsoft/prophetnet-large-uncased",
    "lambda_simp": 0.8,
    "lr": 2e-5,
    "batch_size": 2,             # ✅ giảm để tránh OOM
    "epochs": 100,                 # demo
    "patience": 10,
    "min_delta": 0.001,
    "max_input_len": 512,        # ✅ ProphetNet encoder limit
    "max_label_len": 128,
}


In [None]:
class SATSWrapper(nn.Module):
    def __init__(self, model_name=HYPER["model_name"], freq_table=None, lambda_simp=HYPER["lambda_simp"]):
        super().__init__()
        self.model = ProphetNetForConditionalGeneration.from_pretrained(model_name)
        self.freq_table = freq_table or {}
        self.lambda_simp = lambda_simp

    def compute_simp_loss(self, logits):
        """Tính mất mát đơn giản hóa dựa trên bảng tần suất từ"""
        probs = torch.softmax(logits, dim=-1)
        token_ids = torch.argmax(probs, dim=-1)
        batch_size, seq_len = token_ids.shape

        losses = []
        for i in range(batch_size):
            word_scores = []
            for t in token_ids[i]:
                word_scores.append(self.freq_table.get(int(t), 0.5))  # default=0.5 nếu không có trong bảng
            losses.append(sum(word_scores) / len(word_scores))
        return torch.tensor(losses, device=logits.device).mean()

    def forward(self, input_ids, attention_mask, labels):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        lm_loss = outputs.loss
        logits = outputs.logits
        simp_loss = self.compute_simp_loss(logits)
        total_loss = lm_loss + self.lambda_simp * simp_loss
        return total_loss

In [None]:
def build_dummy_freq_table(tokenizer):
    vocab = tokenizer.get_vocab()
    scores = {}
    for token, idx in vocab.items():
        freq = abs(hash(token)) % 1_000_000 + 2
        val = 1 / math.log(freq)
        if not math.isfinite(val):
            val = 0.5
        scores[idx] = (val - 0.1) / (1.2 - 0.1)  # normalize [0,1]
    return scores


In [None]:
print("🔹 Loading dataset...")
dataset = load_dataset("cnn_dailymail", "3.0.0", split="train")
dataset = dataset.select(range(int(len(dataset) * 0.002)))  # chỉ 0.1% để demo

tokenizer = ProphetNetTokenizer.from_pretrained(HYPER["model_name"])

def preprocess(example):
    model_inputs = tokenizer(
        example["article"],
        truncation=True,
        padding="max_length",
        max_length=HYPER["max_input_len"],
        return_tensors="pt"
    )
    labels = tokenizer(
        example["highlights"],
        truncation=True,
        padding="max_length",
        max_length=HYPER["max_label_len"],
        return_tensors="pt"
    )

    model_inputs = {k: v.squeeze(0) for k, v in model_inputs.items()}
    labels = {k: v.squeeze(0) for k, v in labels.items()}

    model_inputs["labels"] = labels["input_ids"]

    model_inputs["labels"][model_inputs["labels"] == tokenizer.pad_token_id] = -100
    model_inputs["attention_mask"] = (model_inputs["input_ids"] != tokenizer.pad_token_id).long()

    for k in model_inputs:
        if not isinstance(model_inputs[k], torch.Tensor):
            model_inputs[k] = torch.tensor(model_inputs[k])

    return model_inputs



dataset = dataset.map(preprocess, batched=False)


🔹 Loading dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

3.0.0/train-00000-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

3.0.0/train-00001-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

3.0.0/train-00002-of-00003.parquet:   0%|          | 0.00/259M [00:00<?, ?B/s]

3.0.0/validation-00000-of-00001.parquet:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

3.0.0/test-00000-of-00001.parquet:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

prophetnet.tokenizer: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]



Map:   0%|          | 0/1435 [00:00<?, ? examples/s]

In [None]:
def collate_fn(batch):
    input_ids = [torch.tensor(ex["input_ids"]) if not isinstance(ex["input_ids"], torch.Tensor) else ex["input_ids"] for ex in batch]
    attention_mask = [torch.tensor(ex["attention_mask"]) if not isinstance(ex["attention_mask"], torch.Tensor) else ex["attention_mask"] for ex in batch]
    labels = [torch.tensor(ex["labels"]) if not isinstance(ex["labels"], torch.Tensor) else ex["labels"] for ex in batch]

    return {
        "input_ids": torch.stack(input_ids),
        "attention_mask": torch.stack(attention_mask),
        "labels": torch.stack(labels)
    }

train_loader = DataLoader(dataset, batch_size=HYPER["batch_size"], shuffle=True, collate_fn=collate_fn)

sample = next(iter(train_loader))
print({k: v.shape for k, v in sample.items()})

{'input_ids': torch.Size([2, 512]), 'attention_mask': torch.Size([2, 512]), 'labels': torch.Size([2, 128])}


In [None]:
freq_table = build_dummy_freq_table(tokenizer)
model = SATSWrapper(freq_table=freq_table, lambda_simp=HYPER["lambda_simp"]).to("cuda")
optimizer = AdamW(model.parameters(), lr=HYPER["lr"])

best_loss = float("inf")
patience_counter = 0

for epoch in range(HYPER["epochs"]):
    print(f"\n🚀 Epoch {epoch+1}/{HYPER['epochs']}")
    model.train()
    running_loss = 0.0

    loop = tqdm(train_loader, total=len(train_loader))
    for batch in loop:
        input_ids = batch["input_ids"].to("cuda")
        attention_mask = batch["attention_mask"].to("cuda")
        labels = batch["labels"].to("cuda")

        loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    epoch_loss = running_loss / len(train_loader)
    print(f"✅ Epoch {epoch+1} Average Loss: {epoch_loss:.4f}")

    # EARLY STOPPING LOGIC
    if epoch_loss < best_loss - HYPER["min_delta"]:
        best_loss = epoch_loss
        patience_counter = 0
        torch.save(model.state_dict(), "best_model.pt")
        print("💾 Saved new best model.")
    else:
        patience_counter += 1
        print(f"⚠️ No improvement. Patience {patience_counter}/{HYPER['patience']}")
        if patience_counter >= HYPER["patience"]:
            print("🛑 Early stopping triggered.")
            break

model.safetensors:   0%|          | 0.00/1.57G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/293 [00:00<?, ?B/s]


🚀 Epoch 1/100


  0%|          | 0/718 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


✅ Epoch 1 Average Loss: 3.1361
💾 Saved new best model.

🚀 Epoch 2/100


  0%|          | 0/718 [00:00<?, ?it/s]

✅ Epoch 2 Average Loss: 2.4591
💾 Saved new best model.

🚀 Epoch 3/100


  0%|          | 0/718 [00:00<?, ?it/s]

✅ Epoch 3 Average Loss: 1.9113
💾 Saved new best model.

🚀 Epoch 4/100


  0%|          | 0/718 [00:00<?, ?it/s]

✅ Epoch 4 Average Loss: 1.4572
💾 Saved new best model.

🚀 Epoch 5/100


  0%|          | 0/718 [00:00<?, ?it/s]

✅ Epoch 5 Average Loss: 1.1080
💾 Saved new best model.

🚀 Epoch 6/100


  0%|          | 0/718 [00:00<?, ?it/s]

✅ Epoch 6 Average Loss: 0.8845
💾 Saved new best model.

🚀 Epoch 7/100


  0%|          | 0/718 [00:00<?, ?it/s]

✅ Epoch 7 Average Loss: 0.6188
💾 Saved new best model.

🚀 Epoch 8/100


  0%|          | 0/718 [00:00<?, ?it/s]

✅ Epoch 8 Average Loss: 0.4523
💾 Saved new best model.

🚀 Epoch 9/100


  0%|          | 0/718 [00:00<?, ?it/s]

✅ Epoch 9 Average Loss: 0.3411
💾 Saved new best model.

🚀 Epoch 10/100


  0%|          | 0/718 [00:00<?, ?it/s]

✅ Epoch 10 Average Loss: 0.2642
💾 Saved new best model.

🚀 Epoch 11/100


  0%|          | 0/718 [00:00<?, ?it/s]

✅ Epoch 11 Average Loss: 0.2353
💾 Saved new best model.

🚀 Epoch 12/100


  0%|          | 0/718 [00:00<?, ?it/s]

In [None]:
model.load_state_dict(torch.load("best_model.pt"))
model.eval()

sample_text = dataset[0]["article"]
inputs = tokenizer(sample_text, return_tensors="pt", truncation=True, max_length=HYPER["max_input_len"]).to("cuda")

summary_ids = model.model.generate(**inputs, max_length=150, num_beams=5, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

print("\nGenerated Simplified Summary:")
print(summary)

rouge = load_metric("rouge")
pred = [summary]
ref = [dataset[0]["highlights"]]
score = rouge.compute(predictions=pred, references=ref)
print("\n📊 ROUGE Score:")
print(score)