In [2]:
%matplotlib inline

In [3]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math

from encoder import EncoderRNN
from decoder import LuongAttnDecoderRNN
from movie_line_process import loadLines, loadConversations, extractSentencePairs
from voc import loadPrepareData, trimRareWords, normalizeString
from voc import MIN_COUNT, MAX_LENGTH, PAD_token, SOS_token, EOS_token
from prepare_data import indexesFromSentence, batch2TrainData
from train import trainIters
from model_config import model_name, attn_model, hidden_size
from model_config import encoder_n_layers, decoder_n_layers, dropout, batch_size
from model_config import device, loadFilename, checkpoint_iter
from train_config import clip, learning_rate, decoder_learning_ratio, n_iteration
from train_config import print_every, save_every
from evaluate import GreedySearchDecoder, evaluateInput

In [4]:
corpus_name = "cornell movie-dialogs corpus"
corpus = os.path.join("data", corpus_name)

def printLines(file, n=10):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

printLines(os.path.join(corpus, "movie_lines.txt"))

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


In [5]:
# Define path to new file
datafile = os.path.join(corpus, "formatted_movie_lines.txt")

delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

# Initialize lines dict, conversations list, and field ids
lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

# Load lines and process conversations
print("\nProcessing corpus...")
lines = loadLines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS)
print("\nLoading conversations...")
conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"),
                                  lines, MOVIE_CONVERSATIONS_FIELDS)

# Write new csv file
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)

# Print a sample of lines
print("\nSample lines from file:")
printLines(datafile)


Processing corpus...

Loading conversations...

Writing newly formatted file...

Sample lines from file:
b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't dat

In [7]:
# Load/Assemble voc and pairs
save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
    print(pair)


Start preparing training data ...
Reading lines...
Read 221282 sentence pairs
Trimmed to 64271 sentence pairs
Counting words...
Counted words: 18008

pairs:
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [8]:
# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 7823 / 18005 = 0.4345
Trimmed from 64271 pairs to 53165, 0.8272 of total


In [9]:
# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[7402, 2800, 6883,  900, 6883],
        [4085,  845, 6230, 4350, 3667],
        [7589, 7435, 2154, 2876, 7463],
        [3609,  283, 1887, 7463,  845],
        [ 845, 1183,  637, 4004, 1683],
        [ 907, 7551, 6230, 4513,    2],
        [7435, 4513, 3945,    2,    0],
        [7332, 7419,    2,    0,    0],
        [4513, 1683,    0,    0,    0],
        [   2,    2,    0,    0,    0]])
lengths: tensor([10, 10,  8,  7,  6])
target_variable: tensor([[7589, 5565, 4958, 2876, 7589],
        [3702,  969,    2, 4513, 6905],
        [4513, 1683,    0,    2, 3803],
        [   2, 7582,    0,    0, 4056],
        [   0, 4513,    0,    0, 4513],
        [   0,    2,    0,    0,    2]])
mask: tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 0, 1, 1],
        [1, 1, 0, 0, 1],
        [0, 1, 0, 0, 1],
        [0, 1, 0, 0, 1]], dtype=torch.uint8)
max_target_len: 6


In [10]:
# Load model if a loadFilename is provided
if loadFilename:
    # If loading on same machine the model was trained on
    checkpoint = torch.load(loadFilename)
    # If loading a model trained on GPU to CPU
    #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']


print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Building encoder and decoder ...
Models built and ready to go!


In [11]:
# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

# Initialize optimizers
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)

# Run training iterations
print("Starting Training!")
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building optimizers ...
Starting Training!
Initializing ...
Training...
Iteration: 1; Percent complete: 0.0%; Average loss: 8.9672
Iteration: 2; Percent complete: 0.1%; Average loss: 8.8414
Iteration: 3; Percent complete: 0.1%; Average loss: 8.6743
Iteration: 4; Percent complete: 0.1%; Average loss: 8.3528
Iteration: 5; Percent complete: 0.1%; Average loss: 7.9295
Iteration: 6; Percent complete: 0.1%; Average loss: 7.3426
Iteration: 7; Percent complete: 0.2%; Average loss: 6.7936
Iteration: 8; Percent complete: 0.2%; Average loss: 6.7994
Iteration: 9; Percent complete: 0.2%; Average loss: 6.8104
Iteration: 10; Percent complete: 0.2%; Average loss: 6.6114
Iteration: 11; Percent complete: 0.3%; Average loss: 6.0325
Iteration: 12; Percent complete: 0.3%; Average loss: 5.8126
Iteration: 13; Percent complete: 0.3%; Average loss: 5.8437
Iteration: 14; Percent complete: 0.4%; Average loss: 5.5351
Iteration: 15; Percent complete: 0.4%; Average loss: 5.7294
Iteration: 16; Percent complete: 0.4%

Iteration: 136; Percent complete: 3.4%; Average loss: 4.0892
Iteration: 137; Percent complete: 3.4%; Average loss: 4.1628
Iteration: 138; Percent complete: 3.5%; Average loss: 4.5382
Iteration: 139; Percent complete: 3.5%; Average loss: 4.4526
Iteration: 140; Percent complete: 3.5%; Average loss: 4.1813
Iteration: 141; Percent complete: 3.5%; Average loss: 4.4443
Iteration: 142; Percent complete: 3.5%; Average loss: 4.2510
Iteration: 143; Percent complete: 3.6%; Average loss: 4.0126
Iteration: 144; Percent complete: 3.6%; Average loss: 4.1882
Iteration: 145; Percent complete: 3.6%; Average loss: 4.3367
Iteration: 146; Percent complete: 3.6%; Average loss: 4.4069
Iteration: 147; Percent complete: 3.7%; Average loss: 4.2325
Iteration: 148; Percent complete: 3.7%; Average loss: 4.2469
Iteration: 149; Percent complete: 3.7%; Average loss: 4.3869
Iteration: 150; Percent complete: 3.8%; Average loss: 4.4267
Iteration: 151; Percent complete: 3.8%; Average loss: 4.1921
Iteration: 152; Percent 

Iteration: 272; Percent complete: 6.8%; Average loss: 3.9734
Iteration: 273; Percent complete: 6.8%; Average loss: 3.9466
Iteration: 274; Percent complete: 6.9%; Average loss: 4.0678
Iteration: 275; Percent complete: 6.9%; Average loss: 3.7542
Iteration: 276; Percent complete: 6.9%; Average loss: 4.3542
Iteration: 277; Percent complete: 6.9%; Average loss: 4.1220
Iteration: 278; Percent complete: 7.0%; Average loss: 4.1248
Iteration: 279; Percent complete: 7.0%; Average loss: 3.7764
Iteration: 280; Percent complete: 7.0%; Average loss: 3.8323
Iteration: 281; Percent complete: 7.0%; Average loss: 3.9930
Iteration: 282; Percent complete: 7.0%; Average loss: 3.9599
Iteration: 283; Percent complete: 7.1%; Average loss: 4.0044
Iteration: 284; Percent complete: 7.1%; Average loss: 4.1102
Iteration: 285; Percent complete: 7.1%; Average loss: 3.6955
Iteration: 286; Percent complete: 7.1%; Average loss: 3.8666
Iteration: 287; Percent complete: 7.2%; Average loss: 3.8124
Iteration: 288; Percent 

Iteration: 408; Percent complete: 10.2%; Average loss: 3.9573
Iteration: 409; Percent complete: 10.2%; Average loss: 3.7308
Iteration: 410; Percent complete: 10.2%; Average loss: 3.7733
Iteration: 411; Percent complete: 10.3%; Average loss: 4.0984
Iteration: 412; Percent complete: 10.3%; Average loss: 3.9018
Iteration: 413; Percent complete: 10.3%; Average loss: 4.1333
Iteration: 414; Percent complete: 10.3%; Average loss: 3.8529
Iteration: 415; Percent complete: 10.4%; Average loss: 3.7287
Iteration: 416; Percent complete: 10.4%; Average loss: 3.9228
Iteration: 417; Percent complete: 10.4%; Average loss: 3.9028
Iteration: 418; Percent complete: 10.4%; Average loss: 3.6317
Iteration: 419; Percent complete: 10.5%; Average loss: 3.7359
Iteration: 420; Percent complete: 10.5%; Average loss: 3.6530
Iteration: 421; Percent complete: 10.5%; Average loss: 3.8707
Iteration: 422; Percent complete: 10.5%; Average loss: 3.5575
Iteration: 423; Percent complete: 10.6%; Average loss: 3.6878
Iteratio

Iteration: 541; Percent complete: 13.5%; Average loss: 3.6803
Iteration: 542; Percent complete: 13.6%; Average loss: 3.7308
Iteration: 543; Percent complete: 13.6%; Average loss: 3.9338
Iteration: 544; Percent complete: 13.6%; Average loss: 3.7486
Iteration: 545; Percent complete: 13.6%; Average loss: 3.5895
Iteration: 546; Percent complete: 13.7%; Average loss: 3.6698
Iteration: 547; Percent complete: 13.7%; Average loss: 3.6503
Iteration: 548; Percent complete: 13.7%; Average loss: 3.5930
Iteration: 549; Percent complete: 13.7%; Average loss: 3.6429
Iteration: 550; Percent complete: 13.8%; Average loss: 3.7918
Iteration: 551; Percent complete: 13.8%; Average loss: 3.5817
Iteration: 552; Percent complete: 13.8%; Average loss: 3.6501
Iteration: 553; Percent complete: 13.8%; Average loss: 3.6722
Iteration: 554; Percent complete: 13.9%; Average loss: 3.5452
Iteration: 555; Percent complete: 13.9%; Average loss: 3.6879
Iteration: 556; Percent complete: 13.9%; Average loss: 3.4671
Iteratio

Iteration: 675; Percent complete: 16.9%; Average loss: 3.8760
Iteration: 676; Percent complete: 16.9%; Average loss: 3.4281
Iteration: 677; Percent complete: 16.9%; Average loss: 3.5345
Iteration: 678; Percent complete: 17.0%; Average loss: 3.4572
Iteration: 679; Percent complete: 17.0%; Average loss: 3.6764
Iteration: 680; Percent complete: 17.0%; Average loss: 3.7040
Iteration: 681; Percent complete: 17.0%; Average loss: 3.4969
Iteration: 682; Percent complete: 17.1%; Average loss: 3.7089
Iteration: 683; Percent complete: 17.1%; Average loss: 3.3397
Iteration: 684; Percent complete: 17.1%; Average loss: 3.9406
Iteration: 685; Percent complete: 17.1%; Average loss: 3.5826
Iteration: 686; Percent complete: 17.2%; Average loss: 3.5017
Iteration: 687; Percent complete: 17.2%; Average loss: 3.5395
Iteration: 688; Percent complete: 17.2%; Average loss: 3.5686
Iteration: 689; Percent complete: 17.2%; Average loss: 3.7019
Iteration: 690; Percent complete: 17.2%; Average loss: 3.9858
Iteratio

Iteration: 809; Percent complete: 20.2%; Average loss: 3.4761
Iteration: 810; Percent complete: 20.2%; Average loss: 3.5062
Iteration: 811; Percent complete: 20.3%; Average loss: 3.5040
Iteration: 812; Percent complete: 20.3%; Average loss: 3.8700
Iteration: 813; Percent complete: 20.3%; Average loss: 3.6983
Iteration: 814; Percent complete: 20.3%; Average loss: 3.5467
Iteration: 815; Percent complete: 20.4%; Average loss: 3.4389
Iteration: 816; Percent complete: 20.4%; Average loss: 3.6339
Iteration: 817; Percent complete: 20.4%; Average loss: 3.5700
Iteration: 818; Percent complete: 20.4%; Average loss: 3.7353
Iteration: 819; Percent complete: 20.5%; Average loss: 3.5721
Iteration: 820; Percent complete: 20.5%; Average loss: 3.3302
Iteration: 821; Percent complete: 20.5%; Average loss: 3.3968
Iteration: 822; Percent complete: 20.5%; Average loss: 3.4729
Iteration: 823; Percent complete: 20.6%; Average loss: 3.5855
Iteration: 824; Percent complete: 20.6%; Average loss: 3.6176
Iteratio

Iteration: 943; Percent complete: 23.6%; Average loss: 3.6393
Iteration: 944; Percent complete: 23.6%; Average loss: 3.3255
Iteration: 945; Percent complete: 23.6%; Average loss: 3.8286
Iteration: 946; Percent complete: 23.6%; Average loss: 3.4308
Iteration: 947; Percent complete: 23.7%; Average loss: 3.5520
Iteration: 948; Percent complete: 23.7%; Average loss: 3.3313
Iteration: 949; Percent complete: 23.7%; Average loss: 3.3601
Iteration: 950; Percent complete: 23.8%; Average loss: 3.3232
Iteration: 951; Percent complete: 23.8%; Average loss: 3.3807
Iteration: 952; Percent complete: 23.8%; Average loss: 3.4661
Iteration: 953; Percent complete: 23.8%; Average loss: 3.3895
Iteration: 954; Percent complete: 23.8%; Average loss: 3.4214
Iteration: 955; Percent complete: 23.9%; Average loss: 3.6682
Iteration: 956; Percent complete: 23.9%; Average loss: 3.3832
Iteration: 957; Percent complete: 23.9%; Average loss: 3.3722
Iteration: 958; Percent complete: 23.9%; Average loss: 3.7194
Iteratio

Iteration: 1075; Percent complete: 26.9%; Average loss: 3.3273
Iteration: 1076; Percent complete: 26.9%; Average loss: 3.5164
Iteration: 1077; Percent complete: 26.9%; Average loss: 3.1955
Iteration: 1078; Percent complete: 27.0%; Average loss: 3.4948
Iteration: 1079; Percent complete: 27.0%; Average loss: 3.2137
Iteration: 1080; Percent complete: 27.0%; Average loss: 3.6245
Iteration: 1081; Percent complete: 27.0%; Average loss: 3.3317
Iteration: 1082; Percent complete: 27.1%; Average loss: 3.3919
Iteration: 1083; Percent complete: 27.1%; Average loss: 3.2003
Iteration: 1084; Percent complete: 27.1%; Average loss: 3.3533
Iteration: 1085; Percent complete: 27.1%; Average loss: 3.3546
Iteration: 1086; Percent complete: 27.2%; Average loss: 3.4340
Iteration: 1087; Percent complete: 27.2%; Average loss: 3.6111
Iteration: 1088; Percent complete: 27.2%; Average loss: 3.4457
Iteration: 1089; Percent complete: 27.2%; Average loss: 3.0763
Iteration: 1090; Percent complete: 27.3%; Average loss:

Iteration: 1207; Percent complete: 30.2%; Average loss: 3.6070
Iteration: 1208; Percent complete: 30.2%; Average loss: 3.6060
Iteration: 1209; Percent complete: 30.2%; Average loss: 3.5330
Iteration: 1210; Percent complete: 30.2%; Average loss: 3.4075
Iteration: 1211; Percent complete: 30.3%; Average loss: 3.4468
Iteration: 1212; Percent complete: 30.3%; Average loss: 3.2276
Iteration: 1213; Percent complete: 30.3%; Average loss: 3.2578
Iteration: 1214; Percent complete: 30.3%; Average loss: 3.5384
Iteration: 1215; Percent complete: 30.4%; Average loss: 3.4816
Iteration: 1216; Percent complete: 30.4%; Average loss: 3.6064
Iteration: 1217; Percent complete: 30.4%; Average loss: 3.3228
Iteration: 1218; Percent complete: 30.4%; Average loss: 3.2934
Iteration: 1219; Percent complete: 30.5%; Average loss: 3.1886
Iteration: 1220; Percent complete: 30.5%; Average loss: 3.4935
Iteration: 1221; Percent complete: 30.5%; Average loss: 3.3656
Iteration: 1222; Percent complete: 30.6%; Average loss:

Iteration: 1339; Percent complete: 33.5%; Average loss: 3.5293
Iteration: 1340; Percent complete: 33.5%; Average loss: 3.2435
Iteration: 1341; Percent complete: 33.5%; Average loss: 3.2041
Iteration: 1342; Percent complete: 33.6%; Average loss: 3.3042
Iteration: 1343; Percent complete: 33.6%; Average loss: 3.3741
Iteration: 1344; Percent complete: 33.6%; Average loss: 3.5760
Iteration: 1345; Percent complete: 33.6%; Average loss: 3.1972
Iteration: 1346; Percent complete: 33.7%; Average loss: 3.3641
Iteration: 1347; Percent complete: 33.7%; Average loss: 3.5006
Iteration: 1348; Percent complete: 33.7%; Average loss: 3.4386
Iteration: 1349; Percent complete: 33.7%; Average loss: 3.3077
Iteration: 1350; Percent complete: 33.8%; Average loss: 3.3077
Iteration: 1351; Percent complete: 33.8%; Average loss: 3.3756
Iteration: 1352; Percent complete: 33.8%; Average loss: 3.2481
Iteration: 1353; Percent complete: 33.8%; Average loss: 3.0745
Iteration: 1354; Percent complete: 33.9%; Average loss:

Iteration: 1471; Percent complete: 36.8%; Average loss: 3.4952
Iteration: 1472; Percent complete: 36.8%; Average loss: 3.4494
Iteration: 1473; Percent complete: 36.8%; Average loss: 2.9870
Iteration: 1474; Percent complete: 36.9%; Average loss: 3.5215
Iteration: 1475; Percent complete: 36.9%; Average loss: 3.3816
Iteration: 1476; Percent complete: 36.9%; Average loss: 3.3283
Iteration: 1477; Percent complete: 36.9%; Average loss: 3.3835
Iteration: 1478; Percent complete: 37.0%; Average loss: 3.4738
Iteration: 1479; Percent complete: 37.0%; Average loss: 3.3921
Iteration: 1480; Percent complete: 37.0%; Average loss: 3.4381
Iteration: 1481; Percent complete: 37.0%; Average loss: 3.1905
Iteration: 1482; Percent complete: 37.0%; Average loss: 3.4407
Iteration: 1483; Percent complete: 37.1%; Average loss: 3.1471
Iteration: 1484; Percent complete: 37.1%; Average loss: 3.1589
Iteration: 1485; Percent complete: 37.1%; Average loss: 2.9922
Iteration: 1486; Percent complete: 37.1%; Average loss:

Iteration: 1603; Percent complete: 40.1%; Average loss: 3.2745
Iteration: 1604; Percent complete: 40.1%; Average loss: 3.4176
Iteration: 1605; Percent complete: 40.1%; Average loss: 3.5768
Iteration: 1606; Percent complete: 40.2%; Average loss: 3.1686
Iteration: 1607; Percent complete: 40.2%; Average loss: 3.2415
Iteration: 1608; Percent complete: 40.2%; Average loss: 3.5429
Iteration: 1609; Percent complete: 40.2%; Average loss: 3.2112
Iteration: 1610; Percent complete: 40.2%; Average loss: 3.2987
Iteration: 1611; Percent complete: 40.3%; Average loss: 3.2300
Iteration: 1612; Percent complete: 40.3%; Average loss: 3.1874
Iteration: 1613; Percent complete: 40.3%; Average loss: 3.2258
Iteration: 1614; Percent complete: 40.4%; Average loss: 3.4236
Iteration: 1615; Percent complete: 40.4%; Average loss: 3.5116
Iteration: 1616; Percent complete: 40.4%; Average loss: 3.2061
Iteration: 1617; Percent complete: 40.4%; Average loss: 3.1347
Iteration: 1618; Percent complete: 40.5%; Average loss:

Iteration: 1735; Percent complete: 43.4%; Average loss: 3.4043
Iteration: 1736; Percent complete: 43.4%; Average loss: 3.0116
Iteration: 1737; Percent complete: 43.4%; Average loss: 3.1372
Iteration: 1738; Percent complete: 43.5%; Average loss: 3.5309
Iteration: 1739; Percent complete: 43.5%; Average loss: 3.2364
Iteration: 1740; Percent complete: 43.5%; Average loss: 3.3139
Iteration: 1741; Percent complete: 43.5%; Average loss: 3.3209
Iteration: 1742; Percent complete: 43.5%; Average loss: 2.9889
Iteration: 1743; Percent complete: 43.6%; Average loss: 3.1228
Iteration: 1744; Percent complete: 43.6%; Average loss: 3.2989
Iteration: 1745; Percent complete: 43.6%; Average loss: 3.2926
Iteration: 1746; Percent complete: 43.6%; Average loss: 3.3397
Iteration: 1747; Percent complete: 43.7%; Average loss: 3.3238
Iteration: 1748; Percent complete: 43.7%; Average loss: 3.0636
Iteration: 1749; Percent complete: 43.7%; Average loss: 3.3020
Iteration: 1750; Percent complete: 43.8%; Average loss:

Iteration: 1867; Percent complete: 46.7%; Average loss: 3.2572
Iteration: 1868; Percent complete: 46.7%; Average loss: 3.2886
Iteration: 1869; Percent complete: 46.7%; Average loss: 3.0329
Iteration: 1870; Percent complete: 46.8%; Average loss: 3.1916
Iteration: 1871; Percent complete: 46.8%; Average loss: 3.3083
Iteration: 1872; Percent complete: 46.8%; Average loss: 2.9666
Iteration: 1873; Percent complete: 46.8%; Average loss: 3.6201
Iteration: 1874; Percent complete: 46.9%; Average loss: 3.0751
Iteration: 1875; Percent complete: 46.9%; Average loss: 3.3245
Iteration: 1876; Percent complete: 46.9%; Average loss: 3.3663
Iteration: 1877; Percent complete: 46.9%; Average loss: 3.0891
Iteration: 1878; Percent complete: 46.9%; Average loss: 3.1404
Iteration: 1879; Percent complete: 47.0%; Average loss: 3.1130
Iteration: 1880; Percent complete: 47.0%; Average loss: 3.3248
Iteration: 1881; Percent complete: 47.0%; Average loss: 3.3043
Iteration: 1882; Percent complete: 47.0%; Average loss:

Iteration: 1999; Percent complete: 50.0%; Average loss: 3.0254
Iteration: 2000; Percent complete: 50.0%; Average loss: 3.2714
Iteration: 2001; Percent complete: 50.0%; Average loss: 3.3982
Iteration: 2002; Percent complete: 50.0%; Average loss: 3.3516
Iteration: 2003; Percent complete: 50.1%; Average loss: 3.2693
Iteration: 2004; Percent complete: 50.1%; Average loss: 3.2177
Iteration: 2005; Percent complete: 50.1%; Average loss: 3.1888
Iteration: 2006; Percent complete: 50.1%; Average loss: 3.2184
Iteration: 2007; Percent complete: 50.2%; Average loss: 3.0245
Iteration: 2008; Percent complete: 50.2%; Average loss: 3.1891
Iteration: 2009; Percent complete: 50.2%; Average loss: 3.1936
Iteration: 2010; Percent complete: 50.2%; Average loss: 3.1308
Iteration: 2011; Percent complete: 50.3%; Average loss: 3.2456
Iteration: 2012; Percent complete: 50.3%; Average loss: 3.1245
Iteration: 2013; Percent complete: 50.3%; Average loss: 3.2957
Iteration: 2014; Percent complete: 50.3%; Average loss:

Iteration: 2131; Percent complete: 53.3%; Average loss: 3.2827
Iteration: 2132; Percent complete: 53.3%; Average loss: 3.2853
Iteration: 2133; Percent complete: 53.3%; Average loss: 3.1622
Iteration: 2134; Percent complete: 53.3%; Average loss: 3.2794
Iteration: 2135; Percent complete: 53.4%; Average loss: 3.0744
Iteration: 2136; Percent complete: 53.4%; Average loss: 3.0485
Iteration: 2137; Percent complete: 53.4%; Average loss: 3.0526
Iteration: 2138; Percent complete: 53.4%; Average loss: 2.9021
Iteration: 2139; Percent complete: 53.5%; Average loss: 3.1565
Iteration: 2140; Percent complete: 53.5%; Average loss: 3.1133
Iteration: 2141; Percent complete: 53.5%; Average loss: 2.9026
Iteration: 2142; Percent complete: 53.5%; Average loss: 3.2447
Iteration: 2143; Percent complete: 53.6%; Average loss: 2.9717
Iteration: 2144; Percent complete: 53.6%; Average loss: 3.5557
Iteration: 2145; Percent complete: 53.6%; Average loss: 3.0374
Iteration: 2146; Percent complete: 53.6%; Average loss:

Iteration: 2263; Percent complete: 56.6%; Average loss: 2.8800
Iteration: 2264; Percent complete: 56.6%; Average loss: 2.9738
Iteration: 2265; Percent complete: 56.6%; Average loss: 2.7869
Iteration: 2266; Percent complete: 56.6%; Average loss: 2.9751
Iteration: 2267; Percent complete: 56.7%; Average loss: 2.7252
Iteration: 2268; Percent complete: 56.7%; Average loss: 3.0372
Iteration: 2269; Percent complete: 56.7%; Average loss: 3.0020
Iteration: 2270; Percent complete: 56.8%; Average loss: 3.2173
Iteration: 2271; Percent complete: 56.8%; Average loss: 2.8960
Iteration: 2272; Percent complete: 56.8%; Average loss: 2.9098
Iteration: 2273; Percent complete: 56.8%; Average loss: 3.1830
Iteration: 2274; Percent complete: 56.9%; Average loss: 3.0930
Iteration: 2275; Percent complete: 56.9%; Average loss: 3.1275
Iteration: 2276; Percent complete: 56.9%; Average loss: 3.1662
Iteration: 2277; Percent complete: 56.9%; Average loss: 2.8856
Iteration: 2278; Percent complete: 57.0%; Average loss:

Iteration: 2395; Percent complete: 59.9%; Average loss: 2.9503
Iteration: 2396; Percent complete: 59.9%; Average loss: 2.8880
Iteration: 2397; Percent complete: 59.9%; Average loss: 2.8456
Iteration: 2398; Percent complete: 60.0%; Average loss: 3.2393
Iteration: 2399; Percent complete: 60.0%; Average loss: 2.9492
Iteration: 2400; Percent complete: 60.0%; Average loss: 2.7921
Iteration: 2401; Percent complete: 60.0%; Average loss: 2.9755
Iteration: 2402; Percent complete: 60.1%; Average loss: 3.1697
Iteration: 2403; Percent complete: 60.1%; Average loss: 2.7544
Iteration: 2404; Percent complete: 60.1%; Average loss: 3.0201
Iteration: 2405; Percent complete: 60.1%; Average loss: 2.9710
Iteration: 2406; Percent complete: 60.2%; Average loss: 2.9479
Iteration: 2407; Percent complete: 60.2%; Average loss: 3.0872
Iteration: 2408; Percent complete: 60.2%; Average loss: 3.0270
Iteration: 2409; Percent complete: 60.2%; Average loss: 3.0285
Iteration: 2410; Percent complete: 60.2%; Average loss:

Iteration: 2527; Percent complete: 63.2%; Average loss: 3.1040
Iteration: 2528; Percent complete: 63.2%; Average loss: 3.0402
Iteration: 2529; Percent complete: 63.2%; Average loss: 2.8175
Iteration: 2530; Percent complete: 63.2%; Average loss: 2.9095
Iteration: 2531; Percent complete: 63.3%; Average loss: 2.8267
Iteration: 2532; Percent complete: 63.3%; Average loss: 2.9882
Iteration: 2533; Percent complete: 63.3%; Average loss: 3.0019
Iteration: 2534; Percent complete: 63.3%; Average loss: 3.1199
Iteration: 2535; Percent complete: 63.4%; Average loss: 2.9301
Iteration: 2536; Percent complete: 63.4%; Average loss: 2.9351
Iteration: 2537; Percent complete: 63.4%; Average loss: 2.9542
Iteration: 2538; Percent complete: 63.4%; Average loss: 3.0032
Iteration: 2539; Percent complete: 63.5%; Average loss: 3.2938
Iteration: 2540; Percent complete: 63.5%; Average loss: 3.1509
Iteration: 2541; Percent complete: 63.5%; Average loss: 2.8756
Iteration: 2542; Percent complete: 63.5%; Average loss:

Iteration: 2659; Percent complete: 66.5%; Average loss: 2.8543
Iteration: 2660; Percent complete: 66.5%; Average loss: 3.0699
Iteration: 2661; Percent complete: 66.5%; Average loss: 2.9586
Iteration: 2662; Percent complete: 66.5%; Average loss: 3.2112
Iteration: 2663; Percent complete: 66.6%; Average loss: 3.0156
Iteration: 2664; Percent complete: 66.6%; Average loss: 2.6452
Iteration: 2665; Percent complete: 66.6%; Average loss: 2.9693
Iteration: 2666; Percent complete: 66.6%; Average loss: 2.9948
Iteration: 2667; Percent complete: 66.7%; Average loss: 3.0569
Iteration: 2668; Percent complete: 66.7%; Average loss: 3.0619
Iteration: 2669; Percent complete: 66.7%; Average loss: 2.9829
Iteration: 2670; Percent complete: 66.8%; Average loss: 2.9904
Iteration: 2671; Percent complete: 66.8%; Average loss: 2.8697
Iteration: 2672; Percent complete: 66.8%; Average loss: 3.0103
Iteration: 2673; Percent complete: 66.8%; Average loss: 3.0081
Iteration: 2674; Percent complete: 66.8%; Average loss:

Iteration: 2791; Percent complete: 69.8%; Average loss: 2.8510
Iteration: 2792; Percent complete: 69.8%; Average loss: 2.8934
Iteration: 2793; Percent complete: 69.8%; Average loss: 3.0251
Iteration: 2794; Percent complete: 69.8%; Average loss: 2.9049
Iteration: 2795; Percent complete: 69.9%; Average loss: 3.0333
Iteration: 2796; Percent complete: 69.9%; Average loss: 2.8987
Iteration: 2797; Percent complete: 69.9%; Average loss: 3.1432
Iteration: 2798; Percent complete: 70.0%; Average loss: 2.8388
Iteration: 2799; Percent complete: 70.0%; Average loss: 3.2025
Iteration: 2800; Percent complete: 70.0%; Average loss: 2.9040
Iteration: 2801; Percent complete: 70.0%; Average loss: 2.9765
Iteration: 2802; Percent complete: 70.0%; Average loss: 2.9705
Iteration: 2803; Percent complete: 70.1%; Average loss: 3.2497
Iteration: 2804; Percent complete: 70.1%; Average loss: 2.8670
Iteration: 2805; Percent complete: 70.1%; Average loss: 2.9310
Iteration: 2806; Percent complete: 70.2%; Average loss:

Iteration: 2923; Percent complete: 73.1%; Average loss: 2.8454
Iteration: 2924; Percent complete: 73.1%; Average loss: 2.9459
Iteration: 2925; Percent complete: 73.1%; Average loss: 2.9008
Iteration: 2926; Percent complete: 73.2%; Average loss: 2.8274
Iteration: 2927; Percent complete: 73.2%; Average loss: 2.5827
Iteration: 2928; Percent complete: 73.2%; Average loss: 2.9653
Iteration: 2929; Percent complete: 73.2%; Average loss: 2.8170
Iteration: 2930; Percent complete: 73.2%; Average loss: 2.7579
Iteration: 2931; Percent complete: 73.3%; Average loss: 2.8656
Iteration: 2932; Percent complete: 73.3%; Average loss: 2.9794
Iteration: 2933; Percent complete: 73.3%; Average loss: 2.8198
Iteration: 2934; Percent complete: 73.4%; Average loss: 2.6231
Iteration: 2935; Percent complete: 73.4%; Average loss: 3.0398
Iteration: 2936; Percent complete: 73.4%; Average loss: 2.9560
Iteration: 2937; Percent complete: 73.4%; Average loss: 2.8980
Iteration: 2938; Percent complete: 73.5%; Average loss:

Iteration: 3055; Percent complete: 76.4%; Average loss: 2.8805
Iteration: 3056; Percent complete: 76.4%; Average loss: 2.8071
Iteration: 3057; Percent complete: 76.4%; Average loss: 3.0033
Iteration: 3058; Percent complete: 76.4%; Average loss: 2.7541
Iteration: 3059; Percent complete: 76.5%; Average loss: 2.8955
Iteration: 3060; Percent complete: 76.5%; Average loss: 2.8475
Iteration: 3061; Percent complete: 76.5%; Average loss: 2.9127
Iteration: 3062; Percent complete: 76.5%; Average loss: 2.7515
Iteration: 3063; Percent complete: 76.6%; Average loss: 2.8675
Iteration: 3064; Percent complete: 76.6%; Average loss: 2.9977
Iteration: 3065; Percent complete: 76.6%; Average loss: 2.9499
Iteration: 3066; Percent complete: 76.6%; Average loss: 2.9072
Iteration: 3067; Percent complete: 76.7%; Average loss: 2.8840
Iteration: 3068; Percent complete: 76.7%; Average loss: 2.9930
Iteration: 3069; Percent complete: 76.7%; Average loss: 2.8795
Iteration: 3070; Percent complete: 76.8%; Average loss:

Iteration: 3187; Percent complete: 79.7%; Average loss: 3.0520
Iteration: 3188; Percent complete: 79.7%; Average loss: 2.6942
Iteration: 3189; Percent complete: 79.7%; Average loss: 2.7551
Iteration: 3190; Percent complete: 79.8%; Average loss: 2.8332
Iteration: 3191; Percent complete: 79.8%; Average loss: 2.7846
Iteration: 3192; Percent complete: 79.8%; Average loss: 2.8722
Iteration: 3193; Percent complete: 79.8%; Average loss: 2.8109
Iteration: 3194; Percent complete: 79.8%; Average loss: 2.8406
Iteration: 3195; Percent complete: 79.9%; Average loss: 2.7464
Iteration: 3196; Percent complete: 79.9%; Average loss: 2.6919
Iteration: 3197; Percent complete: 79.9%; Average loss: 2.9910
Iteration: 3198; Percent complete: 80.0%; Average loss: 3.0618
Iteration: 3199; Percent complete: 80.0%; Average loss: 2.8668
Iteration: 3200; Percent complete: 80.0%; Average loss: 3.0748
Iteration: 3201; Percent complete: 80.0%; Average loss: 2.9348
Iteration: 3202; Percent complete: 80.0%; Average loss:

Iteration: 3319; Percent complete: 83.0%; Average loss: 2.8821
Iteration: 3320; Percent complete: 83.0%; Average loss: 2.9371
Iteration: 3321; Percent complete: 83.0%; Average loss: 2.6466
Iteration: 3322; Percent complete: 83.0%; Average loss: 2.6584
Iteration: 3323; Percent complete: 83.1%; Average loss: 2.5871
Iteration: 3324; Percent complete: 83.1%; Average loss: 2.8540
Iteration: 3325; Percent complete: 83.1%; Average loss: 2.9339
Iteration: 3326; Percent complete: 83.2%; Average loss: 2.9244
Iteration: 3327; Percent complete: 83.2%; Average loss: 2.8822
Iteration: 3328; Percent complete: 83.2%; Average loss: 2.9647
Iteration: 3329; Percent complete: 83.2%; Average loss: 2.7104
Iteration: 3330; Percent complete: 83.2%; Average loss: 2.8606
Iteration: 3331; Percent complete: 83.3%; Average loss: 2.6477
Iteration: 3332; Percent complete: 83.3%; Average loss: 3.0846
Iteration: 3333; Percent complete: 83.3%; Average loss: 2.8401
Iteration: 3334; Percent complete: 83.4%; Average loss:

Iteration: 3451; Percent complete: 86.3%; Average loss: 2.7095
Iteration: 3452; Percent complete: 86.3%; Average loss: 2.7979
Iteration: 3453; Percent complete: 86.3%; Average loss: 2.7906
Iteration: 3454; Percent complete: 86.4%; Average loss: 2.8687
Iteration: 3455; Percent complete: 86.4%; Average loss: 2.6362
Iteration: 3456; Percent complete: 86.4%; Average loss: 2.6447
Iteration: 3457; Percent complete: 86.4%; Average loss: 2.8305
Iteration: 3458; Percent complete: 86.5%; Average loss: 2.7644
Iteration: 3459; Percent complete: 86.5%; Average loss: 2.8565
Iteration: 3460; Percent complete: 86.5%; Average loss: 2.6950
Iteration: 3461; Percent complete: 86.5%; Average loss: 2.7161
Iteration: 3462; Percent complete: 86.6%; Average loss: 2.8624
Iteration: 3463; Percent complete: 86.6%; Average loss: 2.8020
Iteration: 3464; Percent complete: 86.6%; Average loss: 2.6686
Iteration: 3465; Percent complete: 86.6%; Average loss: 2.8625
Iteration: 3466; Percent complete: 86.7%; Average loss:

Iteration: 3583; Percent complete: 89.6%; Average loss: 2.7304
Iteration: 3584; Percent complete: 89.6%; Average loss: 2.7884
Iteration: 3585; Percent complete: 89.6%; Average loss: 2.7890
Iteration: 3586; Percent complete: 89.6%; Average loss: 2.8354
Iteration: 3587; Percent complete: 89.7%; Average loss: 2.9144
Iteration: 3588; Percent complete: 89.7%; Average loss: 2.7822
Iteration: 3589; Percent complete: 89.7%; Average loss: 2.7195
Iteration: 3590; Percent complete: 89.8%; Average loss: 2.7865
Iteration: 3591; Percent complete: 89.8%; Average loss: 2.8193
Iteration: 3592; Percent complete: 89.8%; Average loss: 2.5123
Iteration: 3593; Percent complete: 89.8%; Average loss: 2.4664
Iteration: 3594; Percent complete: 89.8%; Average loss: 2.6889
Iteration: 3595; Percent complete: 89.9%; Average loss: 2.8893
Iteration: 3596; Percent complete: 89.9%; Average loss: 2.6172
Iteration: 3597; Percent complete: 89.9%; Average loss: 2.9699
Iteration: 3598; Percent complete: 90.0%; Average loss:

Iteration: 3715; Percent complete: 92.9%; Average loss: 2.7476
Iteration: 3716; Percent complete: 92.9%; Average loss: 2.6863
Iteration: 3717; Percent complete: 92.9%; Average loss: 2.6458
Iteration: 3718; Percent complete: 93.0%; Average loss: 2.4397
Iteration: 3719; Percent complete: 93.0%; Average loss: 2.7951
Iteration: 3720; Percent complete: 93.0%; Average loss: 2.7705
Iteration: 3721; Percent complete: 93.0%; Average loss: 2.7373
Iteration: 3722; Percent complete: 93.0%; Average loss: 2.6062
Iteration: 3723; Percent complete: 93.1%; Average loss: 2.5042
Iteration: 3724; Percent complete: 93.1%; Average loss: 2.6484
Iteration: 3725; Percent complete: 93.1%; Average loss: 2.7070
Iteration: 3726; Percent complete: 93.2%; Average loss: 2.6274
Iteration: 3727; Percent complete: 93.2%; Average loss: 2.7477
Iteration: 3728; Percent complete: 93.2%; Average loss: 2.6696
Iteration: 3729; Percent complete: 93.2%; Average loss: 2.4833
Iteration: 3730; Percent complete: 93.2%; Average loss:

Iteration: 3847; Percent complete: 96.2%; Average loss: 2.7953
Iteration: 3848; Percent complete: 96.2%; Average loss: 2.5719
Iteration: 3849; Percent complete: 96.2%; Average loss: 2.5523
Iteration: 3850; Percent complete: 96.2%; Average loss: 2.4551
Iteration: 3851; Percent complete: 96.3%; Average loss: 2.5660
Iteration: 3852; Percent complete: 96.3%; Average loss: 2.5622
Iteration: 3853; Percent complete: 96.3%; Average loss: 2.6376
Iteration: 3854; Percent complete: 96.4%; Average loss: 2.6488
Iteration: 3855; Percent complete: 96.4%; Average loss: 2.3396
Iteration: 3856; Percent complete: 96.4%; Average loss: 2.7166
Iteration: 3857; Percent complete: 96.4%; Average loss: 2.6506
Iteration: 3858; Percent complete: 96.5%; Average loss: 2.6339
Iteration: 3859; Percent complete: 96.5%; Average loss: 2.6372
Iteration: 3860; Percent complete: 96.5%; Average loss: 2.5091
Iteration: 3861; Percent complete: 96.5%; Average loss: 2.7149
Iteration: 3862; Percent complete: 96.5%; Average loss:

Iteration: 3979; Percent complete: 99.5%; Average loss: 2.7250
Iteration: 3980; Percent complete: 99.5%; Average loss: 2.5929
Iteration: 3981; Percent complete: 99.5%; Average loss: 2.6971
Iteration: 3982; Percent complete: 99.6%; Average loss: 2.5897
Iteration: 3983; Percent complete: 99.6%; Average loss: 2.6827
Iteration: 3984; Percent complete: 99.6%; Average loss: 2.9289
Iteration: 3985; Percent complete: 99.6%; Average loss: 2.6169
Iteration: 3986; Percent complete: 99.7%; Average loss: 2.5391
Iteration: 3987; Percent complete: 99.7%; Average loss: 2.5094
Iteration: 3988; Percent complete: 99.7%; Average loss: 2.9162
Iteration: 3989; Percent complete: 99.7%; Average loss: 2.5455
Iteration: 3990; Percent complete: 99.8%; Average loss: 2.5603
Iteration: 3991; Percent complete: 99.8%; Average loss: 2.5727
Iteration: 3992; Percent complete: 99.8%; Average loss: 2.4838
Iteration: 3993; Percent complete: 99.8%; Average loss: 2.5804
Iteration: 3994; Percent complete: 99.9%; Average loss:

In [None]:
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

# Begin chatting (uncomment and run the following line to begin)
evaluateInput(encoder, decoder, searcher, voc)

> hiya
Bot: what ? . . . .
> adsf
Error: Encountered unknown word.
> huh
Bot: what do you mean ? . .
> hiy
Error: Encountered unknown word.
> where is cathyu
Error: Encountered unknown word.
> where is cathy
Bot: where is he ? . . .
> where is bob
Bot: where is he ? . . .
