In [3]:
%matplotlib inline

In [4]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math


USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

In [5]:
# corpus_name = "cornell movie-dialogs corpus"
# corpus = os.path.join("data", corpus_name)

def printLines(file, n=10):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

# printLines(os.path.join(corpus, "movie_lines.txt"))

In [6]:
# Splits each line of the file into a dictionary of fields
def loadLines(fileName, fields):
    lines = {}
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            lineObj = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj
    return lines


# Groups fields of lines from `loadLines` into conversations based on *movie_conversations.txt*
def loadConversations(fileName, lines, fields):
    conversations = []
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            convObj = {}
            for i, field in enumerate(fields):
                convObj[field] = values[i]
            # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
            lineIds = eval(convObj["utteranceIDs"])
            # Reassemble lines
            convObj["lines"] = []
            for lineId in lineIds:
                convObj["lines"].append(lines[lineId])
            conversations.append(convObj)
    return conversations


# Extracts pairs of sentences from conversations
def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        # Iterate over all the lines of the conversation
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()
            # Filter wrong samples (if one of the lists is empty)
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

Now we’ll call these functions and create the file. We’ll call it
*formatted_movie_lines.txt*.




In [7]:
# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)

In [8]:
MAX_LENGTH = 10  # Maximum sentence length to consider

# Turn a Unicode string to plain ASCII, thanks to
# http://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

# Read query/response pairs and return a voc object
def readVocs(datafile, corpus_name):
    print("Reading lines...")
    # Read the file and split into lines
    lines = open(datafile, encoding='utf-8').\
        read().strip().split('\n')
    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split(',')] for l in lines[1:]]
    voc = Voc(corpus_name)
    return voc, pairs

import ipdb
# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filterPair(p):
    # Input sequences need to preserve the last word for EOS token
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

# Using the functions defined above, return a populated voc object and pairs list
def loadPrepareData(corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
#     pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs

corpus_name = 'sarch_query'
datafile = 'data/Batch_generation_2/train_step.csv'

# Load/Assemble voc and pairs
save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus_name, datafile, save_dir)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
    print(pair)

Start preparing training data ...
Reading lines...
Read 1789 sentence pairs
Trimmed to 1789 sentence pairs
Counting words...
Counted words: 3514

pairs:
['remove chicken thighs brine pat dry paper towel', 'dry chicken needs']
['place paper towels drain excess oil', 'include reason draining excess oil']
['add oregano garlic powder cumin chili powder cayenne salt pepper stir well covered cook another minutes', 'would useful know heat needs adjusted step']
['traditionally turkish kisir eaten lettuce leaves serve lettuce leaf leave side wrap kisir inside lettuce leaves eat also serve tomatoes alongside', 'paprika subbed red pepper flakes']
['deglaze skillet wine add cream chile puree cook reduced desired consistency stir chives', 'making recipe gathering event ensure inform guests guest recipe calls wine']
['large skillet heat oil butter medium high heat add bacon onions salt pepper sautee medium heat minutes onions browned bit bacon crisped add garlic cook additional seconds remove mixtur

In [9]:
MIN_COUNT = 3    # Minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        # Only keep pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs


# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 1621 / 3511 = 0.4617
Trimmed from 1789 pairs to 714, 0.3991 of total


In [10]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]


def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

# Returns padded input sequence tensor and lengths
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.ByteTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len


# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[  12,   67,   87,  626,  876],
        [ 619,   87,  106,  176, 1046],
        [  86,  181,  177,  178,    2],
        [ 683,   49,   56,   87,    0],
        [ 100,  721,  522,  111,    0],
        [1289,  901,  671,  983,    0],
        [ 860, 1549,  730,    2,    0],
        [ 886,   20,    2,    0,    0],
        [  86,  633,    0,    0,    0],
        [  16,  552,    0,    0,    0],
        [ 210,  248,    0,    0,    0],
        [ 124,  355,    0,    0,    0],
        [ 214,  348,    0,    0,    0],
        [  18,  780,    0,    0,    0],
        [ 683,   24,    0,    0,    0],
        [ 100,  488,    0,    0,    0],
        [ 217,  489,    0,    0,    0],
        [ 118,   25,    0,    0,    0],
        [ 214,    2,    0,    0,    0],
        [ 215,    0,    0,    0,    0],
        [ 227,    0,    0,    0,    0],
        [  77,    0,    0,    0,    0],
        [  78,    0,    0,    0,    0],
        [   2,    0,    0,    0,    0]])
lengths: tensor([24, 19

In [11]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding

        # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size'
        #   because our input size is a word embedding with number of features == hidden_size
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        # Convert word indexes to embeddings
        embedded = self.embedding(input_seq)
        # Pack padded batch of sequences for RNN module
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        # Forward pass through GRU
        outputs, hidden = self.gru(packed, hidden)
        # Unpack padding
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
        # Sum bidirectional GRU outputs
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
        # Return output and final hidden state
        return outputs, hidden

In [12]:
# Luong attention layer
class Attn(torch.nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = torch.nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t()

        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

In [13]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        # Keep for reference
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # Define layers
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):
        # Note: we run this one step (word) at a time
        # Get embedding of current input word
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        # Forward through unidirectional GRU
        rnn_output, hidden = self.gru(embedded, last_hidden)
        # Calculate attention weights from the current GRU output
        attn_weights = self.attn(rnn_output, encoder_outputs)
        # Multiply attention weights to encoder outputs to get new "weighted sum" context vector
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        # Concatenate weighted context vector and GRU output using Luong eq. 5
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        # Predict next word using Luong eq. 6
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        # Return output and final hidden state
        return output, hidden

In [14]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

In [15]:
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding,
          encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):

    # Zero gradients
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Set device options
    input_variable = input_variable.to(device)
    lengths = lengths.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)

    # Initialize variables
    loss = 0
    print_losses = []
    n_totals = 0

    # Forward pass through encoder
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

    # Create initial decoder input (start with SOS tokens for each sentence)
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)

    # Set initial decoder hidden state to the encoder's final hidden state
    decoder_hidden = encoder_hidden[:decoder.n_layers]

    # Determine if we are using teacher forcing this iteration
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # Forward batch of sequences through decoder one time step at a time
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Teacher forcing: next input is current target
            decoder_input = target_variable[t].view(1, -1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # No teacher forcing: next input is decoder's own current output
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    # Perform backpropatation
    loss.backward()

    # Clip gradients: gradients are modified in place
    _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Adjust model weights
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

In [16]:
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):

    # Load batches for each iteration
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
                      for _ in range(n_iteration)]

    # Initializations
    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1

    # Training loop
    print("Training...")
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        # Extract fields from batch
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss

        # Print progress
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        # Save checkpoint
        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

In [17]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        # Forward input through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            # Forward pass through decoder
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores

In [18]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    ### Format input sentence as a batch
    # words -> indexes
    try:
        indexes_batch = [indexesFromSentence(voc, sentence)]
    except:
        print(sentence)
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    # Use appropriate device
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, max_length)
    # indexes -> words
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words


def evaluateInput(encoder, decoder, searcher, voc):
    input_sentence = ''
    while(1):
        try:
            # Get input sentence
            input_sentence = input('> ')
            # Check if it is quit case
            if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalizeString(input_sentence)
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            print('Bot:', ' '.join(output_words))

        except KeyError:
            print("Error: Encountered unknown word.")

In [22]:
# Configure models
model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64

corpus_name = 'search_query'
alldatafile = 'data/Batch_generation_2/all.csv'
traindatafile = 'data/Batch_generation_2/train_step.csv'

voc, _ = loadPrepareData(corpus_name, alldatafile, save_dir)
_, pairs = loadPrepareData(corpus_name, traindatafile, save_dir)

voc.addWord('')



# Set checkpoint to load from; set to None if starting from scratch
loadFilename = None
checkpoint_iter = 4000
#loadFilename = os.path.join(save_dir, model_name, corpus_name,
#                            '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
#                            '{}_checkpoint.tar'.format(checkpoint_iter))


# Load model if a loadFilename is provided
if loadFilename:
    # If loading on same machine the model was trained on
    checkpoint = torch.load(loadFilename)
    # If loading a model trained on GPU to CPU
    #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']


print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Start preparing training data ...
Reading lines...
Read 2558 sentence pairs
Trimmed to 2558 sentence pairs
Counting words...
Counted words: 4955
Start preparing training data ...
Reading lines...
Read 1789 sentence pairs
Trimmed to 1789 sentence pairs
Counting words...
Counted words: 3514
Building encoder and decoder ...
Models built and ready to go!


In [23]:
def evaluateFile(encoder, decoder, searcher, voc, filename, targetname):
    text = list(csv.reader(open(filename, 'rt')))
    target = list(csv.reader(open(targetname, 'rt')))
    responses = []
#     input_sentence = ''
    for i, input_sentence in enumerate(text):
        try:
            # Get input sentence
#             input_sentence = input('> ')
            # Check if it is quit case
#             if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            try:
                input_sentence = normalizeString(input_sentence[0])
            except:
                print(i, input_sentence)
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            responses.append(' '.join(output_words))
#             print('Bot:', ' '.join(output_words))
        except KeyError:
            responses.append(' ')
#             print("Error: Encountered unknown word.")
    return text, target, responses

In [24]:
# Configure training/optimization
clip = 50.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
print_every = 1
save_every = 500

# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

# Initialize optimizers
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)


# Run training iterations
print("Starting Training!")
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building optimizers ...
Starting Training!
Initializing ...
Training...
Iteration: 1; Percent complete: 0.0%; Average loss: 8.5092
Iteration: 2; Percent complete: 0.1%; Average loss: 8.4773
Iteration: 3; Percent complete: 0.1%; Average loss: 8.4576
Iteration: 4; Percent complete: 0.1%; Average loss: 8.4173
Iteration: 5; Percent complete: 0.1%; Average loss: 8.3424
Iteration: 6; Percent complete: 0.1%; Average loss: 8.2327
Iteration: 7; Percent complete: 0.2%; Average loss: 8.0289
Iteration: 8; Percent complete: 0.2%; Average loss: 7.8094
Iteration: 9; Percent complete: 0.2%; Average loss: 7.4572
Iteration: 10; Percent complete: 0.2%; Average loss: 7.1638
Iteration: 11; Percent complete: 0.3%; Average loss: 7.7565
Iteration: 12; Percent complete: 0.3%; Average loss: 7.8383
Iteration: 13; Percent complete: 0.3%; Average loss: 7.8675
Iteration: 14; Percent complete: 0.4%; Average loss: 7.6309
Iteration: 15; Percent complete: 0.4%; Average loss: 7.6456
Iteration: 16; Percent complete: 0.4%

Iteration: 137; Percent complete: 3.4%; Average loss: 5.9431
Iteration: 138; Percent complete: 3.5%; Average loss: 5.8233
Iteration: 139; Percent complete: 3.5%; Average loss: 5.8584
Iteration: 140; Percent complete: 3.5%; Average loss: 6.0487
Iteration: 141; Percent complete: 3.5%; Average loss: 5.8040
Iteration: 142; Percent complete: 3.5%; Average loss: 5.9742
Iteration: 143; Percent complete: 3.6%; Average loss: 5.9390
Iteration: 144; Percent complete: 3.6%; Average loss: 6.2692
Iteration: 145; Percent complete: 3.6%; Average loss: 6.1085
Iteration: 146; Percent complete: 3.6%; Average loss: 6.1157
Iteration: 147; Percent complete: 3.7%; Average loss: 5.9369
Iteration: 148; Percent complete: 3.7%; Average loss: 5.8895
Iteration: 149; Percent complete: 3.7%; Average loss: 6.0102
Iteration: 150; Percent complete: 3.8%; Average loss: 5.7556
Iteration: 151; Percent complete: 3.8%; Average loss: 6.0170
Iteration: 152; Percent complete: 3.8%; Average loss: 5.8017
Iteration: 153; Percent 

Iteration: 272; Percent complete: 6.8%; Average loss: 5.2260
Iteration: 273; Percent complete: 6.8%; Average loss: 5.1976
Iteration: 274; Percent complete: 6.9%; Average loss: 5.1513
Iteration: 275; Percent complete: 6.9%; Average loss: 5.0851
Iteration: 276; Percent complete: 6.9%; Average loss: 5.3305
Iteration: 277; Percent complete: 6.9%; Average loss: 4.8839
Iteration: 278; Percent complete: 7.0%; Average loss: 5.0500
Iteration: 279; Percent complete: 7.0%; Average loss: 5.1859
Iteration: 280; Percent complete: 7.0%; Average loss: 5.1725
Iteration: 281; Percent complete: 7.0%; Average loss: 4.9561
Iteration: 282; Percent complete: 7.0%; Average loss: 5.1228
Iteration: 283; Percent complete: 7.1%; Average loss: 4.9959
Iteration: 284; Percent complete: 7.1%; Average loss: 5.3180
Iteration: 285; Percent complete: 7.1%; Average loss: 4.7123
Iteration: 286; Percent complete: 7.1%; Average loss: 5.2672
Iteration: 287; Percent complete: 7.2%; Average loss: 4.9720
Iteration: 288; Percent 

Iteration: 407; Percent complete: 10.2%; Average loss: 4.2142
Iteration: 408; Percent complete: 10.2%; Average loss: 3.8022
Iteration: 409; Percent complete: 10.2%; Average loss: 3.9697
Iteration: 410; Percent complete: 10.2%; Average loss: 4.3920
Iteration: 411; Percent complete: 10.3%; Average loss: 4.0877
Iteration: 412; Percent complete: 10.3%; Average loss: 4.3438
Iteration: 413; Percent complete: 10.3%; Average loss: 4.2854
Iteration: 414; Percent complete: 10.3%; Average loss: 3.9374
Iteration: 415; Percent complete: 10.4%; Average loss: 4.3666
Iteration: 416; Percent complete: 10.4%; Average loss: 4.2117
Iteration: 417; Percent complete: 10.4%; Average loss: 4.5921
Iteration: 418; Percent complete: 10.4%; Average loss: 4.1801
Iteration: 419; Percent complete: 10.5%; Average loss: 4.3298
Iteration: 420; Percent complete: 10.5%; Average loss: 4.0632
Iteration: 421; Percent complete: 10.5%; Average loss: 4.4955
Iteration: 422; Percent complete: 10.5%; Average loss: 4.1049
Iteratio

Iteration: 540; Percent complete: 13.5%; Average loss: 3.6538
Iteration: 541; Percent complete: 13.5%; Average loss: 3.6419
Iteration: 542; Percent complete: 13.6%; Average loss: 3.2920
Iteration: 543; Percent complete: 13.6%; Average loss: 3.0126
Iteration: 544; Percent complete: 13.6%; Average loss: 3.2603
Iteration: 545; Percent complete: 13.6%; Average loss: 3.3331
Iteration: 546; Percent complete: 13.7%; Average loss: 3.1879
Iteration: 547; Percent complete: 13.7%; Average loss: 3.3580
Iteration: 548; Percent complete: 13.7%; Average loss: 3.3272
Iteration: 549; Percent complete: 13.7%; Average loss: 3.1302
Iteration: 550; Percent complete: 13.8%; Average loss: 3.0494
Iteration: 551; Percent complete: 13.8%; Average loss: 3.2034
Iteration: 552; Percent complete: 13.8%; Average loss: 2.6962
Iteration: 553; Percent complete: 13.8%; Average loss: 3.5828
Iteration: 554; Percent complete: 13.9%; Average loss: 3.3253
Iteration: 555; Percent complete: 13.9%; Average loss: 3.3449
Iteratio

Iteration: 675; Percent complete: 16.9%; Average loss: 2.7547
Iteration: 676; Percent complete: 16.9%; Average loss: 2.3706
Iteration: 677; Percent complete: 16.9%; Average loss: 2.4942
Iteration: 678; Percent complete: 17.0%; Average loss: 2.3101
Iteration: 679; Percent complete: 17.0%; Average loss: 2.5942
Iteration: 680; Percent complete: 17.0%; Average loss: 2.5716
Iteration: 681; Percent complete: 17.0%; Average loss: 2.5687
Iteration: 682; Percent complete: 17.1%; Average loss: 2.4339
Iteration: 683; Percent complete: 17.1%; Average loss: 2.3385
Iteration: 684; Percent complete: 17.1%; Average loss: 2.6662
Iteration: 685; Percent complete: 17.1%; Average loss: 2.5766
Iteration: 686; Percent complete: 17.2%; Average loss: 2.4688
Iteration: 687; Percent complete: 17.2%; Average loss: 2.4427
Iteration: 688; Percent complete: 17.2%; Average loss: 2.7035
Iteration: 689; Percent complete: 17.2%; Average loss: 2.6649
Iteration: 690; Percent complete: 17.2%; Average loss: 2.5337
Iteratio

Iteration: 810; Percent complete: 20.2%; Average loss: 1.9948
Iteration: 811; Percent complete: 20.3%; Average loss: 1.9463
Iteration: 812; Percent complete: 20.3%; Average loss: 1.8189
Iteration: 813; Percent complete: 20.3%; Average loss: 1.9172
Iteration: 814; Percent complete: 20.3%; Average loss: 1.7909
Iteration: 815; Percent complete: 20.4%; Average loss: 1.9852
Iteration: 816; Percent complete: 20.4%; Average loss: 1.9835
Iteration: 817; Percent complete: 20.4%; Average loss: 1.8973
Iteration: 818; Percent complete: 20.4%; Average loss: 2.0986
Iteration: 819; Percent complete: 20.5%; Average loss: 1.7807
Iteration: 820; Percent complete: 20.5%; Average loss: 1.7954
Iteration: 821; Percent complete: 20.5%; Average loss: 1.9116
Iteration: 822; Percent complete: 20.5%; Average loss: 2.0794
Iteration: 823; Percent complete: 20.6%; Average loss: 1.8888
Iteration: 824; Percent complete: 20.6%; Average loss: 1.7763
Iteration: 825; Percent complete: 20.6%; Average loss: 2.0072
Iteratio

Iteration: 944; Percent complete: 23.6%; Average loss: 1.3639
Iteration: 945; Percent complete: 23.6%; Average loss: 1.3696
Iteration: 946; Percent complete: 23.6%; Average loss: 1.3872
Iteration: 947; Percent complete: 23.7%; Average loss: 1.2983
Iteration: 948; Percent complete: 23.7%; Average loss: 1.4106
Iteration: 949; Percent complete: 23.7%; Average loss: 1.3167
Iteration: 950; Percent complete: 23.8%; Average loss: 1.2362
Iteration: 951; Percent complete: 23.8%; Average loss: 1.3283
Iteration: 952; Percent complete: 23.8%; Average loss: 1.3383
Iteration: 953; Percent complete: 23.8%; Average loss: 1.2214
Iteration: 954; Percent complete: 23.8%; Average loss: 1.2558
Iteration: 955; Percent complete: 23.9%; Average loss: 1.1587
Iteration: 956; Percent complete: 23.9%; Average loss: 1.2170
Iteration: 957; Percent complete: 23.9%; Average loss: 1.3823
Iteration: 958; Percent complete: 23.9%; Average loss: 1.2811
Iteration: 959; Percent complete: 24.0%; Average loss: 1.2885
Iteratio

Iteration: 1076; Percent complete: 26.9%; Average loss: 1.0200
Iteration: 1077; Percent complete: 26.9%; Average loss: 0.8760
Iteration: 1078; Percent complete: 27.0%; Average loss: 0.9104
Iteration: 1079; Percent complete: 27.0%; Average loss: 0.9722
Iteration: 1080; Percent complete: 27.0%; Average loss: 0.9865
Iteration: 1081; Percent complete: 27.0%; Average loss: 0.9708
Iteration: 1082; Percent complete: 27.1%; Average loss: 0.9151
Iteration: 1083; Percent complete: 27.1%; Average loss: 0.9646
Iteration: 1084; Percent complete: 27.1%; Average loss: 0.8426
Iteration: 1085; Percent complete: 27.1%; Average loss: 0.9211
Iteration: 1086; Percent complete: 27.2%; Average loss: 0.9566
Iteration: 1087; Percent complete: 27.2%; Average loss: 0.9259
Iteration: 1088; Percent complete: 27.2%; Average loss: 0.9410
Iteration: 1089; Percent complete: 27.2%; Average loss: 1.1435
Iteration: 1090; Percent complete: 27.3%; Average loss: 0.8371
Iteration: 1091; Percent complete: 27.3%; Average loss:

Iteration: 1209; Percent complete: 30.2%; Average loss: 0.6314
Iteration: 1210; Percent complete: 30.2%; Average loss: 0.6598
Iteration: 1211; Percent complete: 30.3%; Average loss: 0.5374
Iteration: 1212; Percent complete: 30.3%; Average loss: 0.6918
Iteration: 1213; Percent complete: 30.3%; Average loss: 0.5990
Iteration: 1214; Percent complete: 30.3%; Average loss: 0.5670
Iteration: 1215; Percent complete: 30.4%; Average loss: 0.5867
Iteration: 1216; Percent complete: 30.4%; Average loss: 0.5691
Iteration: 1217; Percent complete: 30.4%; Average loss: 0.6363
Iteration: 1218; Percent complete: 30.4%; Average loss: 0.6466
Iteration: 1219; Percent complete: 30.5%; Average loss: 0.6262
Iteration: 1220; Percent complete: 30.5%; Average loss: 0.7083
Iteration: 1221; Percent complete: 30.5%; Average loss: 0.6484
Iteration: 1222; Percent complete: 30.6%; Average loss: 0.6803
Iteration: 1223; Percent complete: 30.6%; Average loss: 0.4973
Iteration: 1224; Percent complete: 30.6%; Average loss:

Iteration: 1342; Percent complete: 33.6%; Average loss: 0.5874
Iteration: 1343; Percent complete: 33.6%; Average loss: 0.4466
Iteration: 1344; Percent complete: 33.6%; Average loss: 0.4460
Iteration: 1345; Percent complete: 33.6%; Average loss: 0.3941
Iteration: 1346; Percent complete: 33.7%; Average loss: 0.4407
Iteration: 1347; Percent complete: 33.7%; Average loss: 0.3676
Iteration: 1348; Percent complete: 33.7%; Average loss: 0.3910
Iteration: 1349; Percent complete: 33.7%; Average loss: 0.4355
Iteration: 1350; Percent complete: 33.8%; Average loss: 0.3972
Iteration: 1351; Percent complete: 33.8%; Average loss: 0.4156
Iteration: 1352; Percent complete: 33.8%; Average loss: 0.4022
Iteration: 1353; Percent complete: 33.8%; Average loss: 0.3494
Iteration: 1354; Percent complete: 33.9%; Average loss: 0.3857
Iteration: 1355; Percent complete: 33.9%; Average loss: 0.4478
Iteration: 1356; Percent complete: 33.9%; Average loss: 0.3667
Iteration: 1357; Percent complete: 33.9%; Average loss:

Iteration: 1475; Percent complete: 36.9%; Average loss: 0.3082
Iteration: 1476; Percent complete: 36.9%; Average loss: 0.2706
Iteration: 1477; Percent complete: 36.9%; Average loss: 0.2759
Iteration: 1478; Percent complete: 37.0%; Average loss: 0.2475
Iteration: 1479; Percent complete: 37.0%; Average loss: 0.2510
Iteration: 1480; Percent complete: 37.0%; Average loss: 0.2722
Iteration: 1481; Percent complete: 37.0%; Average loss: 0.2698
Iteration: 1482; Percent complete: 37.0%; Average loss: 0.2908
Iteration: 1483; Percent complete: 37.1%; Average loss: 0.2648
Iteration: 1484; Percent complete: 37.1%; Average loss: 0.2490
Iteration: 1485; Percent complete: 37.1%; Average loss: 0.2153
Iteration: 1486; Percent complete: 37.1%; Average loss: 0.2456
Iteration: 1487; Percent complete: 37.2%; Average loss: 0.2375
Iteration: 1488; Percent complete: 37.2%; Average loss: 0.2572
Iteration: 1489; Percent complete: 37.2%; Average loss: 0.2364
Iteration: 1490; Percent complete: 37.2%; Average loss:

Iteration: 1607; Percent complete: 40.2%; Average loss: 0.1548
Iteration: 1608; Percent complete: 40.2%; Average loss: 0.1680
Iteration: 1609; Percent complete: 40.2%; Average loss: 0.2170
Iteration: 1610; Percent complete: 40.2%; Average loss: 0.1720
Iteration: 1611; Percent complete: 40.3%; Average loss: 0.2124
Iteration: 1612; Percent complete: 40.3%; Average loss: 0.1928
Iteration: 1613; Percent complete: 40.3%; Average loss: 0.2052
Iteration: 1614; Percent complete: 40.4%; Average loss: 0.1620
Iteration: 1615; Percent complete: 40.4%; Average loss: 0.1873
Iteration: 1616; Percent complete: 40.4%; Average loss: 0.1955
Iteration: 1617; Percent complete: 40.4%; Average loss: 0.1471
Iteration: 1618; Percent complete: 40.5%; Average loss: 0.1721
Iteration: 1619; Percent complete: 40.5%; Average loss: 0.1452
Iteration: 1620; Percent complete: 40.5%; Average loss: 0.1702
Iteration: 1621; Percent complete: 40.5%; Average loss: 0.1886
Iteration: 1622; Percent complete: 40.6%; Average loss:

Iteration: 1738; Percent complete: 43.5%; Average loss: 0.1186
Iteration: 1739; Percent complete: 43.5%; Average loss: 0.1043
Iteration: 1740; Percent complete: 43.5%; Average loss: 0.1444
Iteration: 1741; Percent complete: 43.5%; Average loss: 0.1231
Iteration: 1742; Percent complete: 43.5%; Average loss: 0.1163
Iteration: 1743; Percent complete: 43.6%; Average loss: 0.0988
Iteration: 1744; Percent complete: 43.6%; Average loss: 0.1073
Iteration: 1745; Percent complete: 43.6%; Average loss: 0.1389
Iteration: 1746; Percent complete: 43.6%; Average loss: 0.1147
Iteration: 1747; Percent complete: 43.7%; Average loss: 0.1318
Iteration: 1748; Percent complete: 43.7%; Average loss: 0.1262
Iteration: 1749; Percent complete: 43.7%; Average loss: 0.1350
Iteration: 1750; Percent complete: 43.8%; Average loss: 0.1223
Iteration: 1751; Percent complete: 43.8%; Average loss: 0.1349
Iteration: 1752; Percent complete: 43.8%; Average loss: 0.1050
Iteration: 1753; Percent complete: 43.8%; Average loss:

Iteration: 1870; Percent complete: 46.8%; Average loss: 0.0855
Iteration: 1871; Percent complete: 46.8%; Average loss: 0.1072
Iteration: 1872; Percent complete: 46.8%; Average loss: 0.1142
Iteration: 1873; Percent complete: 46.8%; Average loss: 0.1161
Iteration: 1874; Percent complete: 46.9%; Average loss: 0.0907
Iteration: 1875; Percent complete: 46.9%; Average loss: 0.1054
Iteration: 1876; Percent complete: 46.9%; Average loss: 0.1150
Iteration: 1877; Percent complete: 46.9%; Average loss: 0.0956
Iteration: 1878; Percent complete: 46.9%; Average loss: 0.0821
Iteration: 1879; Percent complete: 47.0%; Average loss: 0.1308
Iteration: 1880; Percent complete: 47.0%; Average loss: 0.1211
Iteration: 1881; Percent complete: 47.0%; Average loss: 0.0922
Iteration: 1882; Percent complete: 47.0%; Average loss: 0.0906
Iteration: 1883; Percent complete: 47.1%; Average loss: 0.0893
Iteration: 1884; Percent complete: 47.1%; Average loss: 0.0677
Iteration: 1885; Percent complete: 47.1%; Average loss:

Iteration: 2001; Percent complete: 50.0%; Average loss: 0.0706
Iteration: 2002; Percent complete: 50.0%; Average loss: 0.0755
Iteration: 2003; Percent complete: 50.1%; Average loss: 0.0745
Iteration: 2004; Percent complete: 50.1%; Average loss: 0.0829
Iteration: 2005; Percent complete: 50.1%; Average loss: 0.0652
Iteration: 2006; Percent complete: 50.1%; Average loss: 0.0745
Iteration: 2007; Percent complete: 50.2%; Average loss: 0.0819
Iteration: 2008; Percent complete: 50.2%; Average loss: 0.0863
Iteration: 2009; Percent complete: 50.2%; Average loss: 0.0687
Iteration: 2010; Percent complete: 50.2%; Average loss: 0.0767
Iteration: 2011; Percent complete: 50.3%; Average loss: 0.1096
Iteration: 2012; Percent complete: 50.3%; Average loss: 0.0723
Iteration: 2013; Percent complete: 50.3%; Average loss: 0.0873
Iteration: 2014; Percent complete: 50.3%; Average loss: 0.0747
Iteration: 2015; Percent complete: 50.4%; Average loss: 0.0681
Iteration: 2016; Percent complete: 50.4%; Average loss:

Iteration: 2133; Percent complete: 53.3%; Average loss: 0.0492
Iteration: 2134; Percent complete: 53.3%; Average loss: 0.0633
Iteration: 2135; Percent complete: 53.4%; Average loss: 0.0607
Iteration: 2136; Percent complete: 53.4%; Average loss: 0.0463
Iteration: 2137; Percent complete: 53.4%; Average loss: 0.0489
Iteration: 2138; Percent complete: 53.4%; Average loss: 0.0559
Iteration: 2139; Percent complete: 53.5%; Average loss: 0.0689
Iteration: 2140; Percent complete: 53.5%; Average loss: 0.0597
Iteration: 2141; Percent complete: 53.5%; Average loss: 0.0758
Iteration: 2142; Percent complete: 53.5%; Average loss: 0.0474
Iteration: 2143; Percent complete: 53.6%; Average loss: 0.0455
Iteration: 2144; Percent complete: 53.6%; Average loss: 0.0504
Iteration: 2145; Percent complete: 53.6%; Average loss: 0.0629
Iteration: 2146; Percent complete: 53.6%; Average loss: 0.0774
Iteration: 2147; Percent complete: 53.7%; Average loss: 0.0760
Iteration: 2148; Percent complete: 53.7%; Average loss:

Iteration: 2264; Percent complete: 56.6%; Average loss: 0.0688
Iteration: 2265; Percent complete: 56.6%; Average loss: 0.0538
Iteration: 2266; Percent complete: 56.6%; Average loss: 0.0477
Iteration: 2267; Percent complete: 56.7%; Average loss: 0.0366
Iteration: 2268; Percent complete: 56.7%; Average loss: 0.0374
Iteration: 2269; Percent complete: 56.7%; Average loss: 0.0494
Iteration: 2270; Percent complete: 56.8%; Average loss: 0.0402
Iteration: 2271; Percent complete: 56.8%; Average loss: 0.0854
Iteration: 2272; Percent complete: 56.8%; Average loss: 0.0647
Iteration: 2273; Percent complete: 56.8%; Average loss: 0.0422
Iteration: 2274; Percent complete: 56.9%; Average loss: 0.0550
Iteration: 2275; Percent complete: 56.9%; Average loss: 0.0358
Iteration: 2276; Percent complete: 56.9%; Average loss: 0.0425
Iteration: 2277; Percent complete: 56.9%; Average loss: 0.0704
Iteration: 2278; Percent complete: 57.0%; Average loss: 0.0626
Iteration: 2279; Percent complete: 57.0%; Average loss:

Iteration: 2396; Percent complete: 59.9%; Average loss: 0.0413
Iteration: 2397; Percent complete: 59.9%; Average loss: 0.0425
Iteration: 2398; Percent complete: 60.0%; Average loss: 0.0387
Iteration: 2399; Percent complete: 60.0%; Average loss: 0.0483
Iteration: 2400; Percent complete: 60.0%; Average loss: 0.0575
Iteration: 2401; Percent complete: 60.0%; Average loss: 0.0486
Iteration: 2402; Percent complete: 60.1%; Average loss: 0.0539
Iteration: 2403; Percent complete: 60.1%; Average loss: 0.0513
Iteration: 2404; Percent complete: 60.1%; Average loss: 0.0420
Iteration: 2405; Percent complete: 60.1%; Average loss: 0.0587
Iteration: 2406; Percent complete: 60.2%; Average loss: 0.0322
Iteration: 2407; Percent complete: 60.2%; Average loss: 0.0500
Iteration: 2408; Percent complete: 60.2%; Average loss: 0.0672
Iteration: 2409; Percent complete: 60.2%; Average loss: 0.0366
Iteration: 2410; Percent complete: 60.2%; Average loss: 0.0580
Iteration: 2411; Percent complete: 60.3%; Average loss:

Iteration: 2530; Percent complete: 63.2%; Average loss: 0.0429
Iteration: 2531; Percent complete: 63.3%; Average loss: 0.0574
Iteration: 2532; Percent complete: 63.3%; Average loss: 0.0433
Iteration: 2533; Percent complete: 63.3%; Average loss: 0.0500
Iteration: 2534; Percent complete: 63.3%; Average loss: 0.0483
Iteration: 2535; Percent complete: 63.4%; Average loss: 0.0292
Iteration: 2536; Percent complete: 63.4%; Average loss: 0.0496
Iteration: 2537; Percent complete: 63.4%; Average loss: 0.0374
Iteration: 2538; Percent complete: 63.4%; Average loss: 0.0547
Iteration: 2539; Percent complete: 63.5%; Average loss: 0.0500
Iteration: 2540; Percent complete: 63.5%; Average loss: 0.0514
Iteration: 2541; Percent complete: 63.5%; Average loss: 0.0359
Iteration: 2542; Percent complete: 63.5%; Average loss: 0.0472
Iteration: 2543; Percent complete: 63.6%; Average loss: 0.0444
Iteration: 2544; Percent complete: 63.6%; Average loss: 0.0336
Iteration: 2545; Percent complete: 63.6%; Average loss:

Iteration: 2662; Percent complete: 66.5%; Average loss: 0.0299
Iteration: 2663; Percent complete: 66.6%; Average loss: 0.0304
Iteration: 2664; Percent complete: 66.6%; Average loss: 0.0598
Iteration: 2665; Percent complete: 66.6%; Average loss: 0.0469
Iteration: 2666; Percent complete: 66.6%; Average loss: 0.0622
Iteration: 2667; Percent complete: 66.7%; Average loss: 0.0241
Iteration: 2668; Percent complete: 66.7%; Average loss: 0.0284
Iteration: 2669; Percent complete: 66.7%; Average loss: 0.0247
Iteration: 2670; Percent complete: 66.8%; Average loss: 0.0312
Iteration: 2671; Percent complete: 66.8%; Average loss: 0.0232
Iteration: 2672; Percent complete: 66.8%; Average loss: 0.0344
Iteration: 2673; Percent complete: 66.8%; Average loss: 0.0581
Iteration: 2674; Percent complete: 66.8%; Average loss: 0.0433
Iteration: 2675; Percent complete: 66.9%; Average loss: 0.0518
Iteration: 2676; Percent complete: 66.9%; Average loss: 0.0357
Iteration: 2677; Percent complete: 66.9%; Average loss:

Iteration: 2795; Percent complete: 69.9%; Average loss: 0.0336
Iteration: 2796; Percent complete: 69.9%; Average loss: 0.0382
Iteration: 2797; Percent complete: 69.9%; Average loss: 0.0372
Iteration: 2798; Percent complete: 70.0%; Average loss: 0.0248
Iteration: 2799; Percent complete: 70.0%; Average loss: 0.0290
Iteration: 2800; Percent complete: 70.0%; Average loss: 0.0318
Iteration: 2801; Percent complete: 70.0%; Average loss: 0.0367
Iteration: 2802; Percent complete: 70.0%; Average loss: 0.0414
Iteration: 2803; Percent complete: 70.1%; Average loss: 0.0228
Iteration: 2804; Percent complete: 70.1%; Average loss: 0.0257
Iteration: 2805; Percent complete: 70.1%; Average loss: 0.0440
Iteration: 2806; Percent complete: 70.2%; Average loss: 0.0310
Iteration: 2807; Percent complete: 70.2%; Average loss: 0.0212
Iteration: 2808; Percent complete: 70.2%; Average loss: 0.0504
Iteration: 2809; Percent complete: 70.2%; Average loss: 0.0282
Iteration: 2810; Percent complete: 70.2%; Average loss:

Iteration: 2926; Percent complete: 73.2%; Average loss: 0.0434
Iteration: 2927; Percent complete: 73.2%; Average loss: 0.0259
Iteration: 2928; Percent complete: 73.2%; Average loss: 0.0391
Iteration: 2929; Percent complete: 73.2%; Average loss: 0.0268
Iteration: 2930; Percent complete: 73.2%; Average loss: 0.0426
Iteration: 2931; Percent complete: 73.3%; Average loss: 0.0391
Iteration: 2932; Percent complete: 73.3%; Average loss: 0.0255
Iteration: 2933; Percent complete: 73.3%; Average loss: 0.0511
Iteration: 2934; Percent complete: 73.4%; Average loss: 0.0286
Iteration: 2935; Percent complete: 73.4%; Average loss: 0.0166
Iteration: 2936; Percent complete: 73.4%; Average loss: 0.0531
Iteration: 2937; Percent complete: 73.4%; Average loss: 0.0384
Iteration: 2938; Percent complete: 73.5%; Average loss: 0.0257
Iteration: 2939; Percent complete: 73.5%; Average loss: 0.0367
Iteration: 2940; Percent complete: 73.5%; Average loss: 0.0226
Iteration: 2941; Percent complete: 73.5%; Average loss:

Iteration: 3057; Percent complete: 76.4%; Average loss: 0.0229
Iteration: 3058; Percent complete: 76.4%; Average loss: 0.0350
Iteration: 3059; Percent complete: 76.5%; Average loss: 0.0245
Iteration: 3060; Percent complete: 76.5%; Average loss: 0.0236
Iteration: 3061; Percent complete: 76.5%; Average loss: 0.0304
Iteration: 3062; Percent complete: 76.5%; Average loss: 0.0382
Iteration: 3063; Percent complete: 76.6%; Average loss: 0.0340
Iteration: 3064; Percent complete: 76.6%; Average loss: 0.0405
Iteration: 3065; Percent complete: 76.6%; Average loss: 0.0318
Iteration: 3066; Percent complete: 76.6%; Average loss: 0.0324
Iteration: 3067; Percent complete: 76.7%; Average loss: 0.0265
Iteration: 3068; Percent complete: 76.7%; Average loss: 0.0264
Iteration: 3069; Percent complete: 76.7%; Average loss: 0.0386
Iteration: 3070; Percent complete: 76.8%; Average loss: 0.0538
Iteration: 3071; Percent complete: 76.8%; Average loss: 0.0486
Iteration: 3072; Percent complete: 76.8%; Average loss:

Iteration: 3189; Percent complete: 79.7%; Average loss: 0.0156
Iteration: 3190; Percent complete: 79.8%; Average loss: 0.0345
Iteration: 3191; Percent complete: 79.8%; Average loss: 0.0326
Iteration: 3192; Percent complete: 79.8%; Average loss: 0.0252
Iteration: 3193; Percent complete: 79.8%; Average loss: 0.0287
Iteration: 3194; Percent complete: 79.8%; Average loss: 0.0441
Iteration: 3195; Percent complete: 79.9%; Average loss: 0.0166
Iteration: 3196; Percent complete: 79.9%; Average loss: 0.0232
Iteration: 3197; Percent complete: 79.9%; Average loss: 0.0367
Iteration: 3198; Percent complete: 80.0%; Average loss: 0.0196
Iteration: 3199; Percent complete: 80.0%; Average loss: 0.0148
Iteration: 3200; Percent complete: 80.0%; Average loss: 0.0281
Iteration: 3201; Percent complete: 80.0%; Average loss: 0.0323
Iteration: 3202; Percent complete: 80.0%; Average loss: 0.0242
Iteration: 3203; Percent complete: 80.1%; Average loss: 0.0373
Iteration: 3204; Percent complete: 80.1%; Average loss:

Iteration: 3322; Percent complete: 83.0%; Average loss: 0.0587
Iteration: 3323; Percent complete: 83.1%; Average loss: 0.0162
Iteration: 3324; Percent complete: 83.1%; Average loss: 0.0337
Iteration: 3325; Percent complete: 83.1%; Average loss: 0.0381
Iteration: 3326; Percent complete: 83.2%; Average loss: 0.0492
Iteration: 3327; Percent complete: 83.2%; Average loss: 0.0330
Iteration: 3328; Percent complete: 83.2%; Average loss: 0.0288
Iteration: 3329; Percent complete: 83.2%; Average loss: 0.0418
Iteration: 3330; Percent complete: 83.2%; Average loss: 0.0263
Iteration: 3331; Percent complete: 83.3%; Average loss: 0.0195
Iteration: 3332; Percent complete: 83.3%; Average loss: 0.0252
Iteration: 3333; Percent complete: 83.3%; Average loss: 0.0227
Iteration: 3334; Percent complete: 83.4%; Average loss: 0.0289
Iteration: 3335; Percent complete: 83.4%; Average loss: 0.0224
Iteration: 3336; Percent complete: 83.4%; Average loss: 0.0244
Iteration: 3337; Percent complete: 83.4%; Average loss:

Iteration: 3455; Percent complete: 86.4%; Average loss: 0.0415
Iteration: 3456; Percent complete: 86.4%; Average loss: 0.0172
Iteration: 3457; Percent complete: 86.4%; Average loss: 0.0575
Iteration: 3458; Percent complete: 86.5%; Average loss: 0.0259
Iteration: 3459; Percent complete: 86.5%; Average loss: 0.0208
Iteration: 3460; Percent complete: 86.5%; Average loss: 0.0272
Iteration: 3461; Percent complete: 86.5%; Average loss: 0.0341
Iteration: 3462; Percent complete: 86.6%; Average loss: 0.0280
Iteration: 3463; Percent complete: 86.6%; Average loss: 0.0174
Iteration: 3464; Percent complete: 86.6%; Average loss: 0.0148
Iteration: 3465; Percent complete: 86.6%; Average loss: 0.0283
Iteration: 3466; Percent complete: 86.7%; Average loss: 0.0496
Iteration: 3467; Percent complete: 86.7%; Average loss: 0.0166
Iteration: 3468; Percent complete: 86.7%; Average loss: 0.0197
Iteration: 3469; Percent complete: 86.7%; Average loss: 0.0215
Iteration: 3470; Percent complete: 86.8%; Average loss:

Iteration: 3587; Percent complete: 89.7%; Average loss: 0.0403
Iteration: 3588; Percent complete: 89.7%; Average loss: 0.0229
Iteration: 3589; Percent complete: 89.7%; Average loss: 0.0215
Iteration: 3590; Percent complete: 89.8%; Average loss: 0.0175
Iteration: 3591; Percent complete: 89.8%; Average loss: 0.0311
Iteration: 3592; Percent complete: 89.8%; Average loss: 0.0139
Iteration: 3593; Percent complete: 89.8%; Average loss: 0.0118
Iteration: 3594; Percent complete: 89.8%; Average loss: 0.0287
Iteration: 3595; Percent complete: 89.9%; Average loss: 0.0362
Iteration: 3596; Percent complete: 89.9%; Average loss: 0.0334
Iteration: 3597; Percent complete: 89.9%; Average loss: 0.0256
Iteration: 3598; Percent complete: 90.0%; Average loss: 0.0170
Iteration: 3599; Percent complete: 90.0%; Average loss: 0.0239
Iteration: 3600; Percent complete: 90.0%; Average loss: 0.0217
Iteration: 3601; Percent complete: 90.0%; Average loss: 0.0134
Iteration: 3602; Percent complete: 90.0%; Average loss:

Iteration: 3721; Percent complete: 93.0%; Average loss: 0.0146
Iteration: 3722; Percent complete: 93.0%; Average loss: 0.0183
Iteration: 3723; Percent complete: 93.1%; Average loss: 0.0222
Iteration: 3724; Percent complete: 93.1%; Average loss: 0.0217
Iteration: 3725; Percent complete: 93.1%; Average loss: 0.0348
Iteration: 3726; Percent complete: 93.2%; Average loss: 0.0194
Iteration: 3727; Percent complete: 93.2%; Average loss: 0.0266
Iteration: 3728; Percent complete: 93.2%; Average loss: 0.0399
Iteration: 3729; Percent complete: 93.2%; Average loss: 0.0184
Iteration: 3730; Percent complete: 93.2%; Average loss: 0.0141
Iteration: 3731; Percent complete: 93.3%; Average loss: 0.0220
Iteration: 3732; Percent complete: 93.3%; Average loss: 0.0203
Iteration: 3733; Percent complete: 93.3%; Average loss: 0.0167
Iteration: 3734; Percent complete: 93.3%; Average loss: 0.0159
Iteration: 3735; Percent complete: 93.4%; Average loss: 0.0150
Iteration: 3736; Percent complete: 93.4%; Average loss:

Iteration: 3852; Percent complete: 96.3%; Average loss: 0.0267
Iteration: 3853; Percent complete: 96.3%; Average loss: 0.0153
Iteration: 3854; Percent complete: 96.4%; Average loss: 0.0154
Iteration: 3855; Percent complete: 96.4%; Average loss: 0.0106
Iteration: 3856; Percent complete: 96.4%; Average loss: 0.0213
Iteration: 3857; Percent complete: 96.4%; Average loss: 0.0357
Iteration: 3858; Percent complete: 96.5%; Average loss: 0.0086
Iteration: 3859; Percent complete: 96.5%; Average loss: 0.0215
Iteration: 3860; Percent complete: 96.5%; Average loss: 0.0270
Iteration: 3861; Percent complete: 96.5%; Average loss: 0.0328
Iteration: 3862; Percent complete: 96.5%; Average loss: 0.0154
Iteration: 3863; Percent complete: 96.6%; Average loss: 0.0183
Iteration: 3864; Percent complete: 96.6%; Average loss: 0.0102
Iteration: 3865; Percent complete: 96.6%; Average loss: 0.0265
Iteration: 3866; Percent complete: 96.7%; Average loss: 0.0435
Iteration: 3867; Percent complete: 96.7%; Average loss:

Iteration: 3983; Percent complete: 99.6%; Average loss: 0.2859
Iteration: 3984; Percent complete: 99.6%; Average loss: 0.1630
Iteration: 3985; Percent complete: 99.6%; Average loss: 0.1314
Iteration: 3986; Percent complete: 99.7%; Average loss: 0.2152
Iteration: 3987; Percent complete: 99.7%; Average loss: 0.2588
Iteration: 3988; Percent complete: 99.7%; Average loss: 0.1786
Iteration: 3989; Percent complete: 99.7%; Average loss: 0.1001
Iteration: 3990; Percent complete: 99.8%; Average loss: 0.2286
Iteration: 3991; Percent complete: 99.8%; Average loss: 0.2189
Iteration: 3992; Percent complete: 99.8%; Average loss: 0.2454
Iteration: 3993; Percent complete: 99.8%; Average loss: 0.1889
Iteration: 3994; Percent complete: 99.9%; Average loss: 0.2453
Iteration: 3995; Percent complete: 99.9%; Average loss: 0.2124
Iteration: 3996; Percent complete: 99.9%; Average loss: 0.2010
Iteration: 3997; Percent complete: 99.9%; Average loss: 0.2115
Iteration: 3998; Percent complete: 100.0%; Average loss

Run Evaluation
~~~~~~~~~~~~~~

To chat with your model, run the following block.




In [49]:
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

# Begin chatting (uncomment and run the following line to begin)
evaluateInput(encoder, decoder, searcher, voc)

> cook
Bot: serv size e yield oz serv made life
> cook chicken
Bot: cook time temperatur would add garlic sprinkl chicken


KeyboardInterrupt: 

Conclusion
----------

That’s all for this one, folks. Congratulations, you now know the
fundamentals to building a generative chatbot model! If you’re
interested, you can try tailoring the chatbot’s behavior by tweaking the
model and training parameters and customizing the data that you train
the model on.

Check out the other tutorials for more cool deep learning applications
in PyTorch!




In [25]:
filename = 'data/Batch_generation_2/test_step_text.csv'
targetname = 'data/Batch_generation_2/test_step_target.csv'
encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

# Begin chatting (uncomment and run the following line to begin)
text, target, responses = evaluateFile(encoder, decoder, searcher, voc, filename=filename, targetname=targetname)

with open('data/Batch_generation_2/test_step_predict.csv', 'wt') as f:
    for l in responses:
        f.write('{}\n'.format(l))

In [26]:
import os
res = os.popen('perl multi_bleu.perl data/Batch_generation_2/test_step_target.csv < data/Batch_generation_2/test_step_predict.csv')
res.read()

'BLEU = 0.80, 2.7/0.6/0.5/0.5 (BP=1.000, ratio=1.029, hyp_len=3149, ref_len=3059)\n'

In [28]:
with open('data/Batch_generation_2/seq2seq_step_pred.txt', 'wt') as f:
    for i, (t, g, p) in enumerate(zip(text, target, responses)):
        print('\n\n\n')
        f.write('\n\n\n')
        if len(g) == 0:
            g = [None]
        print('TEXT:\n{}\nTARGET:\n{}\nPREDICT:\n{}'.format(t[0], g[0], p.strip()))
        f.write('TEXT:\n{}\nTARGET:\n{}\nPREDICT:\n{}'.format(t[0], g[0], p.strip()))
#     if i == 100:
#         break





TEXT:
additional recipes barbecue ideas available www walmart ca recipes
TARGET:
use raw material bought walmart make walmart recipes process easier
PREDICT:
couple pieces pepper red onion pieces tofu skewers heavy




TEXT:
combine ingredients pour ice sugared rim glass
TARGET:
sugared rim glass
PREDICT:
vinegar used instead lemon




TEXT:
using pastry brush coat sides pork tenderloin place roasting pan rack
TARGET:
happens pastry brush
PREDICT:
long oven preheated requisite temperature generally minutes




TEXT:
preheat oven 375 degrees f guarantee scones donγçöt stick line baking sheet parchment paper
TARGET:
temperature needs change altitude
PREDICT:
tell full ingredients steps shown pictures please generally




TEXT:
instruction given package
TARGET:
need know long mix ingredients
PREDICT:
need know blended smooth past




TEXT:
roll dough ball place top large sheet parchment paper flatten hands place another sheet parchment paper top flattened dough roll using rolling pin 

In [33]:
target

[[],
 [],
 [],
 [],
 ['cups packed light brown sugar tablespoons margarine tablespoons vegetable shortening cups dark molasses tablespoon baking soda cup boiling water cups all purpose flour sifted tablespoon ground cloves tablespoons ground ginger tablespoon ground cinnamon'],
 [],
 [],
 [],
 ['image of finished product'],
 [],
 [],
 ['peeled and cut into inch thick rounds bechamel sauce cup butter cups hot milk tablespoons flour eggs cup grated kefalograviera cheese or parm teaspoon salt'],
 [],
 ['visualized instruction'],
 [],
 ['how the chicken will look like after frying .'],
 [],
 ['tell exact amounts of ingredients'],
 ['place the cookie crusts in the freezer and make the banana ice cream'],
 ['show a video of this step and show the correct consistency .'],
 [],
 [],
 [],
 ['visualized instructions for clarifying purposes'],
 ['make the pink coconut cream .'],
 [],
 [],
 ['baking dish'],
 ['show images or video of this step . suggest turning the oven on or have the assistant tu