From <a href="https://pytorch.org/tutorials/beginner/chatbot_tutorial.html">Pytorch Tutorial</a> by Matthew Inkawhich

In [1]:
import codecs
import csv
import itertools
import math
import random
import re
import os
import unicodedata
from io import open

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.jit import script, trace

In [2]:
USE_CUDA = torch.cuda.is_available()
device = torch.device('cuda' if USE_CUDA else 'cpu')

# Load and Preprocess Data 
### (Cornell Movie-Dialogues Corpus)

In [3]:
DATA = '../../data'
corpus_name = 'cornell movie-dialogs corpus'
corpus = os.path.join(DATA, corpus_name)

In [4]:
def print_lines(file, n=10):
    with open(file, 'rb') as f:
        lines = f.readlines()
    for line in lines[:n]:
        print(line)

In [5]:
print_lines(os.path.join(corpus, 'movie_lines.txt'))

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


### Create Formatted Data File

In [6]:
def load_lines(file_name, fields):
    '''Splits each line of the file into dict of fields'''
    lines = {}
    with open(file_name, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(' +++$+++ ')
            line_obj = {}
            for i, field in enumerate(fields):
                line_obj[field] = values[i]
            lines[line_obj['lineID']] = line_obj
    return lines

In [7]:
def load_conversations(file_name, lines, fields):
    '''
    Groups fields of lines from load_lines() into conversations based on
    movie_conversations.txt
    '''
    conversations = []
    with open(file_name, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(' +++$+++ ')
            conv_obj = {}
            for i, field in enumerate(fields):
                conv_obj[field] = values[i]
            line_ids = eval(conv_obj['utteranceIDs'])
            conv_obj['lines'] = []
            for line_id in line_ids:
                conv_obj['lines'].append(lines[line_id])
            conversations.append(conv_obj)
    return conversations

In [8]:
def extract_sentence_pairs(conversations):
    '''Extract pairs of sentences from conversations'''
    qa_pairs = []
    for conversation in conversations:
        for i in range(len(conversation['lines']) - 1):
            input_line = conversation['lines'][i]['text'].strip()
            target_line = conversation['lines'][i + 1]['text'].strip()
            # Filter if one of the lists is empty
            if input_line and target_line:
                qa_pairs.append([input_line, target_line])
    return qa_pairs

In [9]:
datafile = os.path.join(corpus, 'formatted_movie_lines.txt')
delimiter = '\t'
delimiter = str(codecs.decode(delimiter, 'unicode_escape'))

In [10]:
lines = {}
conversations = []
MOVIE_LINES_FIELDS = [
    'lineID', 'characterID', 'movieID', 'character', 'text']
MOVIE_CONVERSATION_FIELDS = [
    'character1ID', 'character2ID', 'movieID', 'utteranceIDs']

In [11]:
print('\nProcessing corpus...')
lines = load_lines(os.path.join(corpus, 'movie_lines.txt'), 
                   MOVIE_LINES_FIELDS)

print('\nLoading conversations...')
conversations = load_conversations(
    os.path.join(corpus, 'movie_conversations.txt'), 
    lines,
    MOVIE_CONVERSATION_FIELDS)


Processing corpus...

Loading conversations...


In [12]:
print('\nWriting newly formatted file...')
with open(datafile, 'w', encoding='utf-8') as out:
    writer = csv.writer(out, delimiter=delimiter)
    for pair in extract_sentence_pairs(conversations):
        writer.writerow(pair)


Writing newly formatted file...


In [13]:
print('\nSample lines from file')
print_lines(datafile)


Sample lines from file
b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\r\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\r\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\r\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\r\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\r\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\r\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough...\r\n"

# Load and Trim Data

In [14]:
PAD = 0 # to pad short sentences
SOS = 1 # start of sentence
EOS = 2 # end of sentence

In [15]:
class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD: 'PAD', SOS: 'SOS', EOS: 'EOS'}
        self.n_words = 3 # Count SOS, EOS, PAD
        
    def add_sentence(self, sentence):
        for word in sentence.split():
            self.add_word(word)
            
    def add_word(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
        else:
            self.word2count[word] += 1
            
    # Remove words below min count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True
        keep_words = []
        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)
        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), 
            len(self.word2index), 
            len(keep_words) / len(self.word2index)))
        
        # Reinit dicts
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD: 'PAD', SOS: 'SOS', EOS: 'EOS'}
        self.n_words = 3
        for word in keep_words:
            self.add_word(word)

In [16]:
MAX_LEN = 10

In [17]:
def unicode_to_ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s)
                   if unicodedata.category(c) != 'Mn')

In [18]:
def normalize_string(s):
    s = unicode_to_ascii(s.lower().strip())
    s = re.sub(r'([.!?])', r' \1', s)
    s = re.sub(r'[^a-zA-Z.!?]+', r' ', s)
    s = re.sub(r'\s+', r' ', s).strip()
    return s

In [19]:
def read_vocs(datafile, corpus_name):
    print('Reading lines...')
    lines = open(datafile, encoding='utf-8').read().strip().split('\n')
    pairs = [[normalize_string(s) for s in line.split('\t')] 
             for line in lines]
    voc = Voc(corpus_name)
    return voc, pairs

In [20]:
def filter_pair(p):
    return len(p[0].split()) < MAX_LEN and len(p[1].split()) < MAX_LEN

In [21]:
def filter_pairs(pairs):
    return [pair for pair in pairs if filter_pair(pair)]

In [22]:
def load_prep_data(corpus, corpus_name, datafile, save_dir):
    print('Start prepping training data...')
    voc, pairs = read_vocs(datafile, corpus_name)
    print('Read {!s} sentence pairs'.format(len(pairs)))
    pairs = filter_pairs(pairs)
    print('Trimmed to {!s} sentence pairs'.format(len(pairs)))
    print('Counting words...')
    for pair in pairs:
        voc.add_sentence(pair[0])
        voc.add_sentence(pair[1])
    print('Counted words:', voc.n_words)
    return voc, pairs

In [23]:
save_dir = '%s/save' % DATA
voc, pairs = load_prep_data(corpus, corpus_name, datafile, save_dir)
print('\npairs:')
for pair in pairs[:10]:
    print(pair)

Start prepping training data...
Reading lines...
Read 221282 sentence pairs
Trimmed to 64271 sentence pairs
Counting words...
Counted words: 3

pairs:
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [24]:
MIN_COUNT = 3

In [25]:
def trim_rare_words(voc, pairs):
    voc.trim(MIN_COUNT)
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        for word in input_sentence.split():
            if word not in voc.word2index:
                keep_input = False
                break
        for word in output_sentence.split():
            if word not in voc.word2index:
                keep_output = False
                break
        if keep_input and keep_output:
            keep_pairs.append(pair)
    print('Trimmed from {} pairs to {}, {:.4f} of total'.format(
        len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs

In [26]:
pairs = trim_rare_words(voc, pairs)

keep_words 7822 / 18004 = 0.4345
Trimmed from 64271 pairs to 53165, 0.8272 of total


### Prepare Data for Models

In [27]:
def indices_from_sentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split()] + [EOS]

In [28]:
def zero_padding(l, fill_value=PAD):
    return list(itertools.zip_longest(*l, fillvalue=fill_value))

In [29]:
def binary_matrix(l, value=PAD):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

In [30]:
def input_var(l, voc):
    indices_batch = [indices_from_sentence(voc, sentence) 
                     for sentence in l]
    lengths = torch.tensor([len(indices) for indices in indices_batch])
    pad_list = zero_padding(indices_batch)
    pad_var = torch.LongTensor(pad_list)
    return pad_var, lengths

In [31]:
def output_var(l, voc):
    indices_batch = [indices_from_sentence(voc, sentence) 
                     for sentence in l]
    max_target_len = max([len(indices) for indices in indices_batch])
    pad_list = zero_padding(indices_batch)
    mask = binary_matrix(pad_list)
    mask = torch.ByteTensor(mask)
    pad_var = torch.LongTensor(pad_list)
    return pad_var, mask, max_target_len

In [32]:
def batch2train_data(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split()), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = input_var(input_batch, voc)
    output, mask, max_target_len = output_var(output_batch, voc)
    return inp, lengths, output, mask, max_target_len

In [33]:
# Example for validataion
BATCH = 5
batches = batch2train_data(voc, 
                           [random.choice(pairs) for _ in range(BATCH)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print('input_variable:', input_variable)
print('lengths:', lengths)
print('target_variable:', target_variable)
print('mask:', mask)
print('max_target_len:', max_target_len)

input_variable: tensor([[3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3],
        [3, 3, 3, 2, 2],
        [3, 3, 3, 0, 0],
        [3, 3, 3, 0, 0],
        [3, 2, 2, 0, 0],
        [2, 0, 0, 0, 0]])
lengths: tensor([8, 7, 7, 4, 4])
target_variable: tensor([[3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3],
        [3, 3, 3, 3, 2],
        [3, 2, 3, 3, 0],
        [3, 0, 3, 3, 0],
        [3, 0, 3, 3, 0],
        [3, 0, 3, 3, 0],
        [3, 0, 2, 3, 0],
        [3, 0, 0, 3, 0],
        [2, 0, 0, 2, 0]])
mask: tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0],
        [1, 0, 1, 1, 0],
        [1, 0, 1, 1, 0],
        [1, 0, 1, 1, 0],
        [1, 0, 1, 1, 0],
        [1, 0, 0, 1, 0],
        [1, 0, 0, 1, 0]], dtype=torch.uint8)
max_target_len: 10


# Define Models
### Encoder

In [34]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding
        
        # Init GRU; input_size and hidden_size are both set to 
        #   'hidden_size' bc input size is a word embedding wtih 
        #    n_features == hidden_size
        self.gru = nn.GRU(hidden_size, 
                          hidden_size, 
                          n_layers, 
                          dropout=(0 if n_layers == 1 else dropout), 
                          bidirectional=True)
        
    def forward(self, input_seq, input_lens, hidden=None):
        # Word index -> embedding
        embedded = self.embedding(input_seq)
        # Pack padded batch of seqs for RNN mod
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, 
                                                         input_lens)
        # GRU forward pass
        outputs, hidden = self.gru(packed, hidden)
        # Unpack padding
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequences(outputs)
        # Sum bidirectional outputs
        outputs = (outputs[:, :, :self.hidden_size] 
                   + outputs[:, :, self.hidden_size:])
        return outputs, hidden

### Decoder (with Attention)

In [35]:
class Attn(torch.nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, 
                             'is not an appropriate attention method')
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = torch.nnLinear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = torch.nn.Linear(self.hedden_size*2, hidden_size)
            self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))
            
    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)
    
    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)
    
    def concat_score(self, hidden, encoder_output):
        energy = self.attn(
            torch.cat(
                (hidden.expand(encoder_output.size(0), -1, -1),
                 encoder_output),
                2))\
            .tanh()
        return torch.sum(self.v * energy, dim=2)
    
    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) according to method
        attn_energies = {
            'general': self.general_score,
            'concat': self.concat_score,
            'dot': self.dot_score}[self.method](hidden, encoder_outputs)
        # Transpose max_length and batch size dims
        attn_energies = attn_energies.t()
        # Return softmax normalized probs with added dim
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

In [36]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(
            self, attn_model, embedding, hidden_size, output_size, 
            n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, 
                          hidden_size, 
                          n_layers, 
                          dropout=0 if n_layers == 1 else dropout)
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.attn = Attn(attn_model, hidden_size)
        
    def forward(self, input_step, last_hidden, encoder_outputs):
        # Runs one word at a time
        # Get embedding of current input word
        embedded = self.embedding(inpuy_step)
        embedded = slef.embedding_dropout(embedded)
        # Forward through unidirectional GRU
        rnn_output, hidden = self.gru(embedded, last_hidden)
        # Attn weights for current GRU output
        attn_weighs = self.attn(rnn_output, encoder_outputs)
        # Mult attn weights and encoder out to get 'weighted sum' vec
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        # Concat weighted ctxt vec and GRU out using Luong eq. 5
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        # Predict with Luong eq. 6
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        # Final hidden state:
        return output, hidden

# Define Training Procedure
### Masked Loss

In [37]:
def maskNLLLoss(inp, target, mask):
    n_total = mask.sum()
    cross_ent = -torch.log(
        torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = cross_ent.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, n_totel.item()

In [38]:
def train(
        input_var, lens, target_var, mask, max_target_len, encoder,
        decoder, embedding, encoder_optimizer, decoder_optimizer, 
        batch_size, clip, max_len):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    input_var = input_var.to(device)
    lens = lens.to(device)
    target_var = target_var.to(device)
    mask = mask.to(device)
    loss = 0
    print_losses = []
    n_totals = 0
    
    # Forward pass
    encoder_outputs, encoder_hidden = encoder(input_var, lens)
    
    # Init decoder input with SOS at start of each sent
    decoder_input = torch.LongTensor([[SOS for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)
    # Trigger init state from prev output
    decoder_hidden = encoder_hidden[:decoder.n_layers]
    use_teacher_forcing = (True if random.random() < teacher_forcing_ratio 
                           else False)
    # Forward through decoder
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            # Teacher forcing: next input is current target
            decoder_input = target_var[t].view(1, -1)
            mask_loss, n_total = maskNLLLoss(
                decoder_output, target_var[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * n_total)
            n_totals += n_total
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            # No teacher forcing: next input is deecoder's own output
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([
                [topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            mask_loss, n_total = maskNLLLoss(
                decoder_output, target_var[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * n_total)
            n_totals += n_total
    
    # Backprop
    loss.backward()
    _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)
    
    # Adj weights
    encoder_optimizer.step()
    decoder_optimizer.step()
    return sum(print_losses) / n_totals

### Training Iterations

In [39]:
def train_iters(
        mod_name, voc, pairs, encoder, decoder, encoder_optimizer, 
        decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers,
        save_dir, n_iter, batch_size, print_every, save_every, clip,
        corpus_name, load_file_name):
    # Load batches
    training_batches = [
        batch2train_data(
            voc, 
            [random.choice(pairs) for _ in range(batch_size)]) 
        for _ in range(n_iter)]
    # Init
    print('Initializing...')
    start_iter = 1
    print_loss = 0
    if load_file_name:
        start_iter = checkpoint['iteration'] + 1
    # Train
    print('Training...')
    for i in range(start_iter, n_iter + 1):
        training_batch = training_batches[i - 1]
        # Extract fields from batch
        input_var, lens, target_var, mask, max_target_len = training_batch
        # Run a training iteration with batch
        loss = train(
            input_var, lens, target_var, mask, max_target_len, encoder,
            decoder, embedding, encoder_optimizer, decoder_optimizer, 
            batch_size, clip)
        print_loss += loss
        # Print progress
        if i % print_every == 0:
            print_loss_avg = print_loss / print_every
            print('Iter: {}; Perc. completer: {:.1f}%; Avg. loss: {:.4f}'\
                  .format(i, i / n_iter * 100, print_loss_avg))
            print_loss = 0
        # Save checkpoint
        if i % save_every == 0:
            directory = os.path.join(
                save_dir, 
                mod_name, 
                corpus_name, 
                '{}-{}_{}'.format(
                    encoder_n_layears, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torchs.save({'iteration': i,
                         'en': encoder.state_dict(),
                         'de': decoder.state_dict(),
                         'en_opt': encoder_optimizer.state_dict(),
                         'de_opt': decoder_optimizer.state_dict(),
                         'loss': loss,
                         'voc_dict': voc.__dict__,
                         'embedding': embedding.state_dict()},
                       os.path.join(directory, '{}_{}.tar'.format(
                           i, 'checkpoint')))

In [40]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, input_seq, input_len, max_len):
        encoder_outputs, encoder_hidden = self.encoder(input_seq, 
                                                       input_len)
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        decoder_input = (
            torch.ones(1, 1, device=device, dtype=torch.long) * SOS)
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        for _ in range(max_len):
            decoder_output, decoder_hidden = slef.decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            decoder_scores, decoder_input = torch.max(decoder_output, 
                                                      dim=1)
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        return all_tokens, all_scores

### Evaluate Text

In [41]:
def evalauate(encoder, decoder, searcher, voc, sentence, max_len=MAX_LEN):
    # words -> indices
    index_batch = [indices_from_sentences(voc, sentence)]
    lens = torch.tensor([len(indices) for indices in index_batch])
    input_batch = torch.LongTensor(index_batch).transpose(0, 1)
    input_batch = input_batch.to(device)
    lens = lens.to(device)
    # Decode
    tokens, scores = searcher(input_batch, lens, max_len)
    # ind -> word
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words

In [42]:
def evaluate_input(encoder, decoder, searcher, voc):
    input_sentence = ''
    while True:
        try:
            input_sentence = input('> ')
            if input_sentence == 'q' or input_sentence == 'quit': break
            input_sentence = normaliz_string(input_sentence)
            output_words = evaluate(
                encoder, decoder, searcher, voc, input_sentence)
            output_words[:] = [
                x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            print('Bot:', ' '.join(output_words))
        except KeyError:
            print('Error: Encountered unknown word.')

# Run Model

In [43]:
model_name = 'cb_model'
attn_model = 'dot' # 'general', 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64

In [44]:
load_filename = None
chckpoint_iter = 4000
#load_filename = os.path.join(
#    save_dir, 
#    model_name, 
#    corpus_name, 
#    '{}-{}-{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
#    '{}_checkpoint.tar'.format(checkpoint_iter))

In [45]:
if load_filename:
    checkpoint = torch.load(load_filename)
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimiser_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']

In [46]:
print('Building encoder and decoder...')
embedding = nn.Embedding(voc.n_words, hidden_size)
if load_filename:
    embedding.load_state_dict(embedding_sde)

Building encoder and decoder...


In [48]:
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(
    attn_model, 
    embedding, 
    hidden_size, 
    voc.n_words, 
    decoder_n_layers, 
    dropout)

In [49]:
if load_filename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)

encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models ready to go.')

Models ready to go.


# Run Training