In [1]:
%matplotlib inline

In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math


USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

In [3]:
# corpus_name = "cornell movie-dialogs corpus"
# corpus = os.path.join("data", corpus_name)

def printLines(file, n=10):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

# printLines(os.path.join(corpus, "movie_lines.txt"))

In [4]:
# Splits each line of the file into a dictionary of fields
def loadLines(fileName, fields):
    lines = {}
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            lineObj = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj
    return lines


# Groups fields of lines from `loadLines` into conversations based on *movie_conversations.txt*
def loadConversations(fileName, lines, fields):
    conversations = []
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            convObj = {}
            for i, field in enumerate(fields):
                convObj[field] = values[i]
            # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
            lineIds = eval(convObj["utteranceIDs"])
            # Reassemble lines
            convObj["lines"] = []
            for lineId in lineIds:
                convObj["lines"].append(lines[lineId])
            conversations.append(convObj)
    return conversations


# Extracts pairs of sentences from conversations
def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        # Iterate over all the lines of the conversation
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()
            # Filter wrong samples (if one of the lists is empty)
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

Now we’ll call these functions and create the file. We’ll call it
*formatted_movie_lines.txt*.




In [5]:
# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)

In [7]:
MAX_LENGTH = 10  # Maximum sentence length to consider

# Turn a Unicode string to plain ASCII, thanks to
# http://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

# Read query/response pairs and return a voc object
def readVocs(datafile, corpus_name):
    print("Reading lines...")
    # Read the file and split into lines
    lines = open(datafile, encoding='utf-8').\
        read().strip().split('\n')
    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split(',')] for l in lines[1:]]
    voc = Voc(corpus_name)
    return voc, pairs

import ipdb
# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filterPair(p):
    # Input sequences need to preserve the last word for EOS token
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

# Using the functions defined above, return a populated voc object and pairs list
def loadPrepareData(corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
#     pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs

corpus_name = 'sarch_query'
datafile = 'data/Batch_3648643_batch_results_rob/train_step_query.csv'

# Load/Assemble voc and pairs
save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus_name, datafile, save_dir)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
    print(pair)

Start preparing training data ...
Reading lines...
Read 156 sentence pairs
Trimmed to 156 sentence pairs
Counting words...
Counted words: 971

pairs:
['mix crush oreo cooki butter press firmli base inch cake pan put refriger chill firm make oreo hazlenut banana cheesecak', 'specfic space instruct detail']
['take pork steak inch thick place one gallon zip loc bag add one cup appl cider bag let marin hour overnight make oven pork steak', 'specfic space instruct detail']
['sift flour sugar cocoa larg bowl sift', 'use sifter sift ingredi bowl']
['heat ghee fri rava temperatur best pan fri rava', 'temperatur want fri rava exampl medium high heat']
['meantim add pasta water boil stir often keep stick togeth pasta cook accord tast drain coland return pot toss tbsp butter cup chop parsley serv chicken pasta green salad desir green salad', 'would green salad consist']
['cook dice onion coconut oil transluc keep food stick', 'occasion stir onion prevent stick bottom pan']
['mix togeth quich base

In [8]:
MIN_COUNT = 3    # Minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        # Only keep pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs


# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 407 / 968 = 0.4205
Trimmed from 156 pairs to 14, 0.0897 of total


In [9]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]


def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

# Returns padded input sequence tensor and lengths
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.ByteTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len


# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[242,  30, 288, 205, 351],
        [316,  76, 257,  46,  38],
        [129, 127,  70,  93, 108],
        [ 30, 142, 126,   3, 297],
        [ 41, 127, 219, 107,   2],
        [113,  76,  97,  63,   0],
        [126,  30,  70,  45,   0],
        [227,  16, 223,  93,   0],
        [131, 275,  38,   2,   0],
        [ 30, 276,   2,   0,   0],
        [ 31,   2,   0,   0,   0],
        [236,   0,   0,   0,   0],
        [ 55,   0,   0,   0,   0],
        [136,   0,   0,   0,   0],
        [  2,   0,   0,   0,   0]])
lengths: tensor([15, 11, 10,  9,  5])
target_variable: tensor([[147, 264, 289, 147, 297],
        [ 34, 131, 185,  94,   2],
        [157, 110,  70,   2,   0],
        [170,   2, 290,   0,   0],
        [  2,   0,  97,   0,   0],
        [  0,   0,   2,   0,   0]])
mask: tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 0, 0],
        [1, 0, 1, 0, 0],
        [0, 0, 1, 0, 0]], dtype=torch.uint8)
max_target_len: 

In [10]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding

        # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size'
        #   because our input size is a word embedding with number of features == hidden_size
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        # Convert word indexes to embeddings
        embedded = self.embedding(input_seq)
        # Pack padded batch of sequences for RNN module
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        # Forward pass through GRU
        outputs, hidden = self.gru(packed, hidden)
        # Unpack padding
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
        # Sum bidirectional GRU outputs
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
        # Return output and final hidden state
        return outputs, hidden

In [11]:
# Luong attention layer
class Attn(torch.nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = torch.nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t()

        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

In [12]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        # Keep for reference
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # Define layers
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):
        # Note: we run this one step (word) at a time
        # Get embedding of current input word
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        # Forward through unidirectional GRU
        rnn_output, hidden = self.gru(embedded, last_hidden)
        # Calculate attention weights from the current GRU output
        attn_weights = self.attn(rnn_output, encoder_outputs)
        # Multiply attention weights to encoder outputs to get new "weighted sum" context vector
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        # Concatenate weighted context vector and GRU output using Luong eq. 5
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        # Predict next word using Luong eq. 6
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        # Return output and final hidden state
        return output, hidden

In [13]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

In [14]:
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding,
          encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):

    # Zero gradients
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Set device options
    input_variable = input_variable.to(device)
    lengths = lengths.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)

    # Initialize variables
    loss = 0
    print_losses = []
    n_totals = 0

    # Forward pass through encoder
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

    # Create initial decoder input (start with SOS tokens for each sentence)
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)

    # Set initial decoder hidden state to the encoder's final hidden state
    decoder_hidden = encoder_hidden[:decoder.n_layers]

    # Determine if we are using teacher forcing this iteration
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # Forward batch of sequences through decoder one time step at a time
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Teacher forcing: next input is current target
            decoder_input = target_variable[t].view(1, -1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # No teacher forcing: next input is decoder's own current output
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    # Perform backpropatation
    loss.backward()

    # Clip gradients: gradients are modified in place
    _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Adjust model weights
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

In [15]:
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):

    # Load batches for each iteration
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
                      for _ in range(n_iteration)]

    # Initializations
    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1

    # Training loop
    print("Training...")
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        # Extract fields from batch
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss

        # Print progress
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        # Save checkpoint
        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

In [16]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        # Forward input through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            # Forward pass through decoder
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores

In [17]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [indexesFromSentence(voc, sentence)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    # Use appropriate device
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, max_length)
    # indexes -> words
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words


def evaluateInput(encoder, decoder, searcher, voc):
    input_sentence = ''
    while(1):
        try:
            # Get input sentence
            input_sentence = input('> ')
            # Check if it is quit case
            if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalizeString(input_sentence)
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            print('Bot:', ' '.join(output_words))

        except KeyError:
            print("Error: Encountered unknown word.")

In [18]:
# Configure models
model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64

corpus_name = 'search_query'
alldatafile = 'data/Batch_3648643_batch_results_rob/all.csv'
traindatafile = 'data/Batch_3648643_batch_results_rob/train_step_query.csv'

voc, _ = loadPrepareData(corpus_name, alldatafile, save_dir)
_, pairs = loadPrepareData(corpus_name, traindatafile, save_dir)

voc.addWord('')



# Set checkpoint to load from; set to None if starting from scratch
loadFilename = None
checkpoint_iter = 4000
#loadFilename = os.path.join(save_dir, model_name, corpus_name,
#                            '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
#                            '{}_checkpoint.tar'.format(checkpoint_iter))


# Load model if a loadFilename is provided
if loadFilename:
    # If loading on same machine the model was trained on
    checkpoint = torch.load(loadFilename)
    # If loading a model trained on GPU to CPU
    #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']


print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Start preparing training data ...
Reading lines...
Read 226 sentence pairs
Trimmed to 226 sentence pairs
Counting words...
Counted words: 1159
Start preparing training data ...
Reading lines...
Read 156 sentence pairs
Trimmed to 156 sentence pairs
Counting words...
Counted words: 971
Building encoder and decoder ...
Models built and ready to go!


In [19]:
def evaluateFile(encoder, decoder, searcher, voc, filename, targetname):
    text = list(csv.reader(open(filename, 'rt')))
    target = list(csv.reader(open(targetname, 'rt')))
    responses = []
#     input_sentence = ''
    for input_sentence in text:
        try:
            # Get input sentence
#             input_sentence = input('> ')
            # Check if it is quit case
#             if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalizeString(input_sentence[0])
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            responses.append(' '.join(output_words))
#             print('Bot:', ' '.join(output_words))
        except KeyError:
            responses.append(' ')
#             print("Error: Encountered unknown word.")
    return text, target, responses

In [20]:
# Configure training/optimization
clip = 50.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
print_every = 1
save_every = 500

# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

# Initialize optimizers
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)


# Run training iterations
print("Starting Training!")
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building optimizers ...
Starting Training!
Initializing ...
Training...
Iteration: 1; Percent complete: 0.0%; Average loss: 7.0627
Iteration: 2; Percent complete: 0.1%; Average loss: 7.0437
Iteration: 3; Percent complete: 0.1%; Average loss: 7.0024
Iteration: 4; Percent complete: 0.1%; Average loss: 6.9281
Iteration: 5; Percent complete: 0.1%; Average loss: 6.9019
Iteration: 6; Percent complete: 0.1%; Average loss: 6.8447
Iteration: 7; Percent complete: 0.2%; Average loss: 6.6096
Iteration: 8; Percent complete: 0.2%; Average loss: 6.4029
Iteration: 9; Percent complete: 0.2%; Average loss: 6.5052
Iteration: 10; Percent complete: 0.2%; Average loss: 6.1473
Iteration: 11; Percent complete: 0.3%; Average loss: 6.0708
Iteration: 12; Percent complete: 0.3%; Average loss: 6.0141
Iteration: 13; Percent complete: 0.3%; Average loss: 6.0534
Iteration: 14; Percent complete: 0.4%; Average loss: 6.0314
Iteration: 15; Percent complete: 0.4%; Average loss: 6.1948
Iteration: 16; Percent complete: 0.4%

Iteration: 136; Percent complete: 3.4%; Average loss: 1.6083
Iteration: 137; Percent complete: 3.4%; Average loss: 1.4556
Iteration: 138; Percent complete: 3.5%; Average loss: 1.4920
Iteration: 139; Percent complete: 3.5%; Average loss: 1.6796
Iteration: 140; Percent complete: 3.5%; Average loss: 1.4757
Iteration: 141; Percent complete: 3.5%; Average loss: 1.4374
Iteration: 142; Percent complete: 3.5%; Average loss: 1.5514
Iteration: 143; Percent complete: 3.6%; Average loss: 1.3870
Iteration: 144; Percent complete: 3.6%; Average loss: 1.3498
Iteration: 145; Percent complete: 3.6%; Average loss: 1.3289
Iteration: 146; Percent complete: 3.6%; Average loss: 1.3568
Iteration: 147; Percent complete: 3.7%; Average loss: 1.1022
Iteration: 148; Percent complete: 3.7%; Average loss: 1.3453
Iteration: 149; Percent complete: 3.7%; Average loss: 1.4134
Iteration: 150; Percent complete: 3.8%; Average loss: 1.2856
Iteration: 151; Percent complete: 3.8%; Average loss: 1.1700
Iteration: 152; Percent 

Iteration: 271; Percent complete: 6.8%; Average loss: 0.1344
Iteration: 272; Percent complete: 6.8%; Average loss: 0.1382
Iteration: 273; Percent complete: 6.8%; Average loss: 0.1237
Iteration: 274; Percent complete: 6.9%; Average loss: 0.1431
Iteration: 275; Percent complete: 6.9%; Average loss: 0.1268
Iteration: 276; Percent complete: 6.9%; Average loss: 0.1364
Iteration: 277; Percent complete: 6.9%; Average loss: 0.1294
Iteration: 278; Percent complete: 7.0%; Average loss: 0.1130
Iteration: 279; Percent complete: 7.0%; Average loss: 0.1204
Iteration: 280; Percent complete: 7.0%; Average loss: 0.1212
Iteration: 281; Percent complete: 7.0%; Average loss: 0.0998
Iteration: 282; Percent complete: 7.0%; Average loss: 0.1176
Iteration: 283; Percent complete: 7.1%; Average loss: 0.1372
Iteration: 284; Percent complete: 7.1%; Average loss: 0.1170
Iteration: 285; Percent complete: 7.1%; Average loss: 0.1097
Iteration: 286; Percent complete: 7.1%; Average loss: 0.1053
Iteration: 287; Percent 

Iteration: 407; Percent complete: 10.2%; Average loss: 0.0353
Iteration: 408; Percent complete: 10.2%; Average loss: 0.0360
Iteration: 409; Percent complete: 10.2%; Average loss: 0.0361
Iteration: 410; Percent complete: 10.2%; Average loss: 0.0360
Iteration: 411; Percent complete: 10.3%; Average loss: 0.0325
Iteration: 412; Percent complete: 10.3%; Average loss: 0.0360
Iteration: 413; Percent complete: 10.3%; Average loss: 0.0342
Iteration: 414; Percent complete: 10.3%; Average loss: 0.0341
Iteration: 415; Percent complete: 10.4%; Average loss: 0.0345
Iteration: 416; Percent complete: 10.4%; Average loss: 0.0349
Iteration: 417; Percent complete: 10.4%; Average loss: 0.0386
Iteration: 418; Percent complete: 10.4%; Average loss: 0.0307
Iteration: 419; Percent complete: 10.5%; Average loss: 0.0351
Iteration: 420; Percent complete: 10.5%; Average loss: 0.0348
Iteration: 421; Percent complete: 10.5%; Average loss: 0.0337
Iteration: 422; Percent complete: 10.5%; Average loss: 0.0275
Iteratio

Iteration: 540; Percent complete: 13.5%; Average loss: 0.0194
Iteration: 541; Percent complete: 13.5%; Average loss: 0.0202
Iteration: 542; Percent complete: 13.6%; Average loss: 0.0178
Iteration: 543; Percent complete: 13.6%; Average loss: 0.0189
Iteration: 544; Percent complete: 13.6%; Average loss: 0.0182
Iteration: 545; Percent complete: 13.6%; Average loss: 0.0186
Iteration: 546; Percent complete: 13.7%; Average loss: 0.0183
Iteration: 547; Percent complete: 13.7%; Average loss: 0.0186
Iteration: 548; Percent complete: 13.7%; Average loss: 0.0187
Iteration: 549; Percent complete: 13.7%; Average loss: 0.0187
Iteration: 550; Percent complete: 13.8%; Average loss: 0.0166
Iteration: 551; Percent complete: 13.8%; Average loss: 0.0191
Iteration: 552; Percent complete: 13.8%; Average loss: 0.0170
Iteration: 553; Percent complete: 13.8%; Average loss: 0.0194
Iteration: 554; Percent complete: 13.9%; Average loss: 0.0186
Iteration: 555; Percent complete: 13.9%; Average loss: 0.0187
Iteratio

Iteration: 674; Percent complete: 16.9%; Average loss: 0.0107
Iteration: 675; Percent complete: 16.9%; Average loss: 0.0118
Iteration: 676; Percent complete: 16.9%; Average loss: 0.0131
Iteration: 677; Percent complete: 16.9%; Average loss: 0.0102
Iteration: 678; Percent complete: 17.0%; Average loss: 0.0115
Iteration: 679; Percent complete: 17.0%; Average loss: 0.0113
Iteration: 680; Percent complete: 17.0%; Average loss: 0.0110
Iteration: 681; Percent complete: 17.0%; Average loss: 0.0114
Iteration: 682; Percent complete: 17.1%; Average loss: 0.0123
Iteration: 683; Percent complete: 17.1%; Average loss: 0.0117
Iteration: 684; Percent complete: 17.1%; Average loss: 0.0113
Iteration: 685; Percent complete: 17.1%; Average loss: 0.0115
Iteration: 686; Percent complete: 17.2%; Average loss: 0.0112
Iteration: 687; Percent complete: 17.2%; Average loss: 0.0111
Iteration: 688; Percent complete: 17.2%; Average loss: 0.0107
Iteration: 689; Percent complete: 17.2%; Average loss: 0.0103
Iteratio

Iteration: 807; Percent complete: 20.2%; Average loss: 0.0080
Iteration: 808; Percent complete: 20.2%; Average loss: 0.0078
Iteration: 809; Percent complete: 20.2%; Average loss: 0.0076
Iteration: 810; Percent complete: 20.2%; Average loss: 0.0078
Iteration: 811; Percent complete: 20.3%; Average loss: 0.0083
Iteration: 812; Percent complete: 20.3%; Average loss: 0.0087
Iteration: 813; Percent complete: 20.3%; Average loss: 0.0078
Iteration: 814; Percent complete: 20.3%; Average loss: 0.0083
Iteration: 815; Percent complete: 20.4%; Average loss: 0.0076
Iteration: 816; Percent complete: 20.4%; Average loss: 0.0080
Iteration: 817; Percent complete: 20.4%; Average loss: 0.0088
Iteration: 818; Percent complete: 20.4%; Average loss: 0.0077
Iteration: 819; Percent complete: 20.5%; Average loss: 0.0076
Iteration: 820; Percent complete: 20.5%; Average loss: 0.0069
Iteration: 821; Percent complete: 20.5%; Average loss: 0.0085
Iteration: 822; Percent complete: 20.5%; Average loss: 0.0075
Iteratio

Iteration: 940; Percent complete: 23.5%; Average loss: 0.0087
Iteration: 941; Percent complete: 23.5%; Average loss: 0.0107
Iteration: 942; Percent complete: 23.5%; Average loss: 0.0226
Iteration: 943; Percent complete: 23.6%; Average loss: 0.0067
Iteration: 944; Percent complete: 23.6%; Average loss: 0.0073
Iteration: 945; Percent complete: 23.6%; Average loss: 0.0092
Iteration: 946; Percent complete: 23.6%; Average loss: 0.0078
Iteration: 947; Percent complete: 23.7%; Average loss: 0.0110
Iteration: 948; Percent complete: 23.7%; Average loss: 0.0200
Iteration: 949; Percent complete: 23.7%; Average loss: 0.0071
Iteration: 950; Percent complete: 23.8%; Average loss: 0.0098
Iteration: 951; Percent complete: 23.8%; Average loss: 0.0086
Iteration: 952; Percent complete: 23.8%; Average loss: 0.0162
Iteration: 953; Percent complete: 23.8%; Average loss: 0.0091
Iteration: 954; Percent complete: 23.8%; Average loss: 0.0103
Iteration: 955; Percent complete: 23.9%; Average loss: 0.0101
Iteratio

Iteration: 1072; Percent complete: 26.8%; Average loss: 0.0202
Iteration: 1073; Percent complete: 26.8%; Average loss: 0.0101
Iteration: 1074; Percent complete: 26.9%; Average loss: 0.0307
Iteration: 1075; Percent complete: 26.9%; Average loss: 0.0103
Iteration: 1076; Percent complete: 26.9%; Average loss: 0.0180
Iteration: 1077; Percent complete: 26.9%; Average loss: 0.0079
Iteration: 1078; Percent complete: 27.0%; Average loss: 0.0083
Iteration: 1079; Percent complete: 27.0%; Average loss: 0.0143
Iteration: 1080; Percent complete: 27.0%; Average loss: 0.0130
Iteration: 1081; Percent complete: 27.0%; Average loss: 0.0079
Iteration: 1082; Percent complete: 27.1%; Average loss: 0.0085
Iteration: 1083; Percent complete: 27.1%; Average loss: 0.0082
Iteration: 1084; Percent complete: 27.1%; Average loss: 0.0084
Iteration: 1085; Percent complete: 27.1%; Average loss: 0.0093
Iteration: 1086; Percent complete: 27.2%; Average loss: 0.0109
Iteration: 1087; Percent complete: 27.2%; Average loss:

Iteration: 1203; Percent complete: 30.1%; Average loss: 0.0043
Iteration: 1204; Percent complete: 30.1%; Average loss: 0.0041
Iteration: 1205; Percent complete: 30.1%; Average loss: 0.0045
Iteration: 1206; Percent complete: 30.1%; Average loss: 0.0043
Iteration: 1207; Percent complete: 30.2%; Average loss: 0.0046
Iteration: 1208; Percent complete: 30.2%; Average loss: 0.0046
Iteration: 1209; Percent complete: 30.2%; Average loss: 0.0044
Iteration: 1210; Percent complete: 30.2%; Average loss: 0.0050
Iteration: 1211; Percent complete: 30.3%; Average loss: 0.0047
Iteration: 1212; Percent complete: 30.3%; Average loss: 0.0044
Iteration: 1213; Percent complete: 30.3%; Average loss: 0.0046
Iteration: 1214; Percent complete: 30.3%; Average loss: 0.0044
Iteration: 1215; Percent complete: 30.4%; Average loss: 0.0046
Iteration: 1216; Percent complete: 30.4%; Average loss: 0.0043
Iteration: 1217; Percent complete: 30.4%; Average loss: 0.0043
Iteration: 1218; Percent complete: 30.4%; Average loss:

Iteration: 1335; Percent complete: 33.4%; Average loss: 0.0030
Iteration: 1336; Percent complete: 33.4%; Average loss: 0.0031
Iteration: 1337; Percent complete: 33.4%; Average loss: 0.0032
Iteration: 1338; Percent complete: 33.5%; Average loss: 0.0033
Iteration: 1339; Percent complete: 33.5%; Average loss: 0.0032
Iteration: 1340; Percent complete: 33.5%; Average loss: 0.0034
Iteration: 1341; Percent complete: 33.5%; Average loss: 0.0033
Iteration: 1342; Percent complete: 33.6%; Average loss: 0.0032
Iteration: 1343; Percent complete: 33.6%; Average loss: 0.0032
Iteration: 1344; Percent complete: 33.6%; Average loss: 0.0032
Iteration: 1345; Percent complete: 33.6%; Average loss: 0.0030
Iteration: 1346; Percent complete: 33.7%; Average loss: 0.0034
Iteration: 1347; Percent complete: 33.7%; Average loss: 0.0032
Iteration: 1348; Percent complete: 33.7%; Average loss: 0.0034
Iteration: 1349; Percent complete: 33.7%; Average loss: 0.0036
Iteration: 1350; Percent complete: 33.8%; Average loss:

Iteration: 1466; Percent complete: 36.6%; Average loss: 0.0026
Iteration: 1467; Percent complete: 36.7%; Average loss: 0.0026
Iteration: 1468; Percent complete: 36.7%; Average loss: 0.0027
Iteration: 1469; Percent complete: 36.7%; Average loss: 0.0025
Iteration: 1470; Percent complete: 36.8%; Average loss: 0.0029
Iteration: 1471; Percent complete: 36.8%; Average loss: 0.0028
Iteration: 1472; Percent complete: 36.8%; Average loss: 0.0029
Iteration: 1473; Percent complete: 36.8%; Average loss: 0.0024
Iteration: 1474; Percent complete: 36.9%; Average loss: 0.0027
Iteration: 1475; Percent complete: 36.9%; Average loss: 0.0028
Iteration: 1476; Percent complete: 36.9%; Average loss: 0.0026
Iteration: 1477; Percent complete: 36.9%; Average loss: 0.0032
Iteration: 1478; Percent complete: 37.0%; Average loss: 0.0026
Iteration: 1479; Percent complete: 37.0%; Average loss: 0.0024
Iteration: 1480; Percent complete: 37.0%; Average loss: 0.0027
Iteration: 1481; Percent complete: 37.0%; Average loss:

Iteration: 1597; Percent complete: 39.9%; Average loss: 0.0023
Iteration: 1598; Percent complete: 40.0%; Average loss: 0.0023
Iteration: 1599; Percent complete: 40.0%; Average loss: 0.0024
Iteration: 1600; Percent complete: 40.0%; Average loss: 0.0023
Iteration: 1601; Percent complete: 40.0%; Average loss: 0.0022
Iteration: 1602; Percent complete: 40.1%; Average loss: 0.0023
Iteration: 1603; Percent complete: 40.1%; Average loss: 0.0021
Iteration: 1604; Percent complete: 40.1%; Average loss: 0.0026
Iteration: 1605; Percent complete: 40.1%; Average loss: 0.0023
Iteration: 1606; Percent complete: 40.2%; Average loss: 0.0024
Iteration: 1607; Percent complete: 40.2%; Average loss: 0.0023
Iteration: 1608; Percent complete: 40.2%; Average loss: 0.0021
Iteration: 1609; Percent complete: 40.2%; Average loss: 0.0021
Iteration: 1610; Percent complete: 40.2%; Average loss: 0.0027
Iteration: 1611; Percent complete: 40.3%; Average loss: 0.0023
Iteration: 1612; Percent complete: 40.3%; Average loss:

Iteration: 1729; Percent complete: 43.2%; Average loss: 0.0018
Iteration: 1730; Percent complete: 43.2%; Average loss: 0.0018
Iteration: 1731; Percent complete: 43.3%; Average loss: 0.0018
Iteration: 1732; Percent complete: 43.3%; Average loss: 0.0018
Iteration: 1733; Percent complete: 43.3%; Average loss: 0.0017
Iteration: 1734; Percent complete: 43.4%; Average loss: 0.0018
Iteration: 1735; Percent complete: 43.4%; Average loss: 0.0017
Iteration: 1736; Percent complete: 43.4%; Average loss: 0.0020
Iteration: 1737; Percent complete: 43.4%; Average loss: 0.0020
Iteration: 1738; Percent complete: 43.5%; Average loss: 0.0019
Iteration: 1739; Percent complete: 43.5%; Average loss: 0.0018
Iteration: 1740; Percent complete: 43.5%; Average loss: 0.0019
Iteration: 1741; Percent complete: 43.5%; Average loss: 0.0019
Iteration: 1742; Percent complete: 43.5%; Average loss: 0.0018
Iteration: 1743; Percent complete: 43.6%; Average loss: 0.0018
Iteration: 1744; Percent complete: 43.6%; Average loss:

Iteration: 1861; Percent complete: 46.5%; Average loss: 0.0019
Iteration: 1862; Percent complete: 46.6%; Average loss: 0.0016
Iteration: 1863; Percent complete: 46.6%; Average loss: 0.0017
Iteration: 1864; Percent complete: 46.6%; Average loss: 0.0017
Iteration: 1865; Percent complete: 46.6%; Average loss: 0.0017
Iteration: 1866; Percent complete: 46.7%; Average loss: 0.0017
Iteration: 1867; Percent complete: 46.7%; Average loss: 0.0016
Iteration: 1868; Percent complete: 46.7%; Average loss: 0.0017
Iteration: 1869; Percent complete: 46.7%; Average loss: 0.0016
Iteration: 1870; Percent complete: 46.8%; Average loss: 0.0016
Iteration: 1871; Percent complete: 46.8%; Average loss: 0.0016
Iteration: 1872; Percent complete: 46.8%; Average loss: 0.0018
Iteration: 1873; Percent complete: 46.8%; Average loss: 0.0016
Iteration: 1874; Percent complete: 46.9%; Average loss: 0.0016
Iteration: 1875; Percent complete: 46.9%; Average loss: 0.0017
Iteration: 1876; Percent complete: 46.9%; Average loss:

Iteration: 1993; Percent complete: 49.8%; Average loss: 0.0013
Iteration: 1994; Percent complete: 49.9%; Average loss: 0.0015
Iteration: 1995; Percent complete: 49.9%; Average loss: 0.0014
Iteration: 1996; Percent complete: 49.9%; Average loss: 0.0015
Iteration: 1997; Percent complete: 49.9%; Average loss: 0.0014
Iteration: 1998; Percent complete: 50.0%; Average loss: 0.0013
Iteration: 1999; Percent complete: 50.0%; Average loss: 0.0014
Iteration: 2000; Percent complete: 50.0%; Average loss: 0.0013
Iteration: 2001; Percent complete: 50.0%; Average loss: 0.0014
Iteration: 2002; Percent complete: 50.0%; Average loss: 0.0013
Iteration: 2003; Percent complete: 50.1%; Average loss: 0.0014
Iteration: 2004; Percent complete: 50.1%; Average loss: 0.0014
Iteration: 2005; Percent complete: 50.1%; Average loss: 0.0013
Iteration: 2006; Percent complete: 50.1%; Average loss: 0.0013
Iteration: 2007; Percent complete: 50.2%; Average loss: 0.0012
Iteration: 2008; Percent complete: 50.2%; Average loss:

Iteration: 2124; Percent complete: 53.1%; Average loss: 0.0012
Iteration: 2125; Percent complete: 53.1%; Average loss: 0.0012
Iteration: 2126; Percent complete: 53.1%; Average loss: 0.0013
Iteration: 2127; Percent complete: 53.2%; Average loss: 0.0012
Iteration: 2128; Percent complete: 53.2%; Average loss: 0.0012
Iteration: 2129; Percent complete: 53.2%; Average loss: 0.0013
Iteration: 2130; Percent complete: 53.2%; Average loss: 0.0013
Iteration: 2131; Percent complete: 53.3%; Average loss: 0.0012
Iteration: 2132; Percent complete: 53.3%; Average loss: 0.0012
Iteration: 2133; Percent complete: 53.3%; Average loss: 0.0012
Iteration: 2134; Percent complete: 53.3%; Average loss: 0.0012
Iteration: 2135; Percent complete: 53.4%; Average loss: 0.0012
Iteration: 2136; Percent complete: 53.4%; Average loss: 0.0013
Iteration: 2137; Percent complete: 53.4%; Average loss: 0.0012
Iteration: 2138; Percent complete: 53.4%; Average loss: 0.0014
Iteration: 2139; Percent complete: 53.5%; Average loss:

Iteration: 2256; Percent complete: 56.4%; Average loss: 0.0012
Iteration: 2257; Percent complete: 56.4%; Average loss: 0.0011
Iteration: 2258; Percent complete: 56.5%; Average loss: 0.0011
Iteration: 2259; Percent complete: 56.5%; Average loss: 0.0010
Iteration: 2260; Percent complete: 56.5%; Average loss: 0.0009
Iteration: 2261; Percent complete: 56.5%; Average loss: 0.0010
Iteration: 2262; Percent complete: 56.5%; Average loss: 0.0010
Iteration: 2263; Percent complete: 56.6%; Average loss: 0.0011
Iteration: 2264; Percent complete: 56.6%; Average loss: 0.0010
Iteration: 2265; Percent complete: 56.6%; Average loss: 0.0011
Iteration: 2266; Percent complete: 56.6%; Average loss: 0.0011
Iteration: 2267; Percent complete: 56.7%; Average loss: 0.0012
Iteration: 2268; Percent complete: 56.7%; Average loss: 0.0010
Iteration: 2269; Percent complete: 56.7%; Average loss: 0.0011
Iteration: 2270; Percent complete: 56.8%; Average loss: 0.0011
Iteration: 2271; Percent complete: 56.8%; Average loss:

Iteration: 2388; Percent complete: 59.7%; Average loss: 0.0008
Iteration: 2389; Percent complete: 59.7%; Average loss: 0.0009
Iteration: 2390; Percent complete: 59.8%; Average loss: 0.0010
Iteration: 2391; Percent complete: 59.8%; Average loss: 0.0009
Iteration: 2392; Percent complete: 59.8%; Average loss: 0.0011
Iteration: 2393; Percent complete: 59.8%; Average loss: 0.0010
Iteration: 2394; Percent complete: 59.9%; Average loss: 0.0009
Iteration: 2395; Percent complete: 59.9%; Average loss: 0.0009
Iteration: 2396; Percent complete: 59.9%; Average loss: 0.0010
Iteration: 2397; Percent complete: 59.9%; Average loss: 0.0010
Iteration: 2398; Percent complete: 60.0%; Average loss: 0.0009
Iteration: 2399; Percent complete: 60.0%; Average loss: 0.0009
Iteration: 2400; Percent complete: 60.0%; Average loss: 0.0009
Iteration: 2401; Percent complete: 60.0%; Average loss: 0.0009
Iteration: 2402; Percent complete: 60.1%; Average loss: 0.0009
Iteration: 2403; Percent complete: 60.1%; Average loss:

Iteration: 2519; Percent complete: 63.0%; Average loss: 0.0008
Iteration: 2520; Percent complete: 63.0%; Average loss: 0.0009
Iteration: 2521; Percent complete: 63.0%; Average loss: 0.0008
Iteration: 2522; Percent complete: 63.0%; Average loss: 0.0008
Iteration: 2523; Percent complete: 63.1%; Average loss: 0.0008
Iteration: 2524; Percent complete: 63.1%; Average loss: 0.0008
Iteration: 2525; Percent complete: 63.1%; Average loss: 0.0009
Iteration: 2526; Percent complete: 63.1%; Average loss: 0.0009
Iteration: 2527; Percent complete: 63.2%; Average loss: 0.0008
Iteration: 2528; Percent complete: 63.2%; Average loss: 0.0008
Iteration: 2529; Percent complete: 63.2%; Average loss: 0.0009
Iteration: 2530; Percent complete: 63.2%; Average loss: 0.0008
Iteration: 2531; Percent complete: 63.3%; Average loss: 0.0008
Iteration: 2532; Percent complete: 63.3%; Average loss: 0.0009
Iteration: 2533; Percent complete: 63.3%; Average loss: 0.0008
Iteration: 2534; Percent complete: 63.3%; Average loss:

Iteration: 2651; Percent complete: 66.3%; Average loss: 0.0007
Iteration: 2652; Percent complete: 66.3%; Average loss: 0.0007
Iteration: 2653; Percent complete: 66.3%; Average loss: 0.0008
Iteration: 2654; Percent complete: 66.3%; Average loss: 0.0007
Iteration: 2655; Percent complete: 66.4%; Average loss: 0.0007
Iteration: 2656; Percent complete: 66.4%; Average loss: 0.0008
Iteration: 2657; Percent complete: 66.4%; Average loss: 0.0007
Iteration: 2658; Percent complete: 66.5%; Average loss: 0.0007
Iteration: 2659; Percent complete: 66.5%; Average loss: 0.0007
Iteration: 2660; Percent complete: 66.5%; Average loss: 0.0008
Iteration: 2661; Percent complete: 66.5%; Average loss: 0.0007
Iteration: 2662; Percent complete: 66.5%; Average loss: 0.0007
Iteration: 2663; Percent complete: 66.6%; Average loss: 0.0007
Iteration: 2664; Percent complete: 66.6%; Average loss: 0.0007
Iteration: 2665; Percent complete: 66.6%; Average loss: 0.0007
Iteration: 2666; Percent complete: 66.6%; Average loss:

Iteration: 2783; Percent complete: 69.6%; Average loss: 0.0007
Iteration: 2784; Percent complete: 69.6%; Average loss: 0.0006
Iteration: 2785; Percent complete: 69.6%; Average loss: 0.0007
Iteration: 2786; Percent complete: 69.7%; Average loss: 0.0006
Iteration: 2787; Percent complete: 69.7%; Average loss: 0.0007
Iteration: 2788; Percent complete: 69.7%; Average loss: 0.0006
Iteration: 2789; Percent complete: 69.7%; Average loss: 0.0006
Iteration: 2790; Percent complete: 69.8%; Average loss: 0.0006
Iteration: 2791; Percent complete: 69.8%; Average loss: 0.0006
Iteration: 2792; Percent complete: 69.8%; Average loss: 0.0007
Iteration: 2793; Percent complete: 69.8%; Average loss: 0.0006
Iteration: 2794; Percent complete: 69.8%; Average loss: 0.0006
Iteration: 2795; Percent complete: 69.9%; Average loss: 0.0006
Iteration: 2796; Percent complete: 69.9%; Average loss: 0.0007
Iteration: 2797; Percent complete: 69.9%; Average loss: 0.0007
Iteration: 2798; Percent complete: 70.0%; Average loss:

Iteration: 2914; Percent complete: 72.9%; Average loss: 0.0006
Iteration: 2915; Percent complete: 72.9%; Average loss: 0.0006
Iteration: 2916; Percent complete: 72.9%; Average loss: 0.0006
Iteration: 2917; Percent complete: 72.9%; Average loss: 0.0006
Iteration: 2918; Percent complete: 73.0%; Average loss: 0.0006
Iteration: 2919; Percent complete: 73.0%; Average loss: 0.0006
Iteration: 2920; Percent complete: 73.0%; Average loss: 0.0006
Iteration: 2921; Percent complete: 73.0%; Average loss: 0.0006
Iteration: 2922; Percent complete: 73.0%; Average loss: 0.0006
Iteration: 2923; Percent complete: 73.1%; Average loss: 0.0006
Iteration: 2924; Percent complete: 73.1%; Average loss: 0.0006
Iteration: 2925; Percent complete: 73.1%; Average loss: 0.0006
Iteration: 2926; Percent complete: 73.2%; Average loss: 0.0005
Iteration: 2927; Percent complete: 73.2%; Average loss: 0.0006
Iteration: 2928; Percent complete: 73.2%; Average loss: 0.0006
Iteration: 2929; Percent complete: 73.2%; Average loss:

Iteration: 3045; Percent complete: 76.1%; Average loss: 0.0005
Iteration: 3046; Percent complete: 76.1%; Average loss: 0.0005
Iteration: 3047; Percent complete: 76.2%; Average loss: 0.0005
Iteration: 3048; Percent complete: 76.2%; Average loss: 0.0005
Iteration: 3049; Percent complete: 76.2%; Average loss: 0.0005
Iteration: 3050; Percent complete: 76.2%; Average loss: 0.0006
Iteration: 3051; Percent complete: 76.3%; Average loss: 0.0005
Iteration: 3052; Percent complete: 76.3%; Average loss: 0.0005
Iteration: 3053; Percent complete: 76.3%; Average loss: 0.0005
Iteration: 3054; Percent complete: 76.3%; Average loss: 0.0005
Iteration: 3055; Percent complete: 76.4%; Average loss: 0.0005
Iteration: 3056; Percent complete: 76.4%; Average loss: 0.0005
Iteration: 3057; Percent complete: 76.4%; Average loss: 0.0006
Iteration: 3058; Percent complete: 76.4%; Average loss: 0.0006
Iteration: 3059; Percent complete: 76.5%; Average loss: 0.0005
Iteration: 3060; Percent complete: 76.5%; Average loss:

Iteration: 3177; Percent complete: 79.4%; Average loss: 0.0005
Iteration: 3178; Percent complete: 79.5%; Average loss: 0.0005
Iteration: 3179; Percent complete: 79.5%; Average loss: 0.0005
Iteration: 3180; Percent complete: 79.5%; Average loss: 0.0005
Iteration: 3181; Percent complete: 79.5%; Average loss: 0.0005
Iteration: 3182; Percent complete: 79.5%; Average loss: 0.0005
Iteration: 3183; Percent complete: 79.6%; Average loss: 0.0005
Iteration: 3184; Percent complete: 79.6%; Average loss: 0.0005
Iteration: 3185; Percent complete: 79.6%; Average loss: 0.0005
Iteration: 3186; Percent complete: 79.7%; Average loss: 0.0004
Iteration: 3187; Percent complete: 79.7%; Average loss: 0.0005
Iteration: 3188; Percent complete: 79.7%; Average loss: 0.0005
Iteration: 3189; Percent complete: 79.7%; Average loss: 0.0005
Iteration: 3190; Percent complete: 79.8%; Average loss: 0.0005
Iteration: 3191; Percent complete: 79.8%; Average loss: 0.0004
Iteration: 3192; Percent complete: 79.8%; Average loss:

Iteration: 3308; Percent complete: 82.7%; Average loss: 0.0004
Iteration: 3309; Percent complete: 82.7%; Average loss: 0.0004
Iteration: 3310; Percent complete: 82.8%; Average loss: 0.0004
Iteration: 3311; Percent complete: 82.8%; Average loss: 0.0004
Iteration: 3312; Percent complete: 82.8%; Average loss: 0.0004
Iteration: 3313; Percent complete: 82.8%; Average loss: 0.0004
Iteration: 3314; Percent complete: 82.8%; Average loss: 0.0005
Iteration: 3315; Percent complete: 82.9%; Average loss: 0.0004
Iteration: 3316; Percent complete: 82.9%; Average loss: 0.0004
Iteration: 3317; Percent complete: 82.9%; Average loss: 0.0004
Iteration: 3318; Percent complete: 83.0%; Average loss: 0.0004
Iteration: 3319; Percent complete: 83.0%; Average loss: 0.0004
Iteration: 3320; Percent complete: 83.0%; Average loss: 0.0004
Iteration: 3321; Percent complete: 83.0%; Average loss: 0.0004
Iteration: 3322; Percent complete: 83.0%; Average loss: 0.0004
Iteration: 3323; Percent complete: 83.1%; Average loss:

Iteration: 3439; Percent complete: 86.0%; Average loss: 0.0004
Iteration: 3440; Percent complete: 86.0%; Average loss: 0.0004
Iteration: 3441; Percent complete: 86.0%; Average loss: 0.0004
Iteration: 3442; Percent complete: 86.1%; Average loss: 0.0004
Iteration: 3443; Percent complete: 86.1%; Average loss: 0.0005
Iteration: 3444; Percent complete: 86.1%; Average loss: 0.0004
Iteration: 3445; Percent complete: 86.1%; Average loss: 0.0003
Iteration: 3446; Percent complete: 86.2%; Average loss: 0.0004
Iteration: 3447; Percent complete: 86.2%; Average loss: 0.0004
Iteration: 3448; Percent complete: 86.2%; Average loss: 0.0004
Iteration: 3449; Percent complete: 86.2%; Average loss: 0.0004
Iteration: 3450; Percent complete: 86.2%; Average loss: 0.0004
Iteration: 3451; Percent complete: 86.3%; Average loss: 0.0004
Iteration: 3452; Percent complete: 86.3%; Average loss: 0.0003
Iteration: 3453; Percent complete: 86.3%; Average loss: 0.0004
Iteration: 3454; Percent complete: 86.4%; Average loss:

Iteration: 3570; Percent complete: 89.2%; Average loss: 0.0004
Iteration: 3571; Percent complete: 89.3%; Average loss: 0.0004
Iteration: 3572; Percent complete: 89.3%; Average loss: 0.0004
Iteration: 3573; Percent complete: 89.3%; Average loss: 0.0004
Iteration: 3574; Percent complete: 89.3%; Average loss: 0.0004
Iteration: 3575; Percent complete: 89.4%; Average loss: 0.0004
Iteration: 3576; Percent complete: 89.4%; Average loss: 0.0004
Iteration: 3577; Percent complete: 89.4%; Average loss: 0.0004
Iteration: 3578; Percent complete: 89.5%; Average loss: 0.0004
Iteration: 3579; Percent complete: 89.5%; Average loss: 0.0004
Iteration: 3580; Percent complete: 89.5%; Average loss: 0.0004
Iteration: 3581; Percent complete: 89.5%; Average loss: 0.0004
Iteration: 3582; Percent complete: 89.5%; Average loss: 0.0004
Iteration: 3583; Percent complete: 89.6%; Average loss: 0.0004
Iteration: 3584; Percent complete: 89.6%; Average loss: 0.0004
Iteration: 3585; Percent complete: 89.6%; Average loss:

Iteration: 3703; Percent complete: 92.6%; Average loss: 0.0003
Iteration: 3704; Percent complete: 92.6%; Average loss: 0.0003
Iteration: 3705; Percent complete: 92.6%; Average loss: 0.0004
Iteration: 3706; Percent complete: 92.7%; Average loss: 0.0004
Iteration: 3707; Percent complete: 92.7%; Average loss: 0.0004
Iteration: 3708; Percent complete: 92.7%; Average loss: 0.0003
Iteration: 3709; Percent complete: 92.7%; Average loss: 0.0003
Iteration: 3710; Percent complete: 92.8%; Average loss: 0.0003
Iteration: 3711; Percent complete: 92.8%; Average loss: 0.0003
Iteration: 3712; Percent complete: 92.8%; Average loss: 0.0004
Iteration: 3713; Percent complete: 92.8%; Average loss: 0.0003
Iteration: 3714; Percent complete: 92.8%; Average loss: 0.0004
Iteration: 3715; Percent complete: 92.9%; Average loss: 0.0003
Iteration: 3716; Percent complete: 92.9%; Average loss: 0.0003
Iteration: 3717; Percent complete: 92.9%; Average loss: 0.0003
Iteration: 3718; Percent complete: 93.0%; Average loss:

Iteration: 3834; Percent complete: 95.9%; Average loss: 0.0003
Iteration: 3835; Percent complete: 95.9%; Average loss: 0.0003
Iteration: 3836; Percent complete: 95.9%; Average loss: 0.0003
Iteration: 3837; Percent complete: 95.9%; Average loss: 0.0003
Iteration: 3838; Percent complete: 96.0%; Average loss: 0.0003
Iteration: 3839; Percent complete: 96.0%; Average loss: 0.0003
Iteration: 3840; Percent complete: 96.0%; Average loss: 0.0003
Iteration: 3841; Percent complete: 96.0%; Average loss: 0.0003
Iteration: 3842; Percent complete: 96.0%; Average loss: 0.0003
Iteration: 3843; Percent complete: 96.1%; Average loss: 0.0003
Iteration: 3844; Percent complete: 96.1%; Average loss: 0.0003
Iteration: 3845; Percent complete: 96.1%; Average loss: 0.0003
Iteration: 3846; Percent complete: 96.2%; Average loss: 0.0003
Iteration: 3847; Percent complete: 96.2%; Average loss: 0.0003
Iteration: 3848; Percent complete: 96.2%; Average loss: 0.0003
Iteration: 3849; Percent complete: 96.2%; Average loss:

Iteration: 3966; Percent complete: 99.2%; Average loss: 0.0003
Iteration: 3967; Percent complete: 99.2%; Average loss: 0.0003
Iteration: 3968; Percent complete: 99.2%; Average loss: 0.0003
Iteration: 3969; Percent complete: 99.2%; Average loss: 0.0003
Iteration: 3970; Percent complete: 99.2%; Average loss: 0.0003
Iteration: 3971; Percent complete: 99.3%; Average loss: 0.0003
Iteration: 3972; Percent complete: 99.3%; Average loss: 0.0003
Iteration: 3973; Percent complete: 99.3%; Average loss: 0.0003
Iteration: 3974; Percent complete: 99.4%; Average loss: 0.0003
Iteration: 3975; Percent complete: 99.4%; Average loss: 0.0003
Iteration: 3976; Percent complete: 99.4%; Average loss: 0.0003
Iteration: 3977; Percent complete: 99.4%; Average loss: 0.0003
Iteration: 3978; Percent complete: 99.5%; Average loss: 0.0003
Iteration: 3979; Percent complete: 99.5%; Average loss: 0.0003
Iteration: 3980; Percent complete: 99.5%; Average loss: 0.0003
Iteration: 3981; Percent complete: 99.5%; Average loss:

Run Evaluation
~~~~~~~~~~~~~~

To chat with your model, run the following block.




In [21]:
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

# Begin chatting (uncomment and run the following line to begin)
evaluateInput(encoder, decoder, searcher, voc)

> cook chicken
Bot: type macaroni recommend thi recip get minut recip


KeyboardInterrupt: 

Conclusion
----------

That’s all for this one, folks. Congratulations, you now know the
fundamentals to building a generative chatbot model! If you’re
interested, you can try tailoring the chatbot’s behavior by tweaking the
model and training parameters and customizing the data that you train
the model on.

Check out the other tutorials for more cool deep learning applications
in PyTorch!




In [22]:
filename = 'data/Batch_3648643_batch_results_rob/test_step_query_text.csv'
targetname = 'data/Batch_3648643_batch_results_rob/test_step_query_target.csv'
encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

# Begin chatting (uncomment and run the following line to begin)
text, target, responses = evaluateFile(encoder, decoder, searcher, voc, filename=filename, targetname=targetname)

with open('data/Batch_3648643_batch_results_rob/test_step_query_predict.csv', 'wt') as f:
    for l in responses:
        f.write('{}\n'.format(l))

In [23]:
import os
res = os.popen('perl multi_bleu.perl data/Batch_3648643_batch_results_rob/test_step_query_target.csv < data/Batch_3648643_batch_results_rob/test_step_query_predict.csv')
res.read()

'BLEU = 1.05, 3.2/1.0/0.8/0.5 (BP=1.000, ratio=1.192, hyp_len=348, ref_len=292)\n'

In [26]:
res.close()

512

In [24]:
with open('seq2seq_step_query_pred.txt', 'wt') as f:
    for i, (t, g, p) in enumerate(zip(text, target, responses)):
        print('\n\n\n')
        f.write('\n\n\n')
        if len(g) == 0:
            g = [None]
        print('TEXT:\n{}\nTARGET:\n{}\nPREDICT:\n{}'.format(t[0], g[0], p.strip()))
        f.write('TEXT:\n{}\nTARGET:\n{}\nPREDICT:\n{}'.format(t[0], g[0], p.strip()))
#     if i == 100:
#         break





TEXT:
step_query
TARGET:
information
PREDICT:
none later later recip




TEXT:
5 would need ask question becaus understand step
TARGET:
none
PREDICT:
long blend hand use mixer also dri




TEXT:
bake dish layer 2 cup corn chip singl layer top 1 cup chees spoon beef mixtur chees top remain chees corn chip substitut
TARGET:
substitut corn chip gluten free option
PREDICT:
step clear pear snugli pear snugli mayb exact dish




TEXT:
pan one add cut piec boil potato aloo methi quick beginn recip
TARGET:
better phrase thi step like next add cut piec boil potato first pan mayb detail separ step like next cut boil potato finish boil add potato pan one
PREDICT:
heat stove way get nice brown effect




TEXT:
mix ingredi togeth bake bake muffin
TARGET:
bake muffin
PREDICT:
snow powder anoth work powder sugar




TEXT:
place fill popsicl mold care freezer allow freez 30 minut none
TARGET:
none
PREDICT:
none way get nice brown effect




TEXT:
hardli ani gravi left chicken caramelis char part e

In [33]:
target

[[],
 [],
 [],
 [],
 ['cups packed light brown sugar tablespoons margarine tablespoons vegetable shortening cups dark molasses tablespoon baking soda cup boiling water cups all purpose flour sifted tablespoon ground cloves tablespoons ground ginger tablespoon ground cinnamon'],
 [],
 [],
 [],
 ['image of finished product'],
 [],
 [],
 ['peeled and cut into inch thick rounds bechamel sauce cup butter cups hot milk tablespoons flour eggs cup grated kefalograviera cheese or parm teaspoon salt'],
 [],
 ['visualized instruction'],
 [],
 ['how the chicken will look like after frying .'],
 [],
 ['tell exact amounts of ingredients'],
 ['place the cookie crusts in the freezer and make the banana ice cream'],
 ['show a video of this step and show the correct consistency .'],
 [],
 [],
 [],
 ['visualized instructions for clarifying purposes'],
 ['make the pink coconut cream .'],
 [],
 [],
 ['baking dish'],
 ['show images or video of this step . suggest turning the oven on or have the assistant tu