# Imports

In [1]:
import torch
import torch.nn as nn
from torch import optim
import os
from models import EncoderRNN, LuongAttnDecoderRNN
from vocabulary import Voc
import random
from utils import loadPreparedData, batch2TrainData, train, validate

# Loading Data

In [5]:
# Load preprocessed pairs
training_pairs = loadPreparedData("data/training_pairs.txt")
validation_pairs = loadPreparedData("data/validation_pairs.txt")

In [6]:
# Reconstruct vocabulary
with open("data/voc.txt", "r") as f:
    voc_dict = f.read()
voc_dict = eval(voc_dict)
voc = Voc(voc_dict['name'])
voc.__dict__ = voc_dict

# Training

In [4]:
# Define training parameters and hyperparameters
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64
clip = 50.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
validate_print_save = n_iteration // 100
teacher_forcing_ratio = 0.5

In [5]:
# Initialize models
encoder = EncoderRNN(hidden_size=hidden_size, 
                     embedding=nn.Embedding(num_embeddings=voc.num_words, embedding_dim=hidden_size),
                     n_layers=encoder_n_layers, 
                     dropout=dropout)

decoder = LuongAttnDecoderRNN(attn_model='dot', 
                              embedding=nn.Embedding(num_embeddings=voc.num_words, embedding_dim=hidden_size), 
                              hidden_size=hidden_size, 
                              output_size=voc.num_words, 
                              n_layers=decoder_n_layers, 
                              dropout=dropout)

# Initialize optimizers
print('Building optimizers...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)

Building optimizers...


In [6]:
# Initialize print_loss for tracking progress
print_loss = 0
print_total_words = 0
losses = []
total_words = []
loss_avgs = []
perplexity_scores = []
loss_avgs_val = []
perplexity_scores_val = []
    
for iteration in range(1, n_iteration + 1):
    training_batch = [random.choice(training_pairs) for _ in range(batch_size)]
    
    # Extract fields from batch
    input_variable, lengths, target_variable, mask, max_target_len = batch2TrainData(voc, training_batch)
    
    # Run a training iteration
    loss, n_total = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                 decoder, encoder_optimizer, decoder_optimizer, batch_size, clip, teacher_forcing_ratio)
    
    print_loss += loss
    print_total_words += n_total
    losses.append(loss)
    total_words.append(n_total)

    if iteration % validate_print_save == 0:
        # Validation 
        encoder.eval()
        decoder.eval()

        validation_batches = [validation_pairs[i:i+batch_size] for i in range(0, len(validation_pairs), batch_size)][:-1]
        n_validation = len(validation_batches)
        validation_loss = 0
        val_n_total = 0

        for i in range(n_validation):
            validation_batch = validation_batches[i]
            input_variable, lengths, target_variable, mask, max_target_len = batch2TrainData(voc, validation_batch)
            val_loss, n_total = validate(encoder, decoder, batch_size, input_variable, lengths, target_variable, mask, max_target_len)
            validation_loss += val_loss
            val_n_total += n_total

        encoder.train()
        decoder.train()

        validation_loss_avg = validation_loss / val_n_total
        perplexity_val = torch.exp(torch.tensor(validation_loss / val_n_total))
        loss_avgs_val.append(validation_loss_avg)
        perplexity_scores_val.append(perplexity_val)

        # Print progress
        print_loss_avg = print_loss / validate_print_save
        perplexity = torch.exp(torch.tensor(print_loss / print_total_words))
        loss_avgs.append(print_loss_avg)
        perplexity_scores.append(perplexity)
        print(f"Iter: {iteration}; Train loss: {print_loss_avg:.4f}; Val loss: {validation_loss_avg:.4f}; Train perplexity: {perplexity:.4f}; Val perplexity: {perplexity_val:.4f}")
        print_loss = 0
        print_total_words = 0

        # Save checkpoint
        directory = os.path.join("checkpoints")
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save({
            'iteration': iteration,
            'en': encoder.state_dict(),
            'de': decoder.state_dict(),
            'en_opt': encoder_optimizer.state_dict(),
            'de_opt': decoder_optimizer.state_dict(),
            'train loss': loss_avgs,
            'val loss': loss_avgs_val,
            'train perplexity': perplexity_scores,
            'val perplexity': perplexity_scores_val,
            'voc_dict': voc.__dict__,
        }, os.path.join(directory, f'{iteration}_checkpoint.tar'))

Iter: 40; Train loss: 4601.6239; Val loss: 5.6065; Train perplexity: 556.0798; Val perplexity: 272.2034
Iter: 80; Train loss: 3984.0238; Val loss: 5.6606; Train perplexity: 249.9569; Val perplexity: 287.3350
Iter: 120; Train loss: 4042.6906; Val loss: 5.5603; Train perplexity: 240.9499; Val perplexity: 259.9080
Iter: 160; Train loss: 4110.4025; Val loss: 5.5140; Train perplexity: 238.5346; Val perplexity: 248.1417
Iter: 200; Train loss: 4005.0571; Val loss: 5.5206; Train perplexity: 226.1534; Val perplexity: 249.7895
Iter: 240; Train loss: 3949.2191; Val loss: 5.4492; Train perplexity: 219.5968; Val perplexity: 232.5762
Iter: 280; Train loss: 3809.3168; Val loss: 5.4466; Train perplexity: 199.1803; Val perplexity: 231.9658
Iter: 320; Train loss: 4025.4761; Val loss: 5.4312; Train perplexity: 191.8540; Val perplexity: 228.4186
Iter: 360; Train loss: 3879.2387; Val loss: 5.4230; Train perplexity: 185.8782; Val perplexity: 226.5617
Iter: 400; Train loss: 3744.7536; Val loss: 5.4126; Train