In [None]:
from torch.utils.data import DataLoader
from torch import nn, optim
from tqdm import tqdm
import torch

from dataset import EngSpaDataset
from utils import sample_translation
from model import BiGRUEncoder, GRUDecoder

In [None]:
NUM_SENTENCES = 100000
BATCH_SIZE = 128
HIDDEN_SIZE = 512
LR = 3e-4
MAX_LEN = 30
EPOCHS = 10
DECAY = 0.6
DEV = torch.device("cpu")

In [None]:
dataset = EngSpaDataset("eng-spa.csv", "nmt_glove.pkl", end_idx=NUM_SENTENCES)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

encoder = BiGRUEncoder(input_size=300, hidden_size=HIDDEN_SIZE).to(DEV)
decoder = GRUDecoder(HIDDEN_SIZE, len(dataset.spa2idx), MAX_LEN, device=DEV).to(DEV)

In [None]:
crit = nn.CrossEntropyLoss()
enc_opt = optim.Adam(encoder.parameters(), lr=LR)
dec_opt = optim.Adam(decoder.parameters(), lr=LR)

decay_fn = lambda epoch: DECAY
enc_scheduler = optim.lr_scheduler.MultiplicativeLR(enc_opt, decay_fn)
dec_scheduler = optim.lr_scheduler.MultiplicativeLR(dec_opt, decay_fn)

In [None]:
for e in range(1, EPOCHS + 1):
    encoder.train()
    decoder.train()
    loop = tqdm(enumerate(loader), total=len(loader), position=0)
    loop.set_description(f"Epoch : [{e}/{EPOCHS}]")
    total_loss = 0
    for i, ((x, src_mask), (y, tgt_mask)) in loop:
        batch_size, seq_len = y.shape
        x, y = x.to(DEV), y.to(DEV)
        encoder.zero_grad()
        decoder.zero_grad()

        encoded_source = encoder(x)
        yhat = decoder(encoded_source, y)

        loss = crit(yhat[:, :seq_len, :].reshape(-1, yhat.shape[-1]), y.view(-1))
        loss.backward()
        enc_opt.step()
        dec_opt.step()

        total_loss += loss.item()
        loop.set_postfix(loss = total_loss / (i+1))

    enc_scheduler.step()
    dec_scheduler.step()
    
    encoder.eval()
    decoder.eval()

    with torch.no_grad():
        print(sample_translation(encoder, decoder, dataset, "i like to swim everyday", DEV))

    torch.save(encoder.state_dict(), f"{e}enc.pth")
    torch.save(decoder.state_dict(), f"{e}dec.pth")