In [None]:
import torch
import torch.nn as nn
import numpy as np
import time
import os

def train_model(model, train_iterator, optimizer, criterion, trg_vocab_size, device):
    loss_track = []
    stepLoss = []
    model.train()

    for batch in train_iterator:
        input_sentence = batch.de.to(device)
        trg = batch.en.to(device)

        optimizer.zero_grad()
        out = model(input_sentence, trg[:-1])
        out = out.reshape(-1, trg_vocab_size)
        trg = trg[1:].reshape(-1)
        loss = criterion(out, trg)

        loss.backward()
        optimizer.step()
        stepLoss.append(loss.item())

    loss_track.append(np.mean(stepLoss))

    return loss_track

def valid_model(model, valid_iterator, optimizer, criterion, trg_vocab_size, device):
    loss_validation_track = []
    stepValidLoss = []
    model.eval()

    with torch.no_grad():
      for batch in valid_iterator:
          input_sentence = batch.de.to(device)
          trg = batch.en.to(device)

          optimizer.zero_grad()
          out = model(input_sentence, trg[:-1])
          out = out.reshape(-1, trg_vocab_size)
          trg = trg[1:].reshape(-1)
          loss = criterion(out, trg)

          stepValidLoss.append(loss.item())

    loss_validation_track.append(np.mean(stepValidLoss))

    return loss_validation_track

def train_f(model, train_iterator, valid_iterator, optimizer, criterion, device, checkpoint_dir, epochs, save_frequency, trg_vocab_size):
    for e in range(epochs):
        if e % 2 == 0:
            start_time = time.time()

        train_loss = train_model(model, train_iterator, optimizer, criterion, trg_vocab_size, device)
        val_loss = valid_model(model, valid_iterator, optimizer, criterion, trg_vocab_size, device)

        print(f'Epoch {e + 1}/{epochs}')
        print('-' * 10)
        print("Train Loss at epoch {}: {:.4f}".format(e + 1, train_loss[0]))
        print("Valid Loss at epoch {}: {:.4f}".format(e + 1, val_loss[0]))

        if e % 2 == 0:
            end_time = time.time()
            epoch_time = end_time - start_time
            print("Time for epoch {}: {:.2f} seconds".format(e + 1, epoch_time))

        if (e + 1) % save_frequency == 0:
            checkpoint_name = os.path.join(checkpoint_dir, f'checkpoint_epoch_{e + 1}.pt')
            torch.save(model.state_dict(), checkpoint_name)
            print(f"Saved checkpoint for epoch {e + 1} at {checkpoint_name}")

    return train_loss, val_loss

