In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchtext.datasets import Multi30k
from config import config
from model import Transformer
from tqdm import tqdm
import importlib
import data_utils
importlib.reload(data_utils)
from data_utils import (
    build_vocab, token_transform, collate_fn,
    create_src_mask, create_tgt_mask,
    PAD_IDX, BOS_IDX, EOS_IDX
)
import os

In [2]:
vocab_transform = build_vocab()

In [3]:
config["src_vocab_size"] = len(vocab_transform["de"])
config["tgt_vocab_size"] = len(vocab_transform["en"])
config["pad_idx"] = PAD_IDX
config["bos_idx"] = BOS_IDX
config["eos_idx"] = EOS_IDX

In [4]:
train_iter = list(Multi30k(split='train'))
train_loader = DataLoader(train_iter, batch_size=32, shuffle=True, collate_fn=lambda batch: collate_fn(batch, token_transform, vocab_transform))

In [5]:
model = Transformer(
    src_vocab_size=config["src_vocab_size"],
    tgt_vocab_size=config["tgt_vocab_size"],
    model_dim=config["model_dim"],
    num_heads=config["num_heads"],
    ff_dim=config["ff_dim"],
    num_layers=config["num_layers"],
    max_seq_length=config["max_seq_length"],
    dropout=config["dropout"]
)

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [6]:
start_epoch = 0
num_epochs = config["num_epochs"]
checkpoint_path = "checkpoints/checkpoint_latest.pt"

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    best_loss = checkpoint["best_loss"]
    print(f"🔁 Resumed from epoch {start_epoch}, best_loss = {best_loss:.4f}")
else:
    print("🚀 Starting training from scratch.")
    best_loss = float("inf")

# Training loop
for epoch in range(start_epoch, num_epochs):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)

    for src_batch, tgt_batch in progress_bar:
        src_mask = create_src_mask(src_batch, config["pad_idx"])
        tgt_input = tgt_batch[:, :-1]
        tgt_mask = create_tgt_mask(tgt_input, config["pad_idx"])

        logits = model(src_batch, tgt_input, src_mask, tgt_mask)
        output = logits.reshape(-1, logits.size(-1))
        target = tgt_batch[:, 1:].reshape(-1)

        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        avg_loss = epoch_loss / (progress_bar.n + 1)
        progress_bar.set_postfix(loss=f"{loss.item():.4f}", avg=f"{avg_loss:.4f}")

    tqdm.write(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")

    # Save checkpoint
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "best_loss": best_loss
    }
    os.makedirs("checkpoints", exist_ok=True)
    torch.save(checkpoint, "checkpoints/checkpoint_latest.pt")

    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), "checkpoints/transformer_best.pt")
        print(f"✅ Saved new best model with avg loss {avg_loss:.4f}")

🔁 Resumed from epoch 6, best_loss = 1.6773


Epoch 7/10: 100%|██████████| 907/907 [10:59<00:00,  1.37it/s, avg=1.2234, loss=1.3405]


Epoch 7 completed. Average Loss: 1.2234
✅ Saved new best model with avg loss 1.2234


Epoch 8/10: 100%|██████████| 907/907 [13:24<00:00,  1.13it/s, avg=1.0220, loss=0.8763]


Epoch 8 completed. Average Loss: 1.0220
✅ Saved new best model with avg loss 1.0220


Epoch 9/10: 100%|██████████| 907/907 [14:08<00:00,  1.07it/s, avg=0.8415, loss=0.8796]


Epoch 9 completed. Average Loss: 0.8415
✅ Saved new best model with avg loss 0.8415


Epoch 10/10: 100%|██████████| 907/907 [14:06<00:00,  1.07it/s, avg=0.6777, loss=0.6657]


Epoch 10 completed. Average Loss: 0.6777
✅ Saved new best model with avg loss 0.6777
