In [None]:
import torch
import time
import math
import argparse

# -Train- #
def train(model, dataloader, optimizer, criterion, args):
    model.train()
    epoch_loss = 0

    for batch in dataloader:
        src, trg = batch
        optimizer.zero_grad()
        output, _ = model(src, trg[:, :-1])
        output_dim = output.shape[-1]
        output = output.contiguous().view(-1, output_dim)
        trg = trg[:, 1:].contiguous().view(-1)
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss / len(dataloader)

# Training loop

def run_training_loop(model, train_loader, optimizer, criterion, args):
    for epoch in range(args.n_epochs):
        start_time = time.time()

        train_loss = train(model, train_loader, optimizer, criterion, args)

        end_time = time.time()
        epoch_mins, epoch_secs = divmod(end_time - start_time, 60)

        print(f'Epoch: {epoch + 1:02} | Time: {epoch_mins}m {epoch_secs:.2f}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):.3f}')