In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math

from encoder import EncoderRNN
from decoder import LuongAttnDecoderRNN
from movie_line_process import loadLines, loadConversations, extractSentencePairs
from voc import loadPrepareData, trimRareWords, normalizeString, makeVoc
from voc import MIN_COUNT, MAX_INPUT_LENGTH, MAX_OUTPUT_LENGTH, PAD_token, SOS_token, EOS_token
from prepare_data import indexesFromSentence, batch2TrainData
from train import trainIters
from model_config import model_name, attn_model, hidden_size
from model_config import encoder_n_layers, decoder_n_layers, dropout, batch_size
from model_config import device, loadFilename, checkpoint_iter
from train_config import clip, learning_rate, decoder_learning_ratio, n_iteration
from train_config import print_every, save_every
from evaluate import GreedySearchDecoder, evaluateInput

from squad_loader import prepare_par_pairs, prepare_sent_pairs

In [3]:
corpus_name = "squad"
corpus = os.path.join("data", corpus_name)

def printLines(file, n=10):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

#printLines(os.path.join(corpus, "train-v2.0.json"))

In [4]:
# Define path to new file
datafile = os.path.join(corpus, "formatted_dev_squad_qa.txt")

delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

# Write new csv file
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n\n')
    pairs = prepare_sent_pairs()
    for pair in pairs:
        writer.writerow(pair)
    
# Print a sample of lines
print("\nSample lines from file:")
printLines(datafile)


Writing newly formatted file...
Processing SquAD dataset...
Processing done!

Sample lines from file:
b" Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame   in the late 1990s   as lead singer of R&B girl-group Destiny's Child\tWhen did Beyonce start becoming popular?\n"
b'\n'
b" Born and raised in Houston, Texas, she performed in various   singing and dancing   competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny's Child\tWhat areas did Beyonce compete in when she was growing up?\n"
b'\n'
b'" Their hiatus saw the release of Beyonc\xc3\xa9\'s debut album, Dangerously in Love (  2003  ), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles ""Crazy in Love"" and ""Baby Boy"""\tWhen did Beyonce leave Destiny\'s Child and become a solo singer?\n'
b'\n'
b" Born and raised in   Houston, Texas  , sh

In [5]:
# Load/Assemble voc and pairs
save_dir = os.path.join("data", "save")
pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)
voc = makeVoc(corpus_name)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
    print(pair)
print(pairs[-1])


Start preparing training data ...
Reading lines...
Read 86821 sentence pairs
Trimmed to 86585 sentence pairs
Counting words...
Processing SquAD dataset...
Processing done!

pairs:
['born and raised in houston texas she performed in various singing and dancing competitions as a child and rose to fame in the late 1990s as lead singer of r b girl group destiny s child', 'when did beyonce start becoming popular ?']
['born and raised in houston texas she performed in various singing and dancing competitions as a child and rose to fame in the late 1990s as lead singer of r b girl group destiny s child', 'what areas did beyonce compete in when she was growing up ?']
['their hiatus saw the release of beyonce s debut album dangerously in love 2003 which established her as a solo artist worldwide earned five grammy awards and featured the billboard hot 100 number one singles crazy in love and baby boy', 'when did beyonce leave destiny s child and become a solo singer ?']
['born and raised in hou

In [6]:
#print(voc.num_words)
# Trim voc
#pairs = trimRareWords(voc, pairs, MIN_COUNT)
voc.trim(MIN_COUNT)
#print(voc.index2word[14274])

keep_words 39927 / 82663 = 0.4830


In [7]:
# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[24992, 35743, 24992, 16012,  7734],
        [29237, 24992, 26097, 27437,  4391],
        [25438,  6273, 38701, 21651, 10104],
        [17392, 38701, 16387,  1761, 25812],
        [ 1837, 28993, 25812,  7603, 30421],
        [25812, 16012, 18736, 31437, 14547],
        [ 5571, 19910, 17375, 34783, 28992],
        [ 4862, 39047, 14902, 14793,  6958],
        [30867, 24992, 16143, 28253,     3],
        [24884, 24551, 24850,  1974,     2],
        [ 1837,  8482, 11643, 26574,     0],
        [38701, 25812, 37116, 19359,     0],
        [36314, 24992, 25110, 17375,     0],
        [24469,  5915,  4448, 30930,     0],
        [31080, 17320, 33946, 20352,     0],
        [24992, 38701, 36294, 19924,     0],
        [  298, 24992, 24992, 24992,     0],
        [38701, 27620, 28598, 29684,     0],
        [ 8149, 18242, 39047, 17375,     0],
        [33946,  1486, 32020,  6996,     0],
        [35970, 35970, 17375, 21651,     0],
        [24992, 17270, 24992,  1974,   

In [8]:
# Load model if a loadFilename is provided
if loadFilename:
    # If loading on same machine the model was trained on
    checkpoint = torch.load(loadFilename)
    # If loading a model trained on GPU to CPU
    #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']


print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Building encoder and decoder ...
Models built and ready to go!


In [13]:
# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

# Initialize optimizers
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)

# Run training iterations
print("Starting Training!")
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, 3000, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building optimizers ...
Starting Training!
Initializing ...
Training...
Iteration: 10; Percent complete: 0.3%; Average loss: 2.5526
Iteration: 20; Percent complete: 0.7%; Average loss: 2.5791
Iteration: 30; Percent complete: 1.0%; Average loss: 2.5328
Iteration: 40; Percent complete: 1.3%; Average loss: 2.4961
Iteration: 50; Percent complete: 1.7%; Average loss: 2.6192
Iteration: 60; Percent complete: 2.0%; Average loss: 2.5750
Iteration: 70; Percent complete: 2.3%; Average loss: 2.5523
Iteration: 80; Percent complete: 2.7%; Average loss: 2.6706
Iteration: 90; Percent complete: 3.0%; Average loss: 2.6813
Iteration: 100; Percent complete: 3.3%; Average loss: 2.6381
Iteration: 110; Percent complete: 3.7%; Average loss: 2.5897
Iteration: 120; Percent complete: 4.0%; Average loss: 2.5830
Iteration: 130; Percent complete: 4.3%; Average loss: 2.7236
Iteration: 140; Percent complete: 4.7%; Average loss: 2.6574
Iteration: 150; Percent complete: 5.0%; Average loss: 2.7179
Iteration: 160; Percen

Iteration: 1330; Percent complete: 44.3%; Average loss: 2.4355
Iteration: 1340; Percent complete: 44.7%; Average loss: 2.4636
Iteration: 1350; Percent complete: 45.0%; Average loss: 2.4563
Iteration: 1360; Percent complete: 45.3%; Average loss: 2.4075
Iteration: 1370; Percent complete: 45.7%; Average loss: 2.4780
Iteration: 1380; Percent complete: 46.0%; Average loss: 2.4145
Iteration: 1390; Percent complete: 46.3%; Average loss: 2.5272
Iteration: 1400; Percent complete: 46.7%; Average loss: 2.4811
Iteration: 1410; Percent complete: 47.0%; Average loss: 2.4501
Iteration: 1420; Percent complete: 47.3%; Average loss: 2.4198
Iteration: 1430; Percent complete: 47.7%; Average loss: 2.4869
Iteration: 1440; Percent complete: 48.0%; Average loss: 2.4668
Iteration: 1450; Percent complete: 48.3%; Average loss: 2.3614
Iteration: 1460; Percent complete: 48.7%; Average loss: 2.4502
Iteration: 1470; Percent complete: 49.0%; Average loss: 2.4074
Iteration: 1480; Percent complete: 49.3%; Average loss:

Iteration: 2640; Percent complete: 88.0%; Average loss: 2.2506
Iteration: 2650; Percent complete: 88.3%; Average loss: 2.2089
Iteration: 2660; Percent complete: 88.7%; Average loss: 2.2857
Iteration: 2670; Percent complete: 89.0%; Average loss: 2.1899
Iteration: 2680; Percent complete: 89.3%; Average loss: 2.2212
Iteration: 2690; Percent complete: 89.7%; Average loss: 2.2137
Iteration: 2700; Percent complete: 90.0%; Average loss: 2.2282
Iteration: 2710; Percent complete: 90.3%; Average loss: 2.1780
Iteration: 2720; Percent complete: 90.7%; Average loss: 2.2228
Iteration: 2730; Percent complete: 91.0%; Average loss: 2.3413
Iteration: 2740; Percent complete: 91.3%; Average loss: 2.2295
Iteration: 2750; Percent complete: 91.7%; Average loss: 2.2032
Iteration: 2760; Percent complete: 92.0%; Average loss: 2.3101
Iteration: 2770; Percent complete: 92.3%; Average loss: 2.1837
Iteration: 2780; Percent complete: 92.7%; Average loss: 2.2665
Iteration: 2790; Percent complete: 93.0%; Average loss:

In [None]:
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

# Begin chatting (uncomment and run the following line to begin)
evaluateInput(encoder, decoder, searcher, voc)

> According to professor Jeffrey Pfeffer from Stanford, there are only three levels in the organization and CEO Jim Goodnight has 27 people who directly report to him.
Bot: how many members does the new york city have ?
> Employees are encouraged to do volunteer work and the company makes donation to non-profits where employees are involved.
Bot: what are the two companies that are connected to the company ? t receive a result of the revolution ?
> SAS started building its current headquarters in a forested area of Cary, North Carolina in 1980.
Bot: what is the name of the area in miami ? is what ?
> Stanlow & Thornton railway station is located within the Stanlow Refinery in Cheshire, England.
Bot: what is the railway station in the UNK railway ? UNK UNK to UNK spain ? is what
> American Sentinel University was established through the joining of two separate schools: American College of Computer & Information Sciences and American Graduate School of Management. 
Bot: what was the name