<a href="https://colab.research.google.com/github/enxo7899/INM706-Deep-Learning-for-Sequence-Analysis/blob/main/INM706_Seq2Seq_Machine_Translation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install sacremoses
pip install wandb

Collecting sacremoses
  Downloading sacremoses-0.1.1-py3-none-any.whl (897 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/897.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m174.1/897.5 kB[0m [31m5.0 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━[0m [32m501.8/897.5 kB[0m [31m7.3 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m890.9/897.5 kB[0m [31m8.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: sacremoses
Successfully installed sacremoses-0.1.1


In [1]:
# Load the dataset
with open('GlobalVoices.en-sq.en', 'r', encoding='utf-8') as f:
    en_sentences = f.readlines()

with open('GlobalVoices.en-sq.sq', 'r', encoding='utf-8') as f:
    sq_sentences = f.readlines()

# Print the first few lines of each dataset to understand the structure
print("English sentences sample:")
for i in range(5):
    print(en_sentences[i].strip())

print("\nAlbanian sentences sample:")
for i in range(5):
    print(sq_sentences[i].strip())

# Ensure both lists have the same length
assert len(en_sentences) == len(sq_sentences),

# Print the total number of sentences
print(f"\nTotal number of sentence pairs: {len(en_sentences)}")


English sentences sample:
South Korea: North Korean Dictator, Kim Jong Il Is Dead · Global Voices
Kim Jong Il, the North Korean dictator who ruled the hermit kingdom for the past three decades, has died at the age of 69.
According to North Korean state television's official report on Monday, Kim passed away from "mental and physical strain" during a train ride on December 17, 2011.
The South Korean Twittersphere erupted with various responses.
Although the death of one of the world's most notorious dictators is something people might welcome, most South Koreans have expressed concern about the instability his sudden death might bring to Korean peninsula.

Albanian sentences sample:
Kore: Vdes diktatori koreano-verior, Kim Jong Il
Kim Jong Il, diktatori koreano-verior, i cili sundoi me mbretërinë e izoluar gjatë tre dekadave të kaluara, vdiq në moshën 69 vjeçare.
Sipas lajmit zyrtar të emituar ditën e hënë në televizionin shtetëror koreano-verior, Kim ka ndërruar jetë si rezultat i "lod

In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import MarianTokenizer, MarianMTModel
import torch.nn as nn
import torch.optim as optim
import random
import math
import time

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Load the dataset
with open('GlobalVoices.en-sq.en', 'r', encoding='utf-8') as f:
    en_sentences = f.readlines()
with open('GlobalVoices.en-sq.sq', 'r', encoding='utf-8') as f:
    sq_sentences = f.readlines()

# Verify dataset loaded correctly
print(f"English sentences sample: {en_sentences[:5]}")
print(f"Albanian sentences sample: {sq_sentences[:5]}")
print(f"Total number of sentence pairs: {len(en_sentences)}")

# Use MarianTokenizer for tokenization
tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-sq')

class TranslationDataset(Dataset):
    def __init__(self, src_sentences, trg_sentences, tokenizer, max_length=128):
        self.src_sentences = src_sentences
        self.trg_sentences = trg_sentences
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.src_sentences)

    def __getitem__(self, idx):
        src = self.src_sentences[idx]
        trg = self.trg_sentences[idx]

        src_enc = self.tokenizer.encode_plus(
            src,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        trg_enc = self.tokenizer.encode_plus(
            trg,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'src': src_enc['input_ids'].squeeze(),
            'src_mask': src_enc['attention_mask'].squeeze(),
            'trg': trg_enc['input_ids'].squeeze(),
            'trg_mask': trg_enc['attention_mask'].squeeze()
        }

# Create the dataset objects
dataset = TranslationDataset(en_sentences, sq_sentences, tokenizer)

# Split the dataset into train and validation sets (90% train, 10% validation)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create DataLoader objects
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print("Data preprocessing complete.")

# Define the Seq2Seq model components
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hidden_dim, n_layers, dropout=dropout, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, (hidden, cell) = self.rnn(embedded)
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)))
        hidden = hidden.unsqueeze(0).repeat(2, 1, 1)
        cell = cell[-2:].contiguous()
        return outputs, hidden, cell

class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attn = nn.Linear(hidden_dim * 3, hidden_dim)
        self.v = nn.Parameter(torch.rand(hidden_dim))

    def forward(self, hidden, encoder_outputs):
        src_len = encoder_outputs.shape[0]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = torch.sum(self.v * energy, dim=2)
        return torch.softmax(attention, dim=1)

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hidden_dim, n_layers, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM((hidden_dim * 2) + emb_dim, hidden_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear((hidden_dim * 2) + hidden_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell, encoder_outputs):
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        a = self.attention(hidden[-1], encoder_outputs)
        a = a.unsqueeze(1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        weighted = torch.bmm(a, encoder_outputs)
        weighted = weighted.permute(1, 0, 2)
        rnn_input = torch.cat((embedded, weighted), dim=2)
        output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim=1))
        return prediction, hidden, cell

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        trg_len = trg.shape[0]
        batch_size = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        encoder_outputs, hidden, cell = self.encoder(src)
        input = trg[0,:]
        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell, encoder_outputs)
            outputs[t] = output
            top1 = output.argmax(1)
            input = trg[t] if random.random() < teacher_forcing_ratio else top1
        return outputs

# Model hyperparameters
INPUT_DIM = tokenizer.vocab_size
OUTPUT_DIM = tokenizer.vocab_size
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

# Initialize encoder, attention, decoder, and seq2seq model
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT).to(device)
attn = Attention(HID_DIM).to(device)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT, attn).to(device)
model = Seq2Seq(enc, dec, device).to(device)

# Loss and optimizer
optimizer = optim.Adam(model.parameters())
TRG_PAD_IDX = tokenizer.pad_token_id
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    epoch_acc = 0  # Initialize epoch accuracy
    for i, batch in enumerate(iterator):
        src = batch['src'].T.to(device)
        trg = batch['trg'].T.to(device)
        optimizer.zero_grad()
        output = model(src, trg)
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].reshape(-1)
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()

        # Calculate accuracy
        preds = output.argmax(1)
        non_pad_elements = (trg != TRG_PAD_IDX).nonzero().squeeze()
        correct = preds[non_pad_elements].eq(trg[non_pad_elements]).sum().item()
        acc = correct / len(non_pad_elements)
        epoch_acc += acc

        # Print some batches
        if i % 10 == 0:
            print(f'Batch {i} | Loss: {loss.item():.3f} | Accuracy: {acc:.3f}')

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    epoch_acc = 0
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch['src'].T.to(device)
            trg = batch['trg'].T.to(device)
            output = model(src, trg, 0)
            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].reshape(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()

            # Calculate accuracy
            preds = output.argmax(1)
            non_pad_elements = (trg != TRG_PAD_IDX).nonzero().squeeze()
            correct = preds[non_pad_elements].eq(trg[non_pad_elements]).sum().item()
            acc = correct / len(non_pad_elements)
            epoch_acc += acc

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

N_EPOCHS = 10
CLIP = 1

for epoch in range(N_EPOCHS):
    start_time = time.time()
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, CLIP)
    valid_loss, valid_acc = evaluate(model, val_loader, criterion)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f} | Train Acc: {train_acc:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f} |  Val. Acc: {valid_acc:.3f}')


Using device: cuda
English sentences sample: ['South Korea: North Korean Dictator, Kim Jong Il Is Dead · Global Voices\n', 'Kim Jong Il, the North Korean dictator who ruled the hermit kingdom for the past three decades, has died at the age of 69.\n', 'According to North Korean state television\'s official report on Monday, Kim passed away from "mental and physical strain" during a train ride on December 17, 2011.\n', 'The South Korean Twittersphere erupted with various responses.\n', "Although the death of one of the world's most notorious dictators is something people might welcome, most South Koreans have expressed concern about the instability his sudden death might bring to Korean peninsula.\n"]
Albanian sentences sample: ['Kore: Vdes diktatori koreano-verior, Kim Jong Il\n', 'Kim Jong Il, diktatori koreano-verior, i cili sundoi me mbretërinë e izoluar gjatë tre dekadave të kaluara, vdiq në moshën 69 vjeçare.\n', 'Sipas lajmit zyrtar të emituar ditën e hënë në televizionin shtetëro

In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import MarianTokenizer, MarianMTModel
import torch.nn as nn
import torch.optim as optim
import random
import math
import time
import wandb
import os

# Set the notebook name
os.environ["WANDB_NOTEBOOK_NAME"] = "INM706-Seq2Seq_Machine_Translation.ipynb"

# Login with the API KEY
wandb.login(key="9ce954fd827fd8d839648cb3708ff788ad51bafa")

# Initialize wandb run
wandb.init(project='Translator', name='English-Albanian')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Load the dataset
with open('GlobalVoices.en-sq.en', 'r', encoding='utf-8') as f:
    en_sentences = f.readlines()
with open('GlobalVoices.en-sq.sq', 'r', encoding='utf-8') as f:
    sq_sentences = f.readlines()

# Verify dataset loaded correctly
print(f"English sentences sample: {en_sentences[:5]}")
print(f"Albanian sentences sample: {sq_sentences[:5]}")
print(f"Total number of sentence pairs: {len(en_sentences)}")

# Use MarianTokenizer for tokenization
tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-sq')

class TranslationDataset(Dataset):
    def __init__(self, src_sentences, trg_sentences, tokenizer, max_length=128):
        self.src_sentences = src_sentences
        self.trg_sentences = trg_sentences
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.src_sentences)

    def __getitem__(self, idx):
        src = self.src_sentences[idx]
        trg = self.trg_sentences[idx]

        src_enc = self.tokenizer.encode_plus(
            src,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        trg_enc = self.tokenizer.encode_plus(
            trg,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'src': src_enc['input_ids'].squeeze(),
            'src_mask': src_enc['attention_mask'].squeeze(),
            'trg': trg_enc['input_ids'].squeeze(),
            'trg_mask': trg_enc['attention_mask'].squeeze()
        }

# Create the dataset objects
dataset = TranslationDataset(en_sentences, sq_sentences, tokenizer)

# Split the dataset into train and validation sets (90% train, 10% validation)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create DataLoader objects
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print("Data preprocessing complete.")

# Define the Seq2Seq model components
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hidden_dim, n_layers, dropout=dropout, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, (hidden, cell) = self.rnn(embedded)
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)))
        hidden = hidden.unsqueeze(0).repeat(2, 1, 1)
        cell = cell[-2:].contiguous()
        return outputs, hidden, cell

class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attn = nn.Linear(hidden_dim * 3, hidden_dim)
        self.v = nn.Parameter(torch.rand(hidden_dim))

    def forward(self, hidden, encoder_outputs):
        src_len = encoder_outputs.shape[0]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = torch.sum(self.v * energy, dim=2)
        return torch.softmax(attention, dim=1)

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hidden_dim, n_layers, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM((hidden_dim * 2) + emb_dim, hidden_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear((hidden_dim * 2) + hidden_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell, encoder_outputs):
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        a = self.attention(hidden[-1], encoder_outputs)
        a = a.unsqueeze(1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        weighted = torch.bmm(a, encoder_outputs)
        weighted = weighted.permute(1, 0, 2)
        rnn_input = torch.cat((embedded, weighted), dim=2)
        output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim=1))
        return prediction, hidden, cell

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        trg_len = trg.shape[0]
        batch_size = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        encoder_outputs, hidden, cell = self.encoder(src)
        input = trg[0,:]
        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell, encoder_outputs)
            outputs[t] = output
            top1 = output.argmax(1)
            input = trg[t] if random.random() < teacher_forcing_ratio else top1
        return outputs

# Model hyperparameters
INPUT_DIM = tokenizer.vocab_size
OUTPUT_DIM = tokenizer.vocab_size
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

# Initialize wandb configuration
wandb.config.update({
    "learning_rate": 1e-3,
    "epochs": 10,
    "batch_size": 32,
    "encoder_embedding_dim": ENC_EMB_DIM,
    "decoder_embedding_dim": DEC_EMB_DIM,
    "hidden_dim": HID_DIM,
    "num_layers": N_LAYERS,
    "encoder_dropout": ENC_DROPOUT,
    "decoder_dropout": DEC_DROPOUT
})

# Initialize encoder, attention, decoder, and seq2seq model
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT).to(device)
attn = Attention(HID_DIM).to(device)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT, attn).to(device)
model = Seq2Seq(enc, dec, device).to(device)

# Loss and optimizer
optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate)
TRG_PAD_IDX = tokenizer.pad_token_id
criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    for i, batch in enumerate(iterator):
        src = batch['src'].T.to(device)
        trg = batch['trg'].T.to(device)
        optimizer.zero_grad()
        output = model(src, trg)
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].reshape(-1)
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()

        # Calculate accuracy
        preds = output.argmax(1)
        non_pad_elements = (trg != TRG_PAD_IDX).nonzero().squeeze()
        correct = preds[non_pad_elements].eq(trg[non_pad_elements]).sum().item()
        acc = correct / len(non_pad_elements)
        epoch_acc += acc

        # Log metrics to wandb
        wandb.log({"batch_loss": loss.item(), "batch_accuracy": acc})

        # Print some batches
        if i % 10 == 0:
            print(f'Batch {i} | Loss: {loss.item():.3f} | Accuracy: {acc:.3f}')

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    epoch_acc = 0
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch['src'].T.to(device)
            trg = batch['trg'].T.to(device)
            output = model(src, trg, 0)
            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].reshape(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()

            # Calculate accuracy
            preds = output.argmax(1)
            non_pad_elements = (trg != TRG_PAD_IDX).nonzero().squeeze()
            correct = preds[non_pad_elements].eq(trg[non_pad_elements]).sum().item()
            acc = correct / len(non_pad_elements)
            epoch_acc += acc

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

N_EPOCHS = wandb.config.epochs
CLIP = 1

for epoch in range(N_EPOCHS):
    start_time = time.time()
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, CLIP)
    valid_loss, valid_acc = evaluate(model, val_loader, criterion)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f} | Train Acc: {train_acc:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f} |  Val. Acc: {valid_acc:.3f}')

    # Log epoch metrics to wandb
    wandb.log({"train_loss": train_loss, "train_accuracy": train_acc,
               "valid_loss": valid_loss, "valid_accuracy": valid_acc,
               "epoch": epoch + 1, "epoch_time_mins": epoch_mins, "epoch_time_secs": epoch_secs})
