In [3]:
import torch
import torch.nn as nn
from data_utils import get_train_loader, get_val_loader, get_test_loader, set_seed
from baseline_transformer_architecture import create_small_transformer
from modeling_functions import validate_transformer
from optimizer_scheduler import get_optimizer, get_plateau_scheduler, linear_teacher_scheduler
from tokenizers import Tokenizer
from tqdm.auto import tqdm
import json
import matplotlib.pyplot as plt
from rouge_score import rouge_scorer

class LabelSmoothingLoss(nn.Module):
    def __init__(self, smoothing=0.1, ignore_index=1):
        super().__init__()
        self.smoothing = smoothing
        self.ignore_index = ignore_index

    def forward(self, pred, target):
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (pred.size(1) - 1))
            mask = target != self.ignore_index
            target = target.masked_fill(~mask, 0)
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
            true_dist.masked_fill_(~mask.unsqueeze(1), 0)
        return torch.mean(torch.sum(-true_dist * pred.log_softmax(dim=1), dim=1))


def compute_rouge_score(preds, targets, tokenizer):
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    total_score = 0
    for p, t in zip(preds, targets):
        pred_text = tokenizer.decode(p.tolist(), skip_special_tokens=True)
        target_text = tokenizer.decode(t.tolist(), skip_special_tokens=True)
        total_score += scorer.score(target_text, pred_text)['rougeL'].fmeasure
    return total_score / len(preds)


if __name__ == "__main__":
    torch.cuda.empty_cache()
    set_seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    config = {
        "vocab_size": 20000,
        "dropout": 0.1,
        "d_model": 384,
        "nhead": 6,
        "num_encoder_layers": 4,
        "num_decoder_layers": 4,
        "dim_feedforward": 1536
    }

    tokenizer = Tokenizer.from_file("cnn_bpe_tokenizer_20k.json")
    pad_idx = tokenizer.token_to_id("[PAD]")

    model = create_small_transformer(**config).to(device)
    optimizer = get_optimizer(model)
    plateau_scheduler = get_plateau_scheduler(optimizer)
    teacher_scheduler = linear_teacher_scheduler
    criterion = LabelSmoothingLoss(smoothing=0.1, ignore_index=pad_idx)

    train_loader = get_train_loader(tokenizer, batch_size=32, num_workers=2)
    val_loader = get_val_loader(tokenizer, batch_size=4, num_workers=0)
    test_loader = get_test_loader(tokenizer, batch_size=4, num_workers=0)

    history = {
        "train_loss": [],
        "val_loss": [],
        "train_rouge": [],
        "val_rouge": [],
        "test_loss": None,
        "test_rouge": None,
        "learning_rate": [],
        "teacher_forcing_ratio": [],
    }

    for epoch in range(5):
        model.train()
        total_loss = 0
        all_preds = []
        all_targets = []
        tf_ratio = teacher_scheduler.step()

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/15"):
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)
            attn_mask = batch["attention_mask"].to(device)

            optimizer.zero_grad()
            output = model(
                src=input_ids,
                tgt=labels,
                src_key_padding_mask=(attn_mask == 0),
                teacher_forcing_ratio=tf_ratio
            )
            logits = output.view(-1, output.size(-1))
            targets = labels[:, 1:].contiguous().view(-1)
            loss = criterion(logits, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
            optimizer.step()
            total_loss += loss.item()

            pred_ids = logits.argmax(dim=-1).view(labels[:, 1:].shape)
            all_preds.append(pred_ids.detach().cpu())
            all_targets.append(labels[:, 1:].detach().cpu())

        avg_train_loss = total_loss / len(train_loader)
        train_rouge = compute_rouge_score(torch.cat(all_preds), torch.cat(all_targets), tokenizer)

        torch.cuda.empty_cache()
        val_loss, val_rouge = validate_transformer(model, val_loader, criterion, tokenizer, device, pad_idx, max_length_generate=40)

        history["train_loss"].append(avg_train_loss)
        history["val_loss"].append(val_loss)
        history["train_rouge"].append(train_rouge)
        history["val_rouge"].append(val_rouge)
        history["learning_rate"].append(optimizer.param_groups[0]['lr'])
        history["teacher_forcing_ratio"].append(tf_ratio)

        plateau_scheduler.step(val_loss)
        print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Train ROUGE: {train_rouge:.4f}, Val ROUGE: {val_rouge:.4f}")

    torch.cuda.empty_cache()
    model.eval()
    total_test_loss = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)
            attn_mask = batch["attention_mask"].to(device)

            output = model(
                src=input_ids,
                tgt=labels,
                src_key_padding_mask=(attn_mask == 0)
            )
            logits = output.view(-1, output.size(-1))
            targets = labels[:, 1:].contiguous().view(-1)
            loss = criterion(logits, targets)
            total_test_loss += loss.item()

            pred_ids = logits.argmax(dim=-1).view(labels[:, 1:].shape)
            all_preds.append(pred_ids.detach().cpu())
            all_targets.append(labels[:, 1:].detach().cpu())

    avg_test_loss = total_test_loss / len(test_loader)
    test_rouge = compute_rouge_score(torch.cat(all_preds), torch.cat(all_targets), tokenizer)
    history["test_loss"] = avg_test_loss
    history["test_rouge"] = test_rouge

    print(f"\n Test Loss: {avg_test_loss:.4f} | Test ROUGE-L F1: {test_rouge:.4f}")

    with open("label_smooth_history3.json", "w") as f:
        json.dump(history, f, indent=2)

    epochs = list(range(1, len(history["train_loss"]) + 1))

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, history["train_loss"], label="Train Loss", marker="o")
    plt.plot(epochs, history["val_loss"], label="Validation Loss", marker="s")
    plt.axhline(y=history["test_loss"], color='r', linestyle='--', label=f"Test Loss: {history['test_loss']:.4f}")
    plt.title("Loss Over Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("loss_plot_smooth.png")
    plt.show()

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, history["train_rouge"], label="Train ROUGE-L", marker="o")
    plt.plot(epochs, history["val_rouge"], label="Validation ROUGE-L", marker="s")
    plt.axhline(y=history["test_rouge"], color='g', linestyle='--', label=f"Test ROUGE-L: {history['test_rouge']:.4f}")
    plt.title("ROUGE-L F1 Over Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("ROUGE-L F1 Score")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("rouge_plot_smooth.png")
    plt.show()


README.md:   0%|          | 0.00/15.6k [00:00<?, ?B/s]

train-00000-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

train-00001-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

train-00002-of-00003.parquet:   0%|          | 0.00/259M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

Epoch 1/15:   0%|          | 0/8973 [00:00<?, ?it/s]

  output = torch._nested_tensor_from_mask(


KeyboardInterrupt: 