In [105]:
import torch
import torch.nn as nn
from torch import optim
import os
from models import EncoderRNN, LuongAttnDecoderRNN
from vocabulary import Voc
import itertools
import random
from nltk.translate.bleu_score import corpus_bleu

In [2]:
PAD_token = 0
SOS_token = 1
EOS_token = 2

voc = Voc("FriendsCorpus")

In [3]:
def loadPreparedData(preprocessed_file):
    pairs = []
    with open(preprocessed_file, 'r', encoding='utf-8') as file:
        for line in file:
            parts = line.strip().split('\t')
            if len(parts) == 2:
                pairs.append(parts)
            else:
                print(f"Skipping malformed line: {line.strip()}")
    return pairs

def addPairsToVoc(voc, pairs):
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])

In [4]:
pairs = loadPreparedData("preprocessed_pairs.txt")
addPairsToVoc(voc, pairs)

Skipping malformed line: me neither .
Skipping malformed line: joey you don't have to count down every time we kiss .
Skipping malformed line: i can do it okay ? come on let's go .
Skipping malformed line: i can't do it !


In [5]:
len(pairs), voc.num_words

(55486, 16041)

In [6]:
def trimRareWords(voc, pairs, min_count=3):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(min_count)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        # Only keep pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs

In [7]:
pairs = trimRareWords(voc, pairs)

keep_words 8898 / 16038 = 0.5548
Trimmed from 55486 pairs to 45901, 0.8273 of total


In [8]:
len(pairs), voc.num_words

(45901, 8901)

In [17]:
# Define training parameters and hyperparameters
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64
clip = 50.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
print_every = n_iteration // 100
save_every = 500
teacher_forcing_ratio = 0.5

In [18]:
# Choose device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models
encoder = EncoderRNN(hidden_size=hidden_size, 
                     embedding=nn.Embedding(num_embeddings=voc.num_words, embedding_dim=hidden_size),
                     n_layers=encoder_n_layers, 
                     dropout=dropout).to(device)

decoder = LuongAttnDecoderRNN(attn_model='dot', 
                              embedding=nn.Embedding(num_embeddings=voc.num_words, embedding_dim=hidden_size), 
                              hidden_size=hidden_size, 
                              output_size=voc.num_words, 
                              n_layers=decoder_n_layers, 
                              dropout=dropout).to(device)

# Initialize optimizers
print('Building optimizers...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)

Building optimizers...


In [19]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index.get(word) for word in sentence.split(' ')] + [EOS_token]

def zeroPadding(l, fillvalue=0):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=0):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == value:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.BoolTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len

In [20]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

In [21]:
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, encoder_optimizer, decoder_optimizer, batch_size, clip, max_length, teacher_forcing_ratio):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_variable = input_variable.to(device)
    lengths = lengths.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)

    loss = 0
    print_losses = []
    n_totals = 0

    encoder_outputs, encoder_hidden = encoder(input_variable, lengths, None)
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)

    decoder_hidden = encoder_hidden[:decoder.n_layers]

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden, _ = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_input = target_variable[t].view(1, -1)
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden, _ = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            _, topi = decoder_output.topk(1)

            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    loss.backward()

    _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses), n_totals

In [22]:
# Initialize print_loss for tracking progress
print_loss = 0
print_total_words = 0
losses = []
total_words = []

for iteration in range(1, n_iteration + 1):
    training_batch = [random.choice(pairs) for _ in range(batch_size)]
    # Extract fields from batch
    input_variable, lengths, target_variable, mask, max_target_len = batch2TrainData(voc, training_batch)
    
    # Run a training iteration
    loss, n_total = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                 decoder, encoder_optimizer, decoder_optimizer, batch_size, clip, device, teacher_forcing_ratio)
    
    print_loss += loss
    print_total_words += n_total
    losses.append(loss)
    total_words.append(n_total)
    
    # Print progress
    if iteration % print_every == 0:
        print_loss_avg = print_loss / print_every
        print(f"Iteration: {iteration}; Percent complete: {iteration / n_iteration * 100:.1f}%; Average loss: {print_loss_avg:.4f}")
        print_loss = 0
        print_total_words = 0

    # Save checkpoint
    if iteration % save_every == 0:
        directory = os.path.join("checkpoints")
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save({
            'iteration': iteration,
            'en': encoder.state_dict(),
            'de': decoder.state_dict(),
            'en_opt': encoder_optimizer.state_dict(),
            'de_opt': decoder_optimizer.state_dict(),
            'loss': loss,
            'voc_dict': voc.__dict__,
        }, os.path.join(directory, f'{iteration}_checkpoint.tar'))

Iteration: 40; Percent complete: 1.0%; Average loss: 6.6350
Iteration: 80; Percent complete: 2.0%; Average loss: 5.7125
Iteration: 120; Percent complete: 3.0%; Average loss: 5.6347
Iteration: 160; Percent complete: 4.0%; Average loss: 5.5997
Iteration: 200; Percent complete: 5.0%; Average loss: 5.4916
Iteration: 240; Percent complete: 6.0%; Average loss: 5.4354
Iteration: 280; Percent complete: 7.0%; Average loss: 5.4344
Iteration: 320; Percent complete: 8.0%; Average loss: 5.3948
Iteration: 360; Percent complete: 9.0%; Average loss: 5.2953
Iteration: 400; Percent complete: 10.0%; Average loss: 5.3268
Iteration: 440; Percent complete: 11.0%; Average loss: 5.2158
Iteration: 480; Percent complete: 12.0%; Average loss: 5.2519
Iteration: 520; Percent complete: 13.0%; Average loss: 5.2048
Iteration: 560; Percent complete: 14.0%; Average loss: 5.2508
Iteration: 600; Percent complete: 15.0%; Average loss: 5.2567
Iteration: 640; Percent complete: 16.0%; Average loss: 5.0851
Iteration: 680; Per

In [109]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=10):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [indexesFromSentence(voc, sentence)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    # Use appropriate device
    input_batch = input_batch.to(device)
    lengths = lengths.to("cpu")
    
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, max_length)
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words

class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        # Forward input through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            decoder_output, decoder_hidden, _ = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores

In [110]:
# Example call to evaluate function
searcher = GreedySearchDecoder(encoder, decoder)

# Input sentence
input_sentence = "how are you ?"

# Evaluate sentence
output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
output_sentence = ' '.join(output_words)

print('Input:', input_sentence)
print('Output:', output_sentence)

Input: how are you ?
Output: yeah i i EOS i EOS i EOS i EOS


Input: how are you ?
Output: yeah i was the one i was in the bathroom

In [86]:
def evaluate(encoder, decoder, voc, sentence, max_length=5, beam_width=5):
    with torch.no_grad():
        # Convert input sentence to indexes and add EOS token
        indexes_batch = [indexesFromSentence(voc, sentence)]
        lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
        input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)

        # Forward input through encoder model
        encoder_outputs, encoder_hidden = encoder(input_batch, lengths, None)

        # Create starting vectors for decoder
        decoder_input = torch.LongTensor([[SOS_token]])  # SOS
        decoder_hidden = encoder_hidden[:decoder.n_layers]

        # Initialize the Beam search
        topk = 2  # how many sentences to generate
        decoded_batch = []

        # Number of sentence to generate
        for _ in range(topk):
            # Start with the first word (SOS_token)
            decoder_input = torch.LongTensor([[SOS_token]])  # SOS

            # Number of steps to unroll
            for i in range(max_length):
                # Forward pass through decoder
                decoder_output, decoder_hidden, decoder_attention = decoder(
                    decoder_input, decoder_hidden, encoder_outputs
                )

                # Obtain word with highest probability and its index
                prob, index = torch.topk(decoder_output, beam_width)
                print(f'step {i}: prob {prob}, index {index}')
                prob = prob.detach().cpu().numpy().tolist()[0]
                index = index.detach().cpu().numpy().tolist()[0]

                # Create new set of input
                decoder_input = torch.LongTensor([[index[0]]])

                # Break if EOS token generated
                if index[0] == EOS_token:
                    break

            # Add the decoded sentence to the list
            decoded_batch.append(index)
    return decoded_batch

In [87]:
sentence = "there's nothing to tell ! he's just some guy i work with !"
output_words = evaluate(encoder, decoder, voc, sentence)
print('Input:', sentence)
for idx in range(len(output_words)):
    print(f'Output {idx}:', ' '.join([voc.index2word[token] for token in output_words[idx]]))

step 0: prob tensor([[0.1478, 0.0719, 0.0684, 0.0465, 0.0366]]), index tensor([[50, 47, 51, 12, 82]])
step 1: prob tensor([[0.2157, 0.0966, 0.0458, 0.0437, 0.0408]]), index tensor([[  7,  29,  55, 105,  36]])
step 2: prob tensor([[0.1963, 0.0717, 0.0576, 0.0390, 0.0369]]), index tensor([[  7,  29,   2, 105,  55]])
step 3: prob tensor([[0.1860, 0.0801, 0.0687, 0.0424, 0.0341]]), index tensor([[ 7,  2, 29, 55, 36]])
step 4: prob tensor([[0.1793, 0.0799, 0.0635, 0.0493, 0.0465]]), index tensor([[ 7,  2, 29, 55, 36]])
step 0: prob tensor([[0.2765, 0.0929, 0.0601, 0.0464, 0.0250]]), index tensor([[ 7, 29, 36, 55,  2]])
step 1: prob tensor([[0.1945, 0.0927, 0.0674, 0.0388, 0.0339]]), index tensor([[ 7,  2, 29, 55, 36]])
step 2: prob tensor([[0.1714, 0.1012, 0.0654, 0.0446, 0.0353]]), index tensor([[ 7,  2, 29, 55, 36]])
step 3: prob tensor([[0.1543, 0.1026, 0.0611, 0.0496, 0.0402]]), index tensor([[ 7,  2, 29, 55, 36]])
step 4: prob tensor([[0.1514, 0.1073, 0.0600, 0.0488, 0.0403]]), index t

In [97]:
voc.index2word[29]

'.'