# Imports

In [1]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import os
from models import EncoderRNN, LuongAttnDecoderRNN
from vocabulary import Voc
import itertools
import random
import matplotlib.pyplot as plt
import pickle
import unicodedata
import re
import numpy as np
from utils import loadPreparedData, split_data, addPairsToVoc, trimRareWords, indexesFromSentence, zeroPadding, binaryMatrix, inputVar, outputVar, batch2TrainData, maskNLLLoss, train, validate

# Loading Data, Trimming Words & Splitting

In [2]:
# Load preprocessed pairs
pairs = loadPreparedData("preprocessed_pairs.txt")

# Split data into training and validation sets
training_pairs, validation_pairs = split_data(pairs, 0.9)

print(f"Training pairs: {len(training_pairs)}")
print(f"Validation pairs: {len(validation_pairs)}")

Skipping malformed line: me neither .
Skipping malformed line: joey you don't have to count down every time we kiss .
Skipping malformed line: i can do it okay ? come on let's go .
Skipping malformed line: i can't do it !
Training pairs: 49937
Validation pairs: 5549


In [3]:
PAD_token = 0
SOS_token = 1
EOS_token = 2

# Create vocabulary and add pairs to it
voc = Voc("FriendsCorpus", 
          PAD_token=PAD_token, 
          SOS_token=SOS_token, 
          EOS_token=EOS_token)
addPairsToVoc(voc, training_pairs)

# Trim rare words
training_pairs = trimRareWords(voc, training_pairs)

# Remove validation pairs that contain words not in the vocabulary
validation_pairs = [pair for pair in validation_pairs if all(word in voc.word2index for word in pair[0].split()) and all(word in voc.word2index for word in pair[1].split())]

print(f"Training pairs: {len(training_pairs)}")
print(f"Validation pairs: {len(validation_pairs)}")

keep_words 6110 / 15826 = 0.3861
Trimmed from 49937 pairs to 36300, 0.7269 of total
Training pairs: 36300
Validation pairs: 3859


# Training

In [10]:
# Define training parameters and hyperparameters
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64
clip = 50.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
print_every = n_iteration // 100
save_every = 500
validate_every = n_iteration // 100
teacher_forcing_ratio = 0.5

In [11]:
# Initialize models
encoder = EncoderRNN(hidden_size=hidden_size, 
                     embedding=nn.Embedding(num_embeddings=voc.num_words, embedding_dim=hidden_size),
                     n_layers=encoder_n_layers, 
                     dropout=dropout)

decoder = LuongAttnDecoderRNN(attn_model='dot', 
                              embedding=nn.Embedding(num_embeddings=voc.num_words, embedding_dim=hidden_size), 
                              hidden_size=hidden_size, 
                              output_size=voc.num_words, 
                              n_layers=decoder_n_layers, 
                              dropout=dropout)

# Initialize optimizers
print('Building optimizers...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)

Building optimizers...


In [12]:
# Initialize print_loss for tracking progress
print_loss = 0
print_total_words = 0
losses = []
total_words = []
loss_avgs = []
perplexity_scores = []
loss_avgs_val = []
perplexity_scores_val = []

for iteration in range(1, n_iteration + 1):
    training_batch = [random.choice(training_pairs) for _ in range(batch_size)]
    # Extract fields from batch
    input_variable, lengths, target_variable, mask, max_target_len = batch2TrainData(voc, training_batch)
    
    # Run a training iteration
    loss, n_total = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                 decoder, encoder_optimizer, decoder_optimizer, batch_size, clip, teacher_forcing_ratio)
    
    print_loss += loss
    print_total_words += n_total
    losses.append(loss)
    total_words.append(n_total)

    # Validation 
    if iteration % validate_every == 0:
        validation_batches = [validation_pairs[i:i+batch_size] for i in range(0, len(validation_pairs), batch_size)][:-1]
        n_validation = len(validation_batches)
        validation_loss = 0

        for i in range(n_validation):
            validation_batch = validation_batches[i]
            input_variable, lengths, target_variable, mask, max_target_len = batch2TrainData(voc, validation_batch)
            validation_loss += validate(encoder, decoder, batch_size, input_variable, lengths, target_variable, mask, max_target_len)

        validation_loss_avg = validation_loss / n_validation
        perplexity_val = torch.exp(torch.tensor(validation_loss / len(validation_pairs)))
        loss_avgs_val.append(validation_loss_avg)
        perplexity_scores_val.append(perplexity_val)

    # Print progress
    if iteration % print_every == 0:
        print_loss_avg = print_loss / print_every
        perplexity = torch.exp(torch.tensor(print_loss / print_total_words))
        loss_avgs.append(print_loss_avg)
        perplexity_scores.append(perplexity)
        print(f"Training - Iteration: {iteration}; Average loss: {print_loss_avg:.4f}")
        print_loss = 0
        print_total_words = 0

    # Save checkpoint
    if iteration % save_every == 0:
        directory = os.path.join("checkpoints")
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save({
            'iteration': iteration,
            'en': encoder.state_dict(),
            'de': decoder.state_dict(),
            'en_opt': encoder_optimizer.state_dict(),
            'de_opt': decoder_optimizer.state_dict(),
            'train loss': loss_avgs,
            'val loss': loss_avgs_val,
            'train perplexity': perplexity_scores,
            'val perplexity': perplexity_scores_val,
            'voc_dict': voc.__dict__,
        }, os.path.join(directory, f'{iteration}_checkpoint.tar'))

Validation - Iteration: 40; Average loss: 5.6308
Training - Iteration: 40; Average loss: 4541.1929; Perplexity: 514.42
Validation - Iteration: 80; Average loss: 5.6236
Training - Iteration: 80; Average loss: 4160.2117; Perplexity: 268.84
Validation - Iteration: 120; Average loss: 5.5883
Training - Iteration: 120; Average loss: 4131.7711; Perplexity: 250.30
Validation - Iteration: 160; Average loss: 5.5457
Training - Iteration: 160; Average loss: 3972.9066; Perplexity: 229.48
Validation - Iteration: 200; Average loss: 5.5174
Training - Iteration: 200; Average loss: 3798.7093; Perplexity: 216.04
Validation - Iteration: 240; Average loss: 5.4962
Training - Iteration: 240; Average loss: 3715.9675; Perplexity: 194.55
Validation - Iteration: 280; Average loss: 5.5593
Training - Iteration: 280; Average loss: 3758.5097; Perplexity: 192.72
Validation - Iteration: 320; Average loss: 5.4644
Training - Iteration: 320; Average loss: 3713.6651; Perplexity: 184.26
Validation - Iteration: 360; Average

KeyboardInterrupt: 