In [None]:
from torch.utils.data import DataLoader
import torch
from torch import nn, optim
from torch.utils.data import Dataset
import torch
from tqdm import tqdm

from dataset import EngSpaDataset
from model import Seq2SeqTransformer
from utils import inference, TransformerScheduler, PadCollate, prepare_mask

In [None]:
LR = 1e-9
BETAS = [0.9, 0.98]
WARMUP_STEPS = 1
D_MODEL = 512
NUM_HEADS = 8
ENC_LAYERS = 6
DEC_LAYERS = 6
DEV = torch.device("mps")

DATASET_PATH = "eng-spa.csv"
SAVE_PATH = "/"

BATCH_SIZE = 128
EPOCHS = 100

In [None]:
dataset = EngSpaDataset(DATASET_PATH)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=PadCollate())

model = Seq2SeqTransformer(
        src_dim=len(dataset.eng2idx),
        tgt_dim=len(dataset.spa2idx),
        d_model=D_MODEL,
        num_heads=NUM_HEADS,
        enc_layers=ENC_LAYERS,
        dec_layers=DEC_LAYERS
    ).to(DEV)
model.load_state_dict(torch.load("2.pth", map_location=DEV))
model = model.eval()

In [None]:
crit = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=LR)
scheduler = TransformerScheduler(opt, WARMUP_STEPS, D_MODEL)

In [None]:
for e in range(1, EPOCHS + 1):
    loop = tqdm(enumerate(loader), total=len(loader))
    loop.set_description(f"Epoch : [{e}/{EPOCHS}]")
    total_loss = 0
    model = model.train()
    for i, ((src, src_mask), (dec_input, dec_mask), tgt) in loop:
        src, dec_input, tgt = src.to(DEV), dec_input.to(DEV), tgt.to(DEV)
        src_mask = prepare_mask(src_mask).to(DEV)
        dec_mask = prepare_mask(dec_mask, no_peek_future=True).to(DEV)

        opt.zero_grad()
        yhat = model(src, dec_input, src_mask, dec_mask)
        loss = crit(yhat.view(-1, yhat.shape[-1]), tgt.view(-1))
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        opt.step()

        total_loss += loss.item()
        loop.set_postfix(avg_loss = total_loss / (i + 1))
        scheduler.step()

    model = model.eval()
    with torch.no_grad():
        pred = inference("i like to swim", model, dataset, DEV, max_gen_len=20)
    print(f"Epoch {e} : {pred}")
    torch.save(model.state_dict(), SAVE_PATH + f"{e}.pth")