# Imports

In [1]:
import torch
import torch.nn as nn
from torch import optim
import os
from models import EncoderRNN, LuongAttnDecoderRNN
from vocabulary import Voc
import random
from utils import loadPreparedData, split_data, addPairsToVoc, trimRareWords, batch2TrainData, train, validate

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\cathe\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


# Loading Data, Trimming Words & Splitting

In [2]:
# Load preprocessed pairs
pairs = loadPreparedData("preprocessed_pairs.txt")

# Split data into training and validation sets
training_pairs, validation_pairs = split_data(pairs, 0.9)

print(f"Training pairs: {len(training_pairs)}")
print(f"Validation pairs: {len(validation_pairs)}")

Skipping malformed line: me neither .
Skipping malformed line: joey you don't have to count down every time we kiss .
Skipping malformed line: i can do it okay ? come on let's go .
Skipping malformed line: i can't do it !
Training pairs: 49937
Validation pairs: 5549


In [3]:
PAD_token = 0
SOS_token = 1
EOS_token = 2

# Create vocabulary and add pairs to it
voc = Voc("FriendsCorpus", 
          PAD_token=PAD_token, 
          SOS_token=SOS_token, 
          EOS_token=EOS_token)
addPairsToVoc(voc, training_pairs)

# Trim rare words
training_pairs = trimRareWords(voc, training_pairs)

# Remove validation pairs that contain words not in the vocabulary
validation_pairs = [pair for pair in validation_pairs if all(word in voc.word2index for word in pair[0].split()) and all(word in voc.word2index for word in pair[1].split())]

print(f"Training pairs: {len(training_pairs)}")
print(f"Validation pairs: {len(validation_pairs)}")

keep_words 6137 / 15813 = 0.3881
Trimmed from 49937 pairs to 36294, 0.7268 of total
Training pairs: 36294
Validation pairs: 3885


# Training

In [4]:
# Define training parameters and hyperparameters
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64
clip = 50.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
validate_print_save = n_iteration // 100
teacher_forcing_ratio = 0.5

In [5]:
# Initialize models
encoder = EncoderRNN(hidden_size=hidden_size, 
                     embedding=nn.Embedding(num_embeddings=voc.num_words, embedding_dim=hidden_size),
                     n_layers=encoder_n_layers, 
                     dropout=dropout)

decoder = LuongAttnDecoderRNN(attn_model='dot', 
                              embedding=nn.Embedding(num_embeddings=voc.num_words, embedding_dim=hidden_size), 
                              hidden_size=hidden_size, 
                              output_size=voc.num_words, 
                              n_layers=decoder_n_layers, 
                              dropout=dropout)

# Initialize optimizers
print('Building optimizers...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)

Building optimizers...


In [6]:
# Initialize print_loss for tracking progress
print_loss = 0
print_total_words = 0
losses = []
total_words = []
loss_avgs = []
perplexity_scores = []
loss_avgs_val = []
perplexity_scores_val = []
bleu_scores = []

for iteration in range(1, n_iteration + 1):
    training_batch = [random.choice(training_pairs) for _ in range(batch_size)]
    
    # Extract fields from batch
    input_variable, lengths, target_variable, mask, max_target_len = batch2TrainData(voc, training_batch)
    
    # Run a training iteration
    loss, n_total = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                 decoder, encoder_optimizer, decoder_optimizer, batch_size, clip, teacher_forcing_ratio)
    
    print_loss += loss
    print_total_words += n_total
    losses.append(loss)
    total_words.append(n_total)

    if iteration % validate_print_save == 0:
        # Validation 
        encoder.eval()
        decoder.eval()
        validation_loss, n_total, avg_bleu_score = validate(encoder, decoder, voc, validation_pairs, batch_size)
        encoder.train()
        decoder.train()

        validation_loss_avg = validation_loss / n_total
        perplexity_val = torch.exp(torch.tensor(validation_loss / n_total))
        loss_avgs_val.append(validation_loss_avg)
        perplexity_scores_val.append(perplexity_val)
        bleu_scores.append(avg_bleu_score)

        # Print progress
        print_loss_avg = print_loss / validate_print_save
        perplexity = torch.exp(torch.tensor(print_loss / print_total_words))
        loss_avgs.append(print_loss_avg)
        perplexity_scores.append(perplexity)
        print(f"Iter: {iteration}; Train loss: {print_loss_avg:.4f}; Val loss: {validation_loss_avg:.4f}; Train perplexity: {perplexity:.4f}; Val perplexity: {perplexity_val:.4f}; BLEU: {avg_bleu_score:.4f}")
        print_loss = 0
        print_total_words = 0

        # Save checkpoint
        directory = os.path.join("checkpoints")
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save({
            'iteration': iteration,
            'en': encoder.state_dict(),
            'de': decoder.state_dict(),
            'en_opt': encoder_optimizer.state_dict(),
            'de_opt': decoder_optimizer.state_dict(),
            'train loss': loss_avgs,
            'val loss': loss_avgs_val,
            'train perplexity': perplexity_scores,
            'val perplexity': perplexity_scores_val,
            'bleu': bleu_scores,
            'voc_dict': voc.__dict__,
        }, os.path.join(directory, f'{iteration}_checkpoint.tar'))