In [6]:
import torch
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader

from dataset import EngSpaDataset
from model import Seq2SeqTransformer
from utils import inference, TransformerScheduler, PadCollate, prepare_mask, save_state, load_state

In [7]:
# Training parameters
LR = 1e-9
BATCH_SIZE = 128
BETAS = [0.9, 0.98]
WARMUP_STEPS = 4000
START_EPOCH = 1
EPOCHS = 100
CLIP_VALUE = 0.5

# Model hyperparameters
D_MODEL = 512
NUM_HEADS = 8
ENC_LAYERS = 6
DEC_LAYERS = 6

# Other
DATASET_PATH = "eng-spa.csv"
SAVE_PATH = ""
DEV = torch.device("mps")


dataset = EngSpaDataset(DATASET_PATH)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=PadCollate())
crit = nn.CrossEntropyLoss()

In [8]:
model = Seq2SeqTransformer(
        src_dim=len(dataset.eng2idx),
        tgt_dim=len(dataset.spa2idx),
        d_model=D_MODEL,
        num_heads=NUM_HEADS,
        enc_layers=ENC_LAYERS,
        dec_layers=DEC_LAYERS
    ).to(DEV)

opt = optim.Adam(model.parameters(), lr=LR, betas=BETAS)
scheduler = TransformerScheduler(WARMUP_STEPS, D_MODEL)

# model, opt, scheduler, START_EPOCH = load_state(SAVE_PATH + "transformer-5.pth", DEV)

In [9]:
for e in range(START_EPOCH, EPOCHS + 1):
    loop = tqdm(enumerate(loader), total=len(loader))
    loop.set_description(f"Epoch : [{e}/{EPOCHS}]")

    total_loss = 0
    model = model.train()

    for i, ((src, src_mask), (dec_input, dec_mask), tgt) in loop:
        # Prepare inputs
        src, dec_input, tgt = src.to(DEV), dec_input.to(DEV), tgt.to(DEV)
        src_mask = prepare_mask(src_mask).to(DEV)
        dec_mask = prepare_mask(dec_mask, no_peek_future=True).to(DEV)

        # Forward and backward pass
        opt.zero_grad()
        yhat = model(src, dec_input, src_mask, dec_mask)
        loss = crit(yhat.view(-1, yhat.shape[-1]), tgt.view(-1))
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), CLIP_VALUE)
        opt.step()
        scheduler.step(model, opt)

        total_loss += loss.item()
        loop.set_postfix(avg_loss = total_loss / (i + 1))

    # sample for eval
    model.epoch += 1
    model = model.eval()
    with torch.no_grad():
        pred = inference("hello, how are you today?", model, dataset, DEV)
    print(f"Epoch {e} : {pred}")
    save_state(SAVE_PATH + f"model-{e}.pth", model)

Epoch : [1/100]:   1%|          | 10/1092 [00:03<07:09,  2.52it/s, avg_loss=183]


KeyboardInterrupt: 