In [1]:
%matplotlib inline

In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math


USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

In [3]:
# corpus_name = "cornell movie-dialogs corpus"
# corpus = os.path.join("data", corpus_name)

def printLines(file, n=10):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

# printLines(os.path.join(corpus, "movie_lines.txt"))

In [4]:
# Splits each line of the file into a dictionary of fields
def loadLines(fileName, fields):
    lines = {}
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            lineObj = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj
    return lines


# Groups fields of lines from `loadLines` into conversations based on *movie_conversations.txt*
def loadConversations(fileName, lines, fields):
    conversations = []
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            convObj = {}
            for i, field in enumerate(fields):
                convObj[field] = values[i]
            # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
            lineIds = eval(convObj["utteranceIDs"])
            # Reassemble lines
            convObj["lines"] = []
            for lineId in lineIds:
                convObj["lines"].append(lines[lineId])
            conversations.append(convObj)
    return conversations


# Extracts pairs of sentences from conversations
def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        # Iterate over all the lines of the conversation
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()
            # Filter wrong samples (if one of the lists is empty)
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

Now we’ll call these functions and create the file. We’ll call it
*formatted_movie_lines.txt*.




In [5]:
# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)

In [6]:
MAX_LENGTH = 10  # Maximum sentence length to consider

# Turn a Unicode string to plain ASCII, thanks to
# http://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

# Read query/response pairs and return a voc object
def readVocs(datafile, corpus_name):
    print("Reading lines...")
    # Read the file and split into lines
    lines = open(datafile, encoding='utf-8').\
        read().strip().split('\n')
    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split(',')] for l in lines[1:]]
    voc = Voc(corpus_name)
    return voc, pairs

import ipdb
# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filterPair(p):
    # Input sequences need to preserve the last word for EOS token
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

# Using the functions defined above, return a populated voc object and pairs list
def loadPrepareData(corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
#     pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs

corpus_name = 'sarch_query'
datafile = 'data/Batch_generation_2/train_step_query_question.csv'

# Load/Assemble voc and pairs
save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus_name, datafile, save_dir)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
    print(pair)

Start preparing training data ...
Reading lines...
Read 1789 sentence pairs
Trimmed to 1789 sentence pairs
Counting words...
Counted words: 4242

pairs:
['remove chicken thighs brine pat dry paper towel dry chicken need cooking dry chicken need cooking', 'dry chicken needs']
['place paper towels drain excess oil benefits draining oil potatoes purpose draining excess oil', 'include reason draining excess oil']
['add oregano garlic powder cumin chili powder cayenne salt pepper stir well covered cook another minutes level heat use cooking soul chili level heat use', 'would useful know heat needs adjusted step']
['traditionally turkish kisir eaten lettuce leaves serve lettuce leaf leave side wrap kisir inside lettuce leaves eat also serve tomatoes alongside peppers affect turkish kisir peppers affect recipe', 'paprika subbed red pepper flakes']
['deglaze skillet wine add cream chile puree cook reduced desired consistency stir chives make medallions pork tenderloin steak alternative cream w

In [7]:
MIN_COUNT = 3    # Minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        # Only keep pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs


# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 2088 / 4239 = 0.4926
Trimmed from 1789 pairs to 650, 0.3633 of total


In [8]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]


def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

# Returns padded input sequence tensor and lengths
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.ByteTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len


# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[1472,  301,  758,  191, 1026],
        [ 609,  163,  115, 1376, 1679],
        [ 216,  214,   25,  388,  161],
        [ 598,  357,  107, 1669,  131],
        [ 189,  215,  505, 1923,  509],
        [ 319, 1774,   18,  586,  310],
        [1847,   25,  131,  383, 1642],
        [ 853,  163,   34, 1682,   69],
        [ 154,  800,  853,  113,   97],
        [ 420,   97,  609, 1872, 1555],
        [ 421,  214,   62,  660,  176],
        [ 747,  163,   18,  515,  746],
        [  76,   41,  177,  816, 1680],
        [ 609,  278,  265, 1682, 1679],
        [ 587,  415,   18,  113,  176],
        [1684,  380,   41,   62,  746],
        [1461,  800,  853,    2,    2],
        [ 137,   97,  609,    0,    0],
        [  85, 1845,    2,    0,    0],
        [ 212,  536,    0,    0,    0],
        [ 574,  214,    0,    0,    0],
        [ 216,  239,    0,    0,    0],
        [  36,  258,    0,    0,    0],
        [1725,   38,    0,    0,    0],
        [ 129,  242,    

In [9]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding

        # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size'
        #   because our input size is a word embedding with number of features == hidden_size
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        # Convert word indexes to embeddings
        embedded = self.embedding(input_seq)
        # Pack padded batch of sequences for RNN module
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        # Forward pass through GRU
        outputs, hidden = self.gru(packed, hidden)
        # Unpack padding
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
        # Sum bidirectional GRU outputs
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
        # Return output and final hidden state
        return outputs, hidden

In [10]:
# Luong attention layer
class Attn(torch.nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = torch.nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t()

        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

In [11]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        # Keep for reference
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # Define layers
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):
        # Note: we run this one step (word) at a time
        # Get embedding of current input word
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        # Forward through unidirectional GRU
        rnn_output, hidden = self.gru(embedded, last_hidden)
        # Calculate attention weights from the current GRU output
        attn_weights = self.attn(rnn_output, encoder_outputs)
        # Multiply attention weights to encoder outputs to get new "weighted sum" context vector
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        # Concatenate weighted context vector and GRU output using Luong eq. 5
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        # Predict next word using Luong eq. 6
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        # Return output and final hidden state
        return output, hidden

In [12]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

In [13]:
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding,
          encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):

    # Zero gradients
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Set device options
    input_variable = input_variable.to(device)
    lengths = lengths.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)

    # Initialize variables
    loss = 0
    print_losses = []
    n_totals = 0

    # Forward pass through encoder
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

    # Create initial decoder input (start with SOS tokens for each sentence)
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)

    # Set initial decoder hidden state to the encoder's final hidden state
    decoder_hidden = encoder_hidden[:decoder.n_layers]

    # Determine if we are using teacher forcing this iteration
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # Forward batch of sequences through decoder one time step at a time
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Teacher forcing: next input is current target
            decoder_input = target_variable[t].view(1, -1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # No teacher forcing: next input is decoder's own current output
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    # Perform backpropatation
    loss.backward()

    # Clip gradients: gradients are modified in place
    _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Adjust model weights
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

In [14]:
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):

    # Load batches for each iteration
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
                      for _ in range(n_iteration)]

    # Initializations
    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1

    # Training loop
    print("Training...")
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        # Extract fields from batch
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss

        # Print progress
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        # Save checkpoint
        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

In [15]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        # Forward input through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            # Forward pass through decoder
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores

In [16]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [indexesFromSentence(voc, sentence)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    # Use appropriate device
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, max_length)
    # indexes -> words
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words


def evaluateInput(encoder, decoder, searcher, voc):
    input_sentence = ''
    while(1):
        try:
            # Get input sentence
            input_sentence = input('> ')
            # Check if it is quit case
            if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalizeString(input_sentence)
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            print('Bot:', ' '.join(output_words))

        except KeyError:
            print("Error: Encountered unknown word.")

In [17]:
# Configure models
model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64

corpus_name = 'search_query'
alldatafile = 'data/Batch_generation_2/all.csv'
traindatafile = 'data/Batch_generation_2/train_step_query_question.csv'

voc, _ = loadPrepareData(corpus_name, alldatafile, save_dir)
_, pairs = loadPrepareData(corpus_name, traindatafile, save_dir)

voc.addWord('')



# Set checkpoint to load from; set to None if starting from scratch
loadFilename = None
checkpoint_iter = 4000
#loadFilename = os.path.join(save_dir, model_name, corpus_name,
#                            '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
#                            '{}_checkpoint.tar'.format(checkpoint_iter))


# Load model if a loadFilename is provided
if loadFilename:
    # If loading on same machine the model was trained on
    checkpoint = torch.load(loadFilename)
    # If loading a model trained on GPU to CPU
    #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']


print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Start preparing training data ...
Reading lines...
Read 2558 sentence pairs
Trimmed to 2558 sentence pairs
Counting words...
Counted words: 4955
Start preparing training data ...
Reading lines...
Read 1789 sentence pairs
Trimmed to 1789 sentence pairs
Counting words...
Counted words: 4242
Building encoder and decoder ...
Models built and ready to go!


In [18]:
def evaluateFile(encoder, decoder, searcher, voc, filename, targetname):
    text = list(csv.reader(open(filename, 'rt')))
    target = list(csv.reader(open(targetname, 'rt')))
    responses = []
#     input_sentence = ''
    for input_sentence in text:
        try:
            # Get input sentence
#             input_sentence = input('> ')
            # Check if it is quit case
#             if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalizeString(input_sentence[0])
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            responses.append(' '.join(output_words))
#             print('Bot:', ' '.join(output_words))
        except KeyError:
            responses.append(' ')
#             print("Error: Encountered unknown word.")
    return text, target, responses

In [19]:
# Configure training/optimization
clip = 50.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
print_every = 1
save_every = 500

# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

# Initialize optimizers
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)


# Run training iterations
print("Starting Training!")
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building optimizers ...
Starting Training!
Initializing ...
Training...
Iteration: 1; Percent complete: 0.0%; Average loss: 8.5049
Iteration: 2; Percent complete: 0.1%; Average loss: 8.4838
Iteration: 3; Percent complete: 0.1%; Average loss: 8.4561
Iteration: 4; Percent complete: 0.1%; Average loss: 8.4229
Iteration: 5; Percent complete: 0.1%; Average loss: 8.3832
Iteration: 6; Percent complete: 0.1%; Average loss: 8.3369
Iteration: 7; Percent complete: 0.2%; Average loss: 8.2127
Iteration: 8; Percent complete: 0.2%; Average loss: 8.0221
Iteration: 9; Percent complete: 0.2%; Average loss: 7.9129
Iteration: 10; Percent complete: 0.2%; Average loss: 7.6109
Iteration: 11; Percent complete: 0.3%; Average loss: 7.2662
Iteration: 12; Percent complete: 0.3%; Average loss: 7.1880
Iteration: 13; Percent complete: 0.3%; Average loss: 7.4995
Iteration: 14; Percent complete: 0.4%; Average loss: 7.5862
Iteration: 15; Percent complete: 0.4%; Average loss: 7.4271
Iteration: 16; Percent complete: 0.4%

Iteration: 137; Percent complete: 3.4%; Average loss: 5.5078
Iteration: 138; Percent complete: 3.5%; Average loss: 5.8630
Iteration: 139; Percent complete: 3.5%; Average loss: 5.6463
Iteration: 140; Percent complete: 3.5%; Average loss: 5.3693
Iteration: 141; Percent complete: 3.5%; Average loss: 5.6880
Iteration: 142; Percent complete: 3.5%; Average loss: 6.1075
Iteration: 143; Percent complete: 3.6%; Average loss: 5.8313
Iteration: 144; Percent complete: 3.6%; Average loss: 5.7129
Iteration: 145; Percent complete: 3.6%; Average loss: 5.6829
Iteration: 146; Percent complete: 3.6%; Average loss: 5.7552
Iteration: 147; Percent complete: 3.7%; Average loss: 5.8091
Iteration: 148; Percent complete: 3.7%; Average loss: 6.1498
Iteration: 149; Percent complete: 3.7%; Average loss: 5.4333
Iteration: 150; Percent complete: 3.8%; Average loss: 5.8937
Iteration: 151; Percent complete: 3.8%; Average loss: 5.3922
Iteration: 152; Percent complete: 3.8%; Average loss: 5.6846
Iteration: 153; Percent 

Iteration: 272; Percent complete: 6.8%; Average loss: 4.6919
Iteration: 273; Percent complete: 6.8%; Average loss: 4.8491
Iteration: 274; Percent complete: 6.9%; Average loss: 5.3302
Iteration: 275; Percent complete: 6.9%; Average loss: 4.9119
Iteration: 276; Percent complete: 6.9%; Average loss: 5.0942
Iteration: 277; Percent complete: 6.9%; Average loss: 5.2709
Iteration: 278; Percent complete: 7.0%; Average loss: 4.5176
Iteration: 279; Percent complete: 7.0%; Average loss: 5.1298
Iteration: 280; Percent complete: 7.0%; Average loss: 4.5606
Iteration: 281; Percent complete: 7.0%; Average loss: 4.8029
Iteration: 282; Percent complete: 7.0%; Average loss: 4.5034
Iteration: 283; Percent complete: 7.1%; Average loss: 5.1475
Iteration: 284; Percent complete: 7.1%; Average loss: 4.7926
Iteration: 285; Percent complete: 7.1%; Average loss: 5.1639
Iteration: 286; Percent complete: 7.1%; Average loss: 4.8843
Iteration: 287; Percent complete: 7.2%; Average loss: 4.9261
Iteration: 288; Percent 

Iteration: 407; Percent complete: 10.2%; Average loss: 3.8800
Iteration: 408; Percent complete: 10.2%; Average loss: 3.7698
Iteration: 409; Percent complete: 10.2%; Average loss: 4.1487
Iteration: 410; Percent complete: 10.2%; Average loss: 3.7958
Iteration: 411; Percent complete: 10.3%; Average loss: 3.8432
Iteration: 412; Percent complete: 10.3%; Average loss: 3.5826
Iteration: 413; Percent complete: 10.3%; Average loss: 3.8659
Iteration: 414; Percent complete: 10.3%; Average loss: 4.0029
Iteration: 415; Percent complete: 10.4%; Average loss: 3.8280
Iteration: 416; Percent complete: 10.4%; Average loss: 3.5111
Iteration: 417; Percent complete: 10.4%; Average loss: 4.0842
Iteration: 418; Percent complete: 10.4%; Average loss: 3.6807
Iteration: 419; Percent complete: 10.5%; Average loss: 3.6384
Iteration: 420; Percent complete: 10.5%; Average loss: 3.8069
Iteration: 421; Percent complete: 10.5%; Average loss: 3.5838
Iteration: 422; Percent complete: 10.5%; Average loss: 3.5902
Iteratio

Iteration: 540; Percent complete: 13.5%; Average loss: 3.0567
Iteration: 541; Percent complete: 13.5%; Average loss: 3.1240
Iteration: 542; Percent complete: 13.6%; Average loss: 3.0432
Iteration: 543; Percent complete: 13.6%; Average loss: 3.2226
Iteration: 544; Percent complete: 13.6%; Average loss: 2.9931
Iteration: 545; Percent complete: 13.6%; Average loss: 3.1367
Iteration: 546; Percent complete: 13.7%; Average loss: 2.9200
Iteration: 547; Percent complete: 13.7%; Average loss: 2.8024
Iteration: 548; Percent complete: 13.7%; Average loss: 2.7388
Iteration: 549; Percent complete: 13.7%; Average loss: 2.7278
Iteration: 550; Percent complete: 13.8%; Average loss: 3.1817
Iteration: 551; Percent complete: 13.8%; Average loss: 2.9359
Iteration: 552; Percent complete: 13.8%; Average loss: 2.6737
Iteration: 553; Percent complete: 13.8%; Average loss: 2.9066
Iteration: 554; Percent complete: 13.9%; Average loss: 3.0006
Iteration: 555; Percent complete: 13.9%; Average loss: 2.7003
Iteratio

Iteration: 675; Percent complete: 16.9%; Average loss: 2.1000
Iteration: 676; Percent complete: 16.9%; Average loss: 2.1227
Iteration: 677; Percent complete: 16.9%; Average loss: 2.0162
Iteration: 678; Percent complete: 17.0%; Average loss: 2.3978
Iteration: 679; Percent complete: 17.0%; Average loss: 2.0079
Iteration: 680; Percent complete: 17.0%; Average loss: 2.0594
Iteration: 681; Percent complete: 17.0%; Average loss: 2.1606
Iteration: 682; Percent complete: 17.1%; Average loss: 1.8726
Iteration: 683; Percent complete: 17.1%; Average loss: 2.2331
Iteration: 684; Percent complete: 17.1%; Average loss: 2.2468
Iteration: 685; Percent complete: 17.1%; Average loss: 2.0680
Iteration: 686; Percent complete: 17.2%; Average loss: 2.0030
Iteration: 687; Percent complete: 17.2%; Average loss: 1.7674
Iteration: 688; Percent complete: 17.2%; Average loss: 2.3861
Iteration: 689; Percent complete: 17.2%; Average loss: 2.1264
Iteration: 690; Percent complete: 17.2%; Average loss: 2.0919
Iteratio

Iteration: 808; Percent complete: 20.2%; Average loss: 1.3653
Iteration: 809; Percent complete: 20.2%; Average loss: 1.7809
Iteration: 810; Percent complete: 20.2%; Average loss: 1.4592
Iteration: 811; Percent complete: 20.3%; Average loss: 1.6569
Iteration: 812; Percent complete: 20.3%; Average loss: 1.6765
Iteration: 813; Percent complete: 20.3%; Average loss: 1.6324
Iteration: 814; Percent complete: 20.3%; Average loss: 1.4003
Iteration: 815; Percent complete: 20.4%; Average loss: 1.2891
Iteration: 816; Percent complete: 20.4%; Average loss: 1.5492
Iteration: 817; Percent complete: 20.4%; Average loss: 1.4329
Iteration: 818; Percent complete: 20.4%; Average loss: 1.5130
Iteration: 819; Percent complete: 20.5%; Average loss: 1.5947
Iteration: 820; Percent complete: 20.5%; Average loss: 1.5829
Iteration: 821; Percent complete: 20.5%; Average loss: 1.3220
Iteration: 822; Percent complete: 20.5%; Average loss: 1.8379
Iteration: 823; Percent complete: 20.6%; Average loss: 1.4028
Iteratio

Iteration: 941; Percent complete: 23.5%; Average loss: 0.9900
Iteration: 942; Percent complete: 23.5%; Average loss: 0.8580
Iteration: 943; Percent complete: 23.6%; Average loss: 1.1967
Iteration: 944; Percent complete: 23.6%; Average loss: 0.9893
Iteration: 945; Percent complete: 23.6%; Average loss: 1.0617
Iteration: 946; Percent complete: 23.6%; Average loss: 1.1141
Iteration: 947; Percent complete: 23.7%; Average loss: 0.8692
Iteration: 948; Percent complete: 23.7%; Average loss: 1.0028
Iteration: 949; Percent complete: 23.7%; Average loss: 0.9492
Iteration: 950; Percent complete: 23.8%; Average loss: 0.9250
Iteration: 951; Percent complete: 23.8%; Average loss: 0.9547
Iteration: 952; Percent complete: 23.8%; Average loss: 1.0907
Iteration: 953; Percent complete: 23.8%; Average loss: 1.0486
Iteration: 954; Percent complete: 23.8%; Average loss: 1.1051
Iteration: 955; Percent complete: 23.9%; Average loss: 0.9046
Iteration: 956; Percent complete: 23.9%; Average loss: 0.9509
Iteratio

Iteration: 1073; Percent complete: 26.8%; Average loss: 0.7442
Iteration: 1074; Percent complete: 26.9%; Average loss: 0.6301
Iteration: 1075; Percent complete: 26.9%; Average loss: 0.6886
Iteration: 1076; Percent complete: 26.9%; Average loss: 0.6459
Iteration: 1077; Percent complete: 26.9%; Average loss: 0.8132
Iteration: 1078; Percent complete: 27.0%; Average loss: 0.6891
Iteration: 1079; Percent complete: 27.0%; Average loss: 0.6401
Iteration: 1080; Percent complete: 27.0%; Average loss: 0.6201
Iteration: 1081; Percent complete: 27.0%; Average loss: 0.5032
Iteration: 1082; Percent complete: 27.1%; Average loss: 0.5167
Iteration: 1083; Percent complete: 27.1%; Average loss: 0.6909
Iteration: 1084; Percent complete: 27.1%; Average loss: 0.5926
Iteration: 1085; Percent complete: 27.1%; Average loss: 0.6572
Iteration: 1086; Percent complete: 27.2%; Average loss: 0.7845
Iteration: 1087; Percent complete: 27.2%; Average loss: 0.6931
Iteration: 1088; Percent complete: 27.2%; Average loss:

Iteration: 1205; Percent complete: 30.1%; Average loss: 0.4882
Iteration: 1206; Percent complete: 30.1%; Average loss: 0.4515
Iteration: 1207; Percent complete: 30.2%; Average loss: 0.3992
Iteration: 1208; Percent complete: 30.2%; Average loss: 0.3858
Iteration: 1209; Percent complete: 30.2%; Average loss: 0.4078
Iteration: 1210; Percent complete: 30.2%; Average loss: 0.4447
Iteration: 1211; Percent complete: 30.3%; Average loss: 0.4512
Iteration: 1212; Percent complete: 30.3%; Average loss: 0.3615
Iteration: 1213; Percent complete: 30.3%; Average loss: 0.4404
Iteration: 1214; Percent complete: 30.3%; Average loss: 0.3798
Iteration: 1215; Percent complete: 30.4%; Average loss: 0.3659
Iteration: 1216; Percent complete: 30.4%; Average loss: 0.4251
Iteration: 1217; Percent complete: 30.4%; Average loss: 0.4353
Iteration: 1218; Percent complete: 30.4%; Average loss: 0.4461
Iteration: 1219; Percent complete: 30.5%; Average loss: 0.4001
Iteration: 1220; Percent complete: 30.5%; Average loss:

Iteration: 1336; Percent complete: 33.4%; Average loss: 0.3267
Iteration: 1337; Percent complete: 33.4%; Average loss: 0.3168
Iteration: 1338; Percent complete: 33.5%; Average loss: 0.2823
Iteration: 1339; Percent complete: 33.5%; Average loss: 0.2573
Iteration: 1340; Percent complete: 33.5%; Average loss: 0.2881
Iteration: 1341; Percent complete: 33.5%; Average loss: 0.2742
Iteration: 1342; Percent complete: 33.6%; Average loss: 0.2441
Iteration: 1343; Percent complete: 33.6%; Average loss: 0.2991
Iteration: 1344; Percent complete: 33.6%; Average loss: 0.2821
Iteration: 1345; Percent complete: 33.6%; Average loss: 0.3688
Iteration: 1346; Percent complete: 33.7%; Average loss: 0.3114
Iteration: 1347; Percent complete: 33.7%; Average loss: 0.2806
Iteration: 1348; Percent complete: 33.7%; Average loss: 0.2824
Iteration: 1349; Percent complete: 33.7%; Average loss: 0.2692
Iteration: 1350; Percent complete: 33.8%; Average loss: 0.3269
Iteration: 1351; Percent complete: 33.8%; Average loss:

Iteration: 1468; Percent complete: 36.7%; Average loss: 0.1523
Iteration: 1469; Percent complete: 36.7%; Average loss: 0.1716
Iteration: 1470; Percent complete: 36.8%; Average loss: 0.1758
Iteration: 1471; Percent complete: 36.8%; Average loss: 0.1873
Iteration: 1472; Percent complete: 36.8%; Average loss: 0.2017
Iteration: 1473; Percent complete: 36.8%; Average loss: 0.1901
Iteration: 1474; Percent complete: 36.9%; Average loss: 0.1731
Iteration: 1475; Percent complete: 36.9%; Average loss: 0.1952
Iteration: 1476; Percent complete: 36.9%; Average loss: 0.1649
Iteration: 1477; Percent complete: 36.9%; Average loss: 0.1853
Iteration: 1478; Percent complete: 37.0%; Average loss: 0.1773
Iteration: 1479; Percent complete: 37.0%; Average loss: 0.1843
Iteration: 1480; Percent complete: 37.0%; Average loss: 0.1481
Iteration: 1481; Percent complete: 37.0%; Average loss: 0.2035
Iteration: 1482; Percent complete: 37.0%; Average loss: 0.1675
Iteration: 1483; Percent complete: 37.1%; Average loss:

Iteration: 1600; Percent complete: 40.0%; Average loss: 0.1480
Iteration: 1601; Percent complete: 40.0%; Average loss: 0.1023
Iteration: 1602; Percent complete: 40.1%; Average loss: 0.1337
Iteration: 1603; Percent complete: 40.1%; Average loss: 0.1352
Iteration: 1604; Percent complete: 40.1%; Average loss: 0.1362
Iteration: 1605; Percent complete: 40.1%; Average loss: 0.1166
Iteration: 1606; Percent complete: 40.2%; Average loss: 0.1429
Iteration: 1607; Percent complete: 40.2%; Average loss: 0.1130
Iteration: 1608; Percent complete: 40.2%; Average loss: 0.1229
Iteration: 1609; Percent complete: 40.2%; Average loss: 0.1144
Iteration: 1610; Percent complete: 40.2%; Average loss: 0.1090
Iteration: 1611; Percent complete: 40.3%; Average loss: 0.1189
Iteration: 1612; Percent complete: 40.3%; Average loss: 0.1326
Iteration: 1613; Percent complete: 40.3%; Average loss: 0.1168
Iteration: 1614; Percent complete: 40.4%; Average loss: 0.1190
Iteration: 1615; Percent complete: 40.4%; Average loss:

Iteration: 1731; Percent complete: 43.3%; Average loss: 0.0813
Iteration: 1732; Percent complete: 43.3%; Average loss: 0.0897
Iteration: 1733; Percent complete: 43.3%; Average loss: 0.0825
Iteration: 1734; Percent complete: 43.4%; Average loss: 0.0884
Iteration: 1735; Percent complete: 43.4%; Average loss: 0.0837
Iteration: 1736; Percent complete: 43.4%; Average loss: 0.0753
Iteration: 1737; Percent complete: 43.4%; Average loss: 0.1067
Iteration: 1738; Percent complete: 43.5%; Average loss: 0.0968
Iteration: 1739; Percent complete: 43.5%; Average loss: 0.0857
Iteration: 1740; Percent complete: 43.5%; Average loss: 0.0952
Iteration: 1741; Percent complete: 43.5%; Average loss: 0.0745
Iteration: 1742; Percent complete: 43.5%; Average loss: 0.0765
Iteration: 1743; Percent complete: 43.6%; Average loss: 0.0759
Iteration: 1744; Percent complete: 43.6%; Average loss: 0.0797
Iteration: 1745; Percent complete: 43.6%; Average loss: 0.1004
Iteration: 1746; Percent complete: 43.6%; Average loss:

Iteration: 1862; Percent complete: 46.6%; Average loss: 0.0565
Iteration: 1863; Percent complete: 46.6%; Average loss: 0.0744
Iteration: 1864; Percent complete: 46.6%; Average loss: 0.0723
Iteration: 1865; Percent complete: 46.6%; Average loss: 0.0574
Iteration: 1866; Percent complete: 46.7%; Average loss: 0.0611
Iteration: 1867; Percent complete: 46.7%; Average loss: 0.0592
Iteration: 1868; Percent complete: 46.7%; Average loss: 0.0592
Iteration: 1869; Percent complete: 46.7%; Average loss: 0.0701
Iteration: 1870; Percent complete: 46.8%; Average loss: 0.0581
Iteration: 1871; Percent complete: 46.8%; Average loss: 0.0563
Iteration: 1872; Percent complete: 46.8%; Average loss: 0.0583
Iteration: 1873; Percent complete: 46.8%; Average loss: 0.0512
Iteration: 1874; Percent complete: 46.9%; Average loss: 0.0592
Iteration: 1875; Percent complete: 46.9%; Average loss: 0.0587
Iteration: 1876; Percent complete: 46.9%; Average loss: 0.0619
Iteration: 1877; Percent complete: 46.9%; Average loss:

Iteration: 1993; Percent complete: 49.8%; Average loss: 0.0452
Iteration: 1994; Percent complete: 49.9%; Average loss: 0.0532
Iteration: 1995; Percent complete: 49.9%; Average loss: 0.0459
Iteration: 1996; Percent complete: 49.9%; Average loss: 0.0501
Iteration: 1997; Percent complete: 49.9%; Average loss: 0.0436
Iteration: 1998; Percent complete: 50.0%; Average loss: 0.0452
Iteration: 1999; Percent complete: 50.0%; Average loss: 0.0553
Iteration: 2000; Percent complete: 50.0%; Average loss: 0.0491
Iteration: 2001; Percent complete: 50.0%; Average loss: 0.0529
Iteration: 2002; Percent complete: 50.0%; Average loss: 0.0428
Iteration: 2003; Percent complete: 50.1%; Average loss: 0.0557
Iteration: 2004; Percent complete: 50.1%; Average loss: 0.0516
Iteration: 2005; Percent complete: 50.1%; Average loss: 0.0416
Iteration: 2006; Percent complete: 50.1%; Average loss: 0.0435
Iteration: 2007; Percent complete: 50.2%; Average loss: 0.0536
Iteration: 2008; Percent complete: 50.2%; Average loss:

Iteration: 2124; Percent complete: 53.1%; Average loss: 0.0374
Iteration: 2125; Percent complete: 53.1%; Average loss: 0.0344
Iteration: 2126; Percent complete: 53.1%; Average loss: 0.0352
Iteration: 2127; Percent complete: 53.2%; Average loss: 0.0350
Iteration: 2128; Percent complete: 53.2%; Average loss: 0.0402
Iteration: 2129; Percent complete: 53.2%; Average loss: 0.0372
Iteration: 2130; Percent complete: 53.2%; Average loss: 0.0312
Iteration: 2131; Percent complete: 53.3%; Average loss: 0.0385
Iteration: 2132; Percent complete: 53.3%; Average loss: 0.0390
Iteration: 2133; Percent complete: 53.3%; Average loss: 0.0343
Iteration: 2134; Percent complete: 53.3%; Average loss: 0.0417
Iteration: 2135; Percent complete: 53.4%; Average loss: 0.0332
Iteration: 2136; Percent complete: 53.4%; Average loss: 0.0361
Iteration: 2137; Percent complete: 53.4%; Average loss: 0.0378
Iteration: 2138; Percent complete: 53.4%; Average loss: 0.0348
Iteration: 2139; Percent complete: 53.5%; Average loss:

Iteration: 2257; Percent complete: 56.4%; Average loss: 0.0287
Iteration: 2258; Percent complete: 56.5%; Average loss: 0.0305
Iteration: 2259; Percent complete: 56.5%; Average loss: 0.0283
Iteration: 2260; Percent complete: 56.5%; Average loss: 0.0316
Iteration: 2261; Percent complete: 56.5%; Average loss: 0.0355
Iteration: 2262; Percent complete: 56.5%; Average loss: 0.0297
Iteration: 2263; Percent complete: 56.6%; Average loss: 0.0304
Iteration: 2264; Percent complete: 56.6%; Average loss: 0.0355
Iteration: 2265; Percent complete: 56.6%; Average loss: 0.0312
Iteration: 2266; Percent complete: 56.6%; Average loss: 0.0348
Iteration: 2267; Percent complete: 56.7%; Average loss: 0.0370
Iteration: 2268; Percent complete: 56.7%; Average loss: 0.0266
Iteration: 2269; Percent complete: 56.7%; Average loss: 0.0333
Iteration: 2270; Percent complete: 56.8%; Average loss: 0.0268
Iteration: 2271; Percent complete: 56.8%; Average loss: 0.0302
Iteration: 2272; Percent complete: 56.8%; Average loss:

Iteration: 2389; Percent complete: 59.7%; Average loss: 0.0421
Iteration: 2390; Percent complete: 59.8%; Average loss: 0.0287
Iteration: 2391; Percent complete: 59.8%; Average loss: 0.0287
Iteration: 2392; Percent complete: 59.8%; Average loss: 0.0281
Iteration: 2393; Percent complete: 59.8%; Average loss: 0.0281
Iteration: 2394; Percent complete: 59.9%; Average loss: 0.0274
Iteration: 2395; Percent complete: 59.9%; Average loss: 0.0268
Iteration: 2396; Percent complete: 59.9%; Average loss: 0.0306
Iteration: 2397; Percent complete: 59.9%; Average loss: 0.0290
Iteration: 2398; Percent complete: 60.0%; Average loss: 0.0355
Iteration: 2399; Percent complete: 60.0%; Average loss: 0.0264
Iteration: 2400; Percent complete: 60.0%; Average loss: 0.0309
Iteration: 2401; Percent complete: 60.0%; Average loss: 0.0386
Iteration: 2402; Percent complete: 60.1%; Average loss: 0.0287
Iteration: 2403; Percent complete: 60.1%; Average loss: 0.0276
Iteration: 2404; Percent complete: 60.1%; Average loss:

Iteration: 2520; Percent complete: 63.0%; Average loss: 0.0378
Iteration: 2521; Percent complete: 63.0%; Average loss: 0.0414
Iteration: 2522; Percent complete: 63.0%; Average loss: 0.0476
Iteration: 2523; Percent complete: 63.1%; Average loss: 0.0472
Iteration: 2524; Percent complete: 63.1%; Average loss: 0.0392
Iteration: 2525; Percent complete: 63.1%; Average loss: 0.0722
Iteration: 2526; Percent complete: 63.1%; Average loss: 0.0396
Iteration: 2527; Percent complete: 63.2%; Average loss: 0.0524
Iteration: 2528; Percent complete: 63.2%; Average loss: 0.0292
Iteration: 2529; Percent complete: 63.2%; Average loss: 0.0588
Iteration: 2530; Percent complete: 63.2%; Average loss: 0.0617
Iteration: 2531; Percent complete: 63.3%; Average loss: 0.0276
Iteration: 2532; Percent complete: 63.3%; Average loss: 0.0430
Iteration: 2533; Percent complete: 63.3%; Average loss: 0.0413
Iteration: 2534; Percent complete: 63.3%; Average loss: 0.0450
Iteration: 2535; Percent complete: 63.4%; Average loss:

Iteration: 2651; Percent complete: 66.3%; Average loss: 0.0378
Iteration: 2652; Percent complete: 66.3%; Average loss: 0.0320
Iteration: 2653; Percent complete: 66.3%; Average loss: 0.0366
Iteration: 2654; Percent complete: 66.3%; Average loss: 0.0378
Iteration: 2655; Percent complete: 66.4%; Average loss: 0.0570
Iteration: 2656; Percent complete: 66.4%; Average loss: 0.0315
Iteration: 2657; Percent complete: 66.4%; Average loss: 0.0316
Iteration: 2658; Percent complete: 66.5%; Average loss: 0.0469
Iteration: 2659; Percent complete: 66.5%; Average loss: 0.0258
Iteration: 2660; Percent complete: 66.5%; Average loss: 0.0366
Iteration: 2661; Percent complete: 66.5%; Average loss: 0.0288
Iteration: 2662; Percent complete: 66.5%; Average loss: 0.0274
Iteration: 2663; Percent complete: 66.6%; Average loss: 0.0450
Iteration: 2664; Percent complete: 66.6%; Average loss: 0.0308
Iteration: 2665; Percent complete: 66.6%; Average loss: 0.0257
Iteration: 2666; Percent complete: 66.6%; Average loss:

Iteration: 2782; Percent complete: 69.5%; Average loss: 0.0841
Iteration: 2783; Percent complete: 69.6%; Average loss: 0.1038
Iteration: 2784; Percent complete: 69.6%; Average loss: 0.0967
Iteration: 2785; Percent complete: 69.6%; Average loss: 0.0950
Iteration: 2786; Percent complete: 69.7%; Average loss: 0.0782
Iteration: 2787; Percent complete: 69.7%; Average loss: 0.0956
Iteration: 2788; Percent complete: 69.7%; Average loss: 0.0981
Iteration: 2789; Percent complete: 69.7%; Average loss: 0.0948
Iteration: 2790; Percent complete: 69.8%; Average loss: 0.0714
Iteration: 2791; Percent complete: 69.8%; Average loss: 0.0974
Iteration: 2792; Percent complete: 69.8%; Average loss: 0.0735
Iteration: 2793; Percent complete: 69.8%; Average loss: 0.1070
Iteration: 2794; Percent complete: 69.8%; Average loss: 0.0914
Iteration: 2795; Percent complete: 69.9%; Average loss: 0.1191
Iteration: 2796; Percent complete: 69.9%; Average loss: 0.0635
Iteration: 2797; Percent complete: 69.9%; Average loss:

Iteration: 2913; Percent complete: 72.8%; Average loss: 0.0467
Iteration: 2914; Percent complete: 72.9%; Average loss: 0.0411
Iteration: 2915; Percent complete: 72.9%; Average loss: 0.0454
Iteration: 2916; Percent complete: 72.9%; Average loss: 0.0487
Iteration: 2917; Percent complete: 72.9%; Average loss: 0.0383
Iteration: 2918; Percent complete: 73.0%; Average loss: 0.0530
Iteration: 2919; Percent complete: 73.0%; Average loss: 0.0431
Iteration: 2920; Percent complete: 73.0%; Average loss: 0.0385
Iteration: 2921; Percent complete: 73.0%; Average loss: 0.0390
Iteration: 2922; Percent complete: 73.0%; Average loss: 0.0546
Iteration: 2923; Percent complete: 73.1%; Average loss: 0.0437
Iteration: 2924; Percent complete: 73.1%; Average loss: 0.0686
Iteration: 2925; Percent complete: 73.1%; Average loss: 0.0707
Iteration: 2926; Percent complete: 73.2%; Average loss: 0.0356
Iteration: 2927; Percent complete: 73.2%; Average loss: 0.0403
Iteration: 2928; Percent complete: 73.2%; Average loss:

Iteration: 3046; Percent complete: 76.1%; Average loss: 0.0379
Iteration: 3047; Percent complete: 76.2%; Average loss: 0.0800
Iteration: 3048; Percent complete: 76.2%; Average loss: 0.0430
Iteration: 3049; Percent complete: 76.2%; Average loss: 0.0387
Iteration: 3050; Percent complete: 76.2%; Average loss: 0.0327
Iteration: 3051; Percent complete: 76.3%; Average loss: 0.1544
Iteration: 3052; Percent complete: 76.3%; Average loss: 0.0353
Iteration: 3053; Percent complete: 76.3%; Average loss: 0.0425
Iteration: 3054; Percent complete: 76.3%; Average loss: 0.0419
Iteration: 3055; Percent complete: 76.4%; Average loss: 0.0406
Iteration: 3056; Percent complete: 76.4%; Average loss: 0.0356
Iteration: 3057; Percent complete: 76.4%; Average loss: 0.0678
Iteration: 3058; Percent complete: 76.4%; Average loss: 0.0330
Iteration: 3059; Percent complete: 76.5%; Average loss: 0.0615
Iteration: 3060; Percent complete: 76.5%; Average loss: 0.0436
Iteration: 3061; Percent complete: 76.5%; Average loss:

Iteration: 3178; Percent complete: 79.5%; Average loss: 0.0233
Iteration: 3179; Percent complete: 79.5%; Average loss: 0.0286
Iteration: 3180; Percent complete: 79.5%; Average loss: 0.0238
Iteration: 3181; Percent complete: 79.5%; Average loss: 0.0297
Iteration: 3182; Percent complete: 79.5%; Average loss: 0.0210
Iteration: 3183; Percent complete: 79.6%; Average loss: 0.0224
Iteration: 3184; Percent complete: 79.6%; Average loss: 0.0709
Iteration: 3185; Percent complete: 79.6%; Average loss: 0.0418
Iteration: 3186; Percent complete: 79.7%; Average loss: 0.0263
Iteration: 3187; Percent complete: 79.7%; Average loss: 0.0240
Iteration: 3188; Percent complete: 79.7%; Average loss: 0.0284
Iteration: 3189; Percent complete: 79.7%; Average loss: 0.0245
Iteration: 3190; Percent complete: 79.8%; Average loss: 0.0266
Iteration: 3191; Percent complete: 79.8%; Average loss: 0.0273
Iteration: 3192; Percent complete: 79.8%; Average loss: 0.0241
Iteration: 3193; Percent complete: 79.8%; Average loss:

Iteration: 3309; Percent complete: 82.7%; Average loss: 0.0156
Iteration: 3310; Percent complete: 82.8%; Average loss: 0.0253
Iteration: 3311; Percent complete: 82.8%; Average loss: 0.0177
Iteration: 3312; Percent complete: 82.8%; Average loss: 0.0188
Iteration: 3313; Percent complete: 82.8%; Average loss: 0.0168
Iteration: 3314; Percent complete: 82.8%; Average loss: 0.0162
Iteration: 3315; Percent complete: 82.9%; Average loss: 0.0141
Iteration: 3316; Percent complete: 82.9%; Average loss: 0.0168
Iteration: 3317; Percent complete: 82.9%; Average loss: 0.0164
Iteration: 3318; Percent complete: 83.0%; Average loss: 0.0181
Iteration: 3319; Percent complete: 83.0%; Average loss: 0.0148
Iteration: 3320; Percent complete: 83.0%; Average loss: 0.0152
Iteration: 3321; Percent complete: 83.0%; Average loss: 0.0181
Iteration: 3322; Percent complete: 83.0%; Average loss: 0.0189
Iteration: 3323; Percent complete: 83.1%; Average loss: 0.0227
Iteration: 3324; Percent complete: 83.1%; Average loss:

Iteration: 3441; Percent complete: 86.0%; Average loss: 0.0146
Iteration: 3442; Percent complete: 86.1%; Average loss: 0.0147
Iteration: 3443; Percent complete: 86.1%; Average loss: 0.0154
Iteration: 3444; Percent complete: 86.1%; Average loss: 0.0189
Iteration: 3445; Percent complete: 86.1%; Average loss: 0.0197
Iteration: 3446; Percent complete: 86.2%; Average loss: 0.0135
Iteration: 3447; Percent complete: 86.2%; Average loss: 0.0157
Iteration: 3448; Percent complete: 86.2%; Average loss: 0.0185
Iteration: 3449; Percent complete: 86.2%; Average loss: 0.0115
Iteration: 3450; Percent complete: 86.2%; Average loss: 0.0142
Iteration: 3451; Percent complete: 86.3%; Average loss: 0.0233
Iteration: 3452; Percent complete: 86.3%; Average loss: 0.0184
Iteration: 3453; Percent complete: 86.3%; Average loss: 0.0175
Iteration: 3454; Percent complete: 86.4%; Average loss: 0.0765
Iteration: 3455; Percent complete: 86.4%; Average loss: 0.0135
Iteration: 3456; Percent complete: 86.4%; Average loss:

Iteration: 3572; Percent complete: 89.3%; Average loss: 0.0106
Iteration: 3573; Percent complete: 89.3%; Average loss: 0.0118
Iteration: 3574; Percent complete: 89.3%; Average loss: 0.0119
Iteration: 3575; Percent complete: 89.4%; Average loss: 0.0119
Iteration: 3576; Percent complete: 89.4%; Average loss: 0.0098
Iteration: 3577; Percent complete: 89.4%; Average loss: 0.0133
Iteration: 3578; Percent complete: 89.5%; Average loss: 0.0110
Iteration: 3579; Percent complete: 89.5%; Average loss: 0.0137
Iteration: 3580; Percent complete: 89.5%; Average loss: 0.0125
Iteration: 3581; Percent complete: 89.5%; Average loss: 0.0132
Iteration: 3582; Percent complete: 89.5%; Average loss: 0.0114
Iteration: 3583; Percent complete: 89.6%; Average loss: 0.0137
Iteration: 3584; Percent complete: 89.6%; Average loss: 0.0187
Iteration: 3585; Percent complete: 89.6%; Average loss: 0.0105
Iteration: 3586; Percent complete: 89.6%; Average loss: 0.0147
Iteration: 3587; Percent complete: 89.7%; Average loss:

Iteration: 3704; Percent complete: 92.6%; Average loss: 0.0100
Iteration: 3705; Percent complete: 92.6%; Average loss: 0.0107
Iteration: 3706; Percent complete: 92.7%; Average loss: 0.0086
Iteration: 3707; Percent complete: 92.7%; Average loss: 0.0092
Iteration: 3708; Percent complete: 92.7%; Average loss: 0.0102
Iteration: 3709; Percent complete: 92.7%; Average loss: 0.0098
Iteration: 3710; Percent complete: 92.8%; Average loss: 0.0109
Iteration: 3711; Percent complete: 92.8%; Average loss: 0.0108
Iteration: 3712; Percent complete: 92.8%; Average loss: 0.0092
Iteration: 3713; Percent complete: 92.8%; Average loss: 0.0075
Iteration: 3714; Percent complete: 92.8%; Average loss: 0.0103
Iteration: 3715; Percent complete: 92.9%; Average loss: 0.0162
Iteration: 3716; Percent complete: 92.9%; Average loss: 0.0094
Iteration: 3717; Percent complete: 92.9%; Average loss: 0.0078
Iteration: 3718; Percent complete: 93.0%; Average loss: 0.0082
Iteration: 3719; Percent complete: 93.0%; Average loss:

Iteration: 3836; Percent complete: 95.9%; Average loss: 0.0087
Iteration: 3837; Percent complete: 95.9%; Average loss: 0.0073
Iteration: 3838; Percent complete: 96.0%; Average loss: 0.0084
Iteration: 3839; Percent complete: 96.0%; Average loss: 0.0073
Iteration: 3840; Percent complete: 96.0%; Average loss: 0.0091
Iteration: 3841; Percent complete: 96.0%; Average loss: 0.0080
Iteration: 3842; Percent complete: 96.0%; Average loss: 0.0106
Iteration: 3843; Percent complete: 96.1%; Average loss: 0.0082
Iteration: 3844; Percent complete: 96.1%; Average loss: 0.0073
Iteration: 3845; Percent complete: 96.1%; Average loss: 0.0064
Iteration: 3846; Percent complete: 96.2%; Average loss: 0.0125
Iteration: 3847; Percent complete: 96.2%; Average loss: 0.0135
Iteration: 3848; Percent complete: 96.2%; Average loss: 0.0078
Iteration: 3849; Percent complete: 96.2%; Average loss: 0.0077
Iteration: 3850; Percent complete: 96.2%; Average loss: 0.0080
Iteration: 3851; Percent complete: 96.3%; Average loss:

Iteration: 3967; Percent complete: 99.2%; Average loss: 0.0107
Iteration: 3968; Percent complete: 99.2%; Average loss: 0.0067
Iteration: 3969; Percent complete: 99.2%; Average loss: 0.0067
Iteration: 3970; Percent complete: 99.2%; Average loss: 0.0073
Iteration: 3971; Percent complete: 99.3%; Average loss: 0.0065
Iteration: 3972; Percent complete: 99.3%; Average loss: 0.0060
Iteration: 3973; Percent complete: 99.3%; Average loss: 0.0062
Iteration: 3974; Percent complete: 99.4%; Average loss: 0.0061
Iteration: 3975; Percent complete: 99.4%; Average loss: 0.0067
Iteration: 3976; Percent complete: 99.4%; Average loss: 0.0073
Iteration: 3977; Percent complete: 99.4%; Average loss: 0.0056
Iteration: 3978; Percent complete: 99.5%; Average loss: 0.0074
Iteration: 3979; Percent complete: 99.5%; Average loss: 0.0069
Iteration: 3980; Percent complete: 99.5%; Average loss: 0.0071
Iteration: 3981; Percent complete: 99.5%; Average loss: 0.0064
Iteration: 3982; Percent complete: 99.6%; Average loss:

Run Evaluation
~~~~~~~~~~~~~~

To chat with your model, run the following block.




In [21]:
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

# Begin chatting (uncomment and run the following line to begin)
evaluateInput(encoder, decoder, searcher, voc)

> cook chicken
Bot: type macaroni recommend thi recip get minut recip


KeyboardInterrupt: 

Conclusion
----------

That’s all for this one, folks. Congratulations, you now know the
fundamentals to building a generative chatbot model! If you’re
interested, you can try tailoring the chatbot’s behavior by tweaking the
model and training parameters and customizing the data that you train
the model on.

Check out the other tutorials for more cool deep learning applications
in PyTorch!




In [20]:
filename = 'data/Batch_generation_2/test_step_query_question_text.csv'
targetname = 'data/Batch_generation_2/test_step_query_question_target.csv'
encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

# Begin chatting (uncomment and run the following line to begin)
text, target, responses = evaluateFile(encoder, decoder, searcher, voc, filename=filename, targetname=targetname)

with open('data/Batch_generation_2/test_step_query_question_predict.csv', 'wt') as f:
    for l in responses:
        f.write('{}\n'.format(l))

In [21]:
import os
res = os.popen('perl multi_bleu.perl data/Batch_generation_2/test_step_query_question_target.csv < data/Batch_generation_2/test_step_query_question_predict.csv')
res.read()

'BLEU = 9.79, 15.4/9.3/8.0/7.9 (BP=1.000, ratio=1.022, hyp_len=3127, ref_len=3059)\n'

In [26]:
res.close()

512

In [22]:
with open('data/Batch_generation_2/seq2seq_step_query_question_pred.txt', 'wt') as f:
    for i, (t, g, p) in enumerate(zip(text, target, responses)):
        print('\n\n\n')
        f.write('\n\n\n')
        if len(g) == 0:
            g = [None]
        print('TEXT:\n{}\nTARGET:\n{}\nPREDICT:\n{}'.format(t[0], g[0], p.strip()))
        f.write('TEXT:\n{}\nTARGET:\n{}\nPREDICT:\n{}'.format(t[0], g[0], p.strip()))
#     if i == 100:
#         break





TEXT:
additional recipes barbecue ideas available www walmart ca recipes buy raw materials walmart recipes buy raw materials make recipes
TARGET:
use raw material bought walmart make walmart recipes process easier
PREDICT:
advertisement non slices




TEXT:
combine ingredients pour ice sugared rim glass fizzy pink drink recipe sugared rim glass
TARGET:
sugared rim glass
PREDICT:
chickpeas turn black beetroot bit softer teaspoon dried




TEXT:
using pastry brush coat sides pork tenderloin place roasting pan rack honey mustard pork loin pastry brush
TARGET:
happens pastry brush
PREDICT:
long heat flour ground use bowl butter better




TEXT:
preheat oven 375 degrees f guarantee scones donγçöt stick line baking sheet parchment paper oven temperature scones different altitudes scones bake differently different altitudes
TARGET:
temperature needs change altitude
PREDICT:
need know parchment paper prevent firm tofu less water firm




TEXT:
instruction given package long mix ingredients

In [33]:
target

[[],
 [],
 [],
 [],
 ['cups packed light brown sugar tablespoons margarine tablespoons vegetable shortening cups dark molasses tablespoon baking soda cup boiling water cups all purpose flour sifted tablespoon ground cloves tablespoons ground ginger tablespoon ground cinnamon'],
 [],
 [],
 [],
 ['image of finished product'],
 [],
 [],
 ['peeled and cut into inch thick rounds bechamel sauce cup butter cups hot milk tablespoons flour eggs cup grated kefalograviera cheese or parm teaspoon salt'],
 [],
 ['visualized instruction'],
 [],
 ['how the chicken will look like after frying .'],
 [],
 ['tell exact amounts of ingredients'],
 ['place the cookie crusts in the freezer and make the banana ice cream'],
 ['show a video of this step and show the correct consistency .'],
 [],
 [],
 [],
 ['visualized instructions for clarifying purposes'],
 ['make the pink coconut cream .'],
 [],
 [],
 ['baking dish'],
 ['show images or video of this step . suggest turning the oven on or have the assistant tu