In [1]:
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
import random
import os
import re
import csv
import codecs
import itertools
import unicodedata

In [2]:
gpu = torch.cuda.is_available()
device = torch.device('cuda' if gpu else 'cpu')

In [3]:
FOLDER = "cornell_movie_dialogs_corpus"

In [4]:
line_path = os.path.join(FOLDER, 'movie_lines.txt')
conv_path = os.path.join(FOLDER, 'movie_conversations.txt')

In [5]:
with open(line_path, 'rb') as f:
    lines = f.readlines()
for l in lines[:10]:
    print(l.rstrip())

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go."
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie."
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?'


In [6]:
fields = ['lineid', 'charid', 'movieid', 'charname', 'text']
lines = {}
with open(line_path, 'r', encoding='iso-8859-1') as f:
    for line in f.readlines():
        split = line.split(' +++$+++ ')
        lineobj = {}
        for i, fname in enumerate(fields):
            lineobj[fname] = split[i]
        lines[lineobj['lineid']] = lineobj

In [7]:
lines['L105']

{'lineid': 'L105',
 'charid': 'u8',
 'movieid': 'm0',
 'charname': 'MISS PERKY',
 'text': "Well, yes, compared to your other choices of expression this year, today's events are quite mild.  By the way, Bobby Rictor's gonad retrieval operation went quite well, in case you're interested.\n"}

In [8]:
with open(conv_path, 'rb') as f:
    c_lines = f.readlines()
for l in c_lines[:10]:
    print(l.rstrip())

b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L194', 'L195', 'L196', 'L197']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L198', 'L199']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L200', 'L201', 'L202', 'L203']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L204', 'L205', 'L206']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L207', 'L208']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L271', 'L272', 'L273', 'L274', 'L275']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L276', 'L277']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L280', 'L281']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L363', 'L364']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L365', 'L366']"


In [9]:
fields = ['char1id', 'char2id', 'movieid', 'lineids']
conversations = []
with open(conv_path, 'r', encoding='iso-8859-1') as f:
    for line in f.readlines():
        split = line.split(' +++$+++ ')
        convobj = {}
        for i, fname in enumerate(fields):
            convobj[fname] = split[i]
        convobj['lineids'] = eval(convobj['lineids'])
        #creating a key for the dict
        all_lines = []
        for lid in convobj['lineids']:
            all_lines.append(lines[lid])
        convobj['lines'] = all_lines
        conversations.append(convobj)

In [10]:
conversations[0]

{'char1id': 'u0',
 'char2id': 'u2',
 'movieid': 'm0',
 'lineids': ['L194', 'L195', 'L196', 'L197'],
 'lines': [{'lineid': 'L194',
   'charid': 'u0',
   'movieid': 'm0',
   'charname': 'BIANCA',
   'text': 'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\n'},
  {'lineid': 'L195',
   'charid': 'u2',
   'movieid': 'm0',
   'charname': 'CAMERON',
   'text': "Well, I thought we'd start with pronunciation, if that's okay with you.\n"},
  {'lineid': 'L196',
   'charid': 'u0',
   'movieid': 'm0',
   'charname': 'BIANCA',
   'text': 'Not the hacking and gagging and spitting part.  Please.\n'},
  {'lineid': 'L197',
   'charid': 'u2',
   'movieid': 'm0',
   'charname': 'CAMERON',
   'text': "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"}]}

In [11]:
qa_pairs = []
for conversation in conversations:
    for i in range(len(conversation['lines']) - 1):
        q = conversation['lines'][i]['text'].strip()
        a = conversation['lines'][1+i]['text'].strip()
        if q and a:
            qa_pairs.append((q,a))

In [12]:
qa_pairs[0]

('Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.',
 "Well, I thought we'd start with pronunciation, if that's okay with you.")

In [13]:
#save to csv
csv_name = 'formatted_lines.txt'
csv_path = os.path.join(FOLDER, csv_name)
delim = '\t'
delim = str(codecs.decode(delim, 'unicode_escape'))

with open(csv_path, 'w', encoding='utf-8') as fn:
    writer = csv.writer(fn, delimiter=delim)
    for qa in qa_pairs:
        writer.writerow(qa)
print('done')

done


In [14]:
with open(csv_path, 'r') as f:
    c_lines = f.readlines()[:5]
print(c_lines)

["Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n", "Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\n", "Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n", "You're asking me out.  That's so cute. What's your name again?\tForget it.\n", "No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"]


In [15]:
PAD = 0
SOS = 1
EOS = 2

class Vocabulary(object):
    def __init__(self, name):
        self.name = name
        self.word2id = {}
        self.word2count = {}
        self.id2word = {PAD:'PAD', SOS:'SOS', EOS:'EOS'}
        self.word_count = 3
    
    def add_word(self, word):
        if word in self.word2id:
            self.word2count[word] += 1
        else:
            self.word2id[word] = self.word_count
            self.word2count[word] = 1
            self.id2word[self.word_count] = word
            self.word_count += 1
    
    def add_sentence(self, sent):
        for word in sent.split():
            self.add_word(word)
            
    def trim(self, min_count):
        keep_list = []
        for k,v in zip(self.word2count.keys(), self.word2count.values()):
            if v >= min_count:
                keep_list.append(k)
        self.word2id = {}
        self.word2count = {}
        self.id2word = {PAD:'PAD', SOS:'SOS', EOS:'EOS'}
        self.word_count = 3
        for word in keep_list:
            self.add_word(word)

In [16]:
def unicode2ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')
    #normal form decomposed, non-marking space

In [17]:
unicode2ascii('Montréal')

'Montreal'

In [18]:
def normalize_string(s):
    s = unicode2ascii(s.lower().strip())
    
    #replace .?! by ' .'
    s = re.sub(r'([.!?])', r' \1', s)
    #replace non alphabet, punctuation with whitespace
    s = re.sub(r'[^a-zA-Z.!?]+', r' ', s)
    #remove continuous whites
    s = re.sub(r'\s+', r' ', s).strip()
    return s

In [19]:
normalize_string('aa!   n')

'aa ! n'

In [20]:
lines = open(csv_path, 'r', encoding = 'utf-8').read().strip().split('\n')
pairs = []
for line in lines:
    pairs.append([normalize_string(s) for s in line.split('\t')])

In [21]:
print(len(pairs))
pairs[0]

221282


['can we make this quick ? roxanne korrine and andrew barrett are having an incredibly horrendous public break up on the quad . again .',
 'well i thought we d start with pronunciation if that s okay with you .']

In [22]:
vocab = Vocabulary('cornell')

In [23]:
MAX_LENGTH = 10
def filter_words(qa):
    return len(qa[0].split()) < MAX_LENGTH and len(qa[1].split()) < MAX_LENGTH

In [24]:
def filter_pairs(pairs):
    return [p for p in pairs if filter_words(p)]

In [25]:
pairs = [p for p in pairs if len(p)>1]
print('pairs of length >= 2:', len(pairs))
pairs = filter_pairs(pairs)
print('pairs after filtering', len(pairs))

pairs of length >= 2: 221282
pairs after filtering 64271


In [26]:
len(pairs[0][0].split())

2

In [27]:
for pair in pairs:
    vocab.add_sentence(pair[0])
    vocab.add_sentence(pair[1])
vocab.word_count

18007

In [28]:
#trim rare words and the pairs in which they appear
MIN_COUNT = 4
def trim_rare_words(voc, pairs, min_c = MIN_COUNT):
    voc.trim(min_c)
    temp_pairs = []
    for p in pairs:
        keep = True
        for wq in p[0].split():
            if wq not in voc.word2id :
                keep = False
                break
        if keep:
            for wa in p[1].split():
                if wa not in voc.word2id :
                    keep = False
                    break
            if keep:
                temp_pairs.append(p)
    return temp_pairs   

In [29]:
pairs = trim_rare_words(vocab, pairs)
len(pairs)

49781

In [30]:
def sentence_to_index(voc, s):
    return [voc.word2id[w] for w in s.split()] + [EOS]

In [31]:
print(pairs[0][0])
sentence_to_index(vocab, pairs[0][0])

there .


[3, 4, 2]

In [32]:
#zero pad for variable length
def zero_pad(v, val=0):
    return list(itertools.zip_longest(*v, fillvalue=val))

In [33]:
i = sentence_to_index(vocab,pairs[1][0])
j = sentence_to_index(vocab,pairs[0][0])
print(i,j)
print(zero_pad((i,j)))

[7, 8, 9, 10, 4, 11, 12, 13, 2] [3, 4, 2]
[(7, 3), (8, 4), (9, 2), (10, 0), (4, 0), (11, 0), (12, 0), (13, 0), (2, 0)]


In [34]:
def binarizer(l, val=0):
    m = []
    for i, row in enumerate(l):
        m.append([])
        for v in row:
            if v is not PAD:
                m[i].append(1)
            else:
                m[i].append(val)
    return m

In [35]:
def prepare_input(qs, voc):
    indices = [sentence_to_index(voc, q) for q in qs]
    lens = torch.tensor([len(index) for index in indices])
    zp = torch.LongTensor(zero_pad(indices))
    return zp, lens

In [36]:
def prepare_output(ans, voc):
    indices = [sentence_to_index(voc, a) for a in ans]
    max_target_len = max([len(index) for index in indices])
    zp = zero_pad(indices)
    mask = torch.ByteTensor(binarizer(zp))
    zp = torch.LongTensor(zp)
    return zp, mask, max_target_len

In [37]:
def batch_train_data(voc, pair_batch):
    #sort pairs in desc order by ques length
    pair_batch.sort(key = lambda x:len(x[0].split(' ')), reverse = True)
    ip, op = [], []
    for pair in pair_batch:
        ip.append(pair[0])
        op.append(pair[1])
    inputs, lens = prepare_input(ip, voc)
    outputs, mask, max_len = prepare_output(op, voc)
    return inputs, lens, outputs, mask, max_len

In [38]:
#testing the function above
small_sample_size = 7
small_pairs = [random.choice(pairs) for _ in range(small_sample_size)]
batches = batch_train_data(vocab, small_pairs)
for item in batches:
    print(item)

tensor([[  91,    7,   64,   33, 3973,  155,   38],
        [   7,  195,   58,   27,  748,  557,  261],
        [ 190,  116,    7,   14, 3052,  558,    4],
        [   7,   73,   21,  188,    4,    4,    2],
        [   4,    7,    6,    2,    2,    2,    0],
        [   4,  131,    2,    0,    0,    0,    0],
        [   4,   59,    0,    0,    0,    0,    0],
        [   6,    2,    0,    0,    0,    0,    0],
        [   2,    0,    0,    0,    0,    0,    0]])
tensor([9, 8, 6, 5, 5, 5, 4])
tensor([[ 312,   25,  165,   99,   49,   49,  375],
        [   4,  195,    4,  265,    6,   46,    7],
        [   2,  116,    2,  116,    2,    7,    4],
        [   0,   24,    0,  100,    0,  157,  664],
        [   0,    4,    0,  621,    0,   97, 1480],
        [   0,    2,    0,  110,    0,   36,    4],
        [   0,    0,    0,    6,    0,    6,    2],
        [   0,    0,    0,    2,    0,    2,    0]])
tensor([[1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1

In [70]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.hidden = hidden
        self.embedding = embedding
        self.n_layers = n_layers
        self.gru = nn.GRU(hidden, hidden, n_layers, dropout = (0 if n_layers==1 else dropout), bidirectional = True)
    def forward(self, input_seq, input_lens, hidden=None):
        #input_seq = (max_len, batch_size)
        #input_lens = (batch_size, )
        #hidden = (n_layers*n_directions, batch_size, hidden)
        
        #calculate embeddings
        embs = self.embedding(input_seq)
        
        #packed padded
        packed = torch.nn.utils.rnn.pack_padded_sequence(embs, input_lens)
        
        output, hidden = self.gru(packed, hidden)
        
        output, _ = torch.nn.utils.rnn.pad_packed_sequence(output)
        output = output[:,:,:self.hidden] + output[:,:,self.hidden:]
        return output, hidden
        #output = (max_len, batch, hidden)
        #hidden = (n_layers*n_directions, batch, hidden)

In [40]:
class Attn(nn.Module):
    def __init__(self, hidden, method='dot'):
        super(Attn, self).__init__()
        self.hidden = hidden
        self.method = method
    
    def score(self, dec_hidden, enc_op):
        #dec_hidden = (1, batch, hidden)
        #enc_op = (max_len, batch, hidden)
        #prod = (max_len, batch, hidden)=> sum over dim=2 => (max_len, batch)
        if self.method == 'dot':
            return torch.sum(dec_hidden * enc_op, dim=2)
    
    def forward(self, dec_h, enc_ops):
        
        attn_wts = self.score(dec_h, enc_ops)
        attn_wts = attn_wts.t() #(batch, max_len)
        #take softmax over max_len(timesteps) and unsqueeze
        return F.softmax(attn_wts, dim = 1).unsqueeze(1) #(batch, 1, max_len)

In [41]:
class DecoderWithAttn(nn.Module):
    def __init__(self, embedding, hidden, output_size, n_layers=1, dropout=0):
        super(DecoderWithAttn, self).__init__()
        self.hidden = hidden
        self.embedding = embedding
        self.output_size = output_size
        self.n_layers = n_layers
        self.attn = Attn(hidden)
        self.deconcat = nn.Linear(2*hidden, hidden)
        self.out = nn.Linear(hidden, output_size)
        self.gru = nn.GRU(hidden, hidden, n_layers, dropout=dropout)
        
    def forward(self, input_step, last_hidden, encoder_states):
        #input_step = (1, batch)
        #last_hidden = (n_layers*n_directions, batch, hidden)
        #encoder_states = (max_len, batch, n_directions*hidden)
        #run one batch word at a time
        
        embs = self.embedding(input_step)
        dec_op, hidden = self.gru(embs, last_hidden)
        #(1,batch, n_directions*hidden)
        #(n_layers*n_directions, batch, hidden)
        #find attn
        attn_wt = self.attn(hidden, encoder_states)
        #batch, 1, max_len
        #multiply attn with enc op
        context = attn_wt.bmm(encoder_states.transpose(0,1))
        #batch,1,max_len * batch,max_len, dir*hidd = batch, 1, dir*hidd
        dec_op = dec_op.squeeze(0) #batch, dir*hidd
        context = context.squeeze(1) #batch, dir*hidd
        concat_ip = torch.cat((dec_op, context), 1)
        concat_op = torch.tanh(self.deconcat(concat_ip))
        
        softmax = nn.Softmax(dim=1)
        output = softmax(self.out(concat_op))
        return output, hidden
        
        
        

In [42]:
def maskedNLL(output, true, mask):
    n_elems = mask.sum()
    true = true.view(-1,1)
    gathered = torch.gather(output, 1, true)
    ce = -torch.log(gathered)
    loss= ce.masked_select(mask)
    loss = loss.mean().to(device)
    return loss, n_elems.item()

In [43]:
#train for small batch
small_batch_size = 5
small_pairs = [random.choice(pairs) for _ in range(small_batch_size)]
batches = batch_train_data(vocab, small_pairs)
inputs, lens, outputs, mask, max_len = batches
print('inputs:', inputs)
print('lens:', lens)
print('outputs:', outputs)
print('mask:', mask)
print('max_len:', max_len)


hidden_size=500

embeds = nn.Embedding(vocab.word_count, hidden_size)
encoder = EncoderRNN(hidden_size, embeds).to(device)
decoder = DecoderWithAttn(embeds, hidden_size, vocab.word_count).to(device)
encoder.train()
decoder.train()

encoder_optim = optim.Adam(encoder.parameters(), 0.001)
decoder_optim = optim.Adam(decoder.parameters(), 0.001)
encoder_optim.zero_grad()
decoder_optim.zero_grad()

inputs = inputs.to(device)
lens = lens.to(device)
outputs = outputs.to(device)
mask = mask.to(device)

loss = 0.
print_loss = []
total = 0.

encoder_out, encoder_hid = encoder(inputs, lens)
print('encoder output:', encoder_out.size())
print('encoder hidden:', encoder_hid.size())

#first decoder input is start symbol
decoder_in = torch.LongTensor([[SOS for _ in range(small_batch_size)]]).to(device)
print('decoder input:', decoder_in.size())

#first hidden of decoder = last hidden of encoder
decoder_hid = encoder_hid[:1].to(device)
print('decoder hidden:', decoder_hid.size())

#using teacher forcing
for t in range(max_len):
    decoder_out, decoder_hid = decoder(decoder_in, decoder_hid, encoder_out)
    print('decoder output:', decoder_out.size())
    print('decoder hidden:', decoder_hid.size())
    
    
    print('decoder input shape initial:', outputs[t].size())
    #teacher forcing
    decoder_in = outputs[t].view(1, -1)
    print('decoder input shape final:', decoder_in.size())
    print('mask shape:', mask[t].size())
    print('mask:', mask[t])
    
    mask_loss, nt = maskedNLL(decoder_out, outputs[t], mask[t])
    print('mask loss', mask_loss)
    print('nt',nt)
    total+=nt
    print('total', total)
    print_loss.append(mask_loss.item()*nt)
    print('print loss', print_loss)
    ls = sum(print_loss)
    return_loss = ls/total
    print('return', return_loss)
    encoder_optim.step()
    decoder_optim.step()
    
    

inputs: tensor([[2789,   85,   36,    5,  123],
        [ 101,  296,   37,  209,    9],
        [   9,    7,   75,  210,  124],
        [ 267,  214,    4, 4437,    4],
        [ 584,   82,    4,    6,    2],
        [ 186,    4,    4,    2,    0],
        [ 316,    2,    2,    0,    0],
        [  65,    0,    0,    0,    0],
        [   2,    0,    0,    0,    0]])
lens: tensor([9, 7, 7, 6, 5])
outputs: tensor([[  36,   25,   17, 1423, 1474],
        [  37,  214, 1137,    4,    7],
        [ 343,   75,   40,    4, 2931],
        [  65,    4, 2999,    4,   65],
        [   2,    2,    4,    2,    2],
        [   0,    0,    2,    0,    0]])
mask: tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [0, 0, 1, 0, 0]], dtype=torch.uint8)
max_len: 6
encoder output: torch.Size([9, 5, 500])
encoder hidden: torch.Size([2, 5, 500])
decoder input: torch.Size([1, 5])
decoder hidden: torch.Size([1, 5, 500])
decoder ou

In [64]:
#actual train function
def train(inputs, lens, outputs, mask, max_target_len, batch, embs, encoder, decoder, encoder_optim, decoder_optim, teacher_forcing_ratio):
    encoder_optim.zero_grad()
    decoder_optim.zero_grad()

    inputs = inputs.to(device)
    lens = lens.to(device)
    outputs = outputs.to(device)
    mask = mask.to(device)

    loss = 0.
    print_loss = []
    total = 0.

    encoder_out, encoder_hid = encoder(inputs, lens)

    #first decoder input is start symbol
    decoder_in = torch.LongTensor([[SOS for _ in range(batch)]]).to(device)

    #first hidden of decoder = last hidden of encoder
    decoder_hid = encoder_hid[:1].to(device)

    use_tf = True if random.random()<teacher_forcing_ratio else False
    
    if use_tf:
        for t in range(max_len):
            decoder_out, decoder_hid = decoder(decoder_in, decoder_hid, encoder_out)
           
            #teacher forcing
            decoder_in = outputs[t].view(1, -1)

            mask_loss, nt = maskedNLL(decoder_out, outputs[t], mask[t])
            total += nt
            print_loss.append(mask_loss.item()*nt)
            loss += mask_loss
    else:
        for t in range(max_len):
            decoder_out, decoder_hid = decoder(decoder_in, decoder_hid, encoder_out)
           
            _, topk = decoder_out.topk(1)
            decoder_in = torch.LongTensor([[topk[i][0] for i in range(batch)]]).to(device)
            
            mask_loss, nt = maskedNLL(decoder_out, outputs[t], mask[t])
            total += nt
            print_loss.append(mask_loss.item()*nt)
            loss += mask_loss
    
    loss.backward()
    encoder_optim.step()
    decoder_optim.step()
    
    return sum(print_loss)/total

In [72]:
def train_iters(voc, pairs, encoder, decoder, encoder_optim, decoder_optim, embs, batch, n_iters):
    training_batches = [batch_train_data(voc, [random.choice(pairs) for _ in range(batch)])
                      for _ in range(n_iters)]
    print_loss = 0.
    for i in range(n_iters):
        current_batch = training_batches[i]
        inputs, lens, outputs, mask, max_len = current_batch
        loss = train(inputs, lens, outputs, mask, max_len, batch, embs, encoder, decoder, encoder_optim, decoder_optim, 1.0)
        print_loss += loss
        
        if i%200 == 0 and i!=0:
            print('Iteration', i)
            print(print_loss/200)
            print_loss = 0.

In [80]:
#to talk to the bot we need a decoder
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, inputs, lens, max_len):
        encoder_out, encoder_hid = self.encoder(inputs, lens)
        #enc final hidden is in first of decoder
        decoder_hid = encoder_hid[:1]
        # start symbol
        decoder_in = torch.ones(1, 1, device=device, dtype=torch.long) * SOS
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        for _ in range(max_len):
            # Forward pass through decoder
            decoder_out, decoder_hid = self.decoder(decoder_in, decoder_hid, encoder_out)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_in = torch.max(decoder_out, dim=1)
            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_in), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_in = torch.unsqueeze(decoder_in, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores

In [82]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=10):
    # Format input sentence as a batch
    indexes_batch = [sentence_to_index(voc, sentence)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch 
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, max_length)
    #index back to word
    decoded_words = [voc.id2word[token.item()] for token in tokens]
    return decoded_words


def evaluateInput(encoder, decoder, searcher, voc):
    input_sentence = ''
    while(1):
        try:
            input_sentence = input('> ')
            # exit condition
            if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalize_string(input_sentence)
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            print('Bot:', ' '.join(output_words))

        except KeyError:
            print("Error: <UNK>")

In [73]:
hidden_dim = 500
batch_size = 64
n_iters = 2000
embedding = nn.Embedding(vocab.word_count, hidden_dim)
encoder = EncoderRNN(hidden_dim, embedding).to(device)
decoder = DecoderWithAttn(embedding, hidden_dim, vocab.word_count).to(device)

l_rate = 0.001
encoder.train()
decoder.train()

encoder_optim = optim.Adam(encoder.parameters(), l_rate)
decoder_optim = optim.Adam(decoder.parameters(), l_rate)
encoder_optim.zero_grad()
decoder_optim.zero_grad()

train_iters(vocab, pairs, encoder, decoder, encoder_optim, decoder_optim, embedding, batch_size, n_iters)


Iteration 200
4.251150501117423
Iteration 400
3.6769184004841566
Iteration 600
3.47263782340546
Iteration 800
3.264304235343397
Iteration 1000
3.1538263822650983
Iteration 1200
3.0448477410860137
Iteration 1400
2.9013089056025727
Iteration 1600
2.8010474879619944
Iteration 1800
2.666254231161008


In [84]:
encoder.eval()
decoder.eval()

searcher = GreedySearchDecoder(encoder, decoder)

evaluateInput(encoder, decoder, searcher, vocab)


> hi
Bot: hi .
> how are you?
Bot: fine .
> quit
