# Transformer - Implemenation from scratch using pytorch

Part of **#30DaysOfBasics**, Lets build a chatbot application using Encoder-Decoder (transformer based) architecture.

Training Data: Cornell Movies dialogs (https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html)

Referencev tutorial: 

In [13]:
from collections import Counter

import re
import json
import itertools

In [4]:
corpus_movie_convo = '../data/movie_conversations.txt'
corpus_movie_lines = '../data/movie_lines.txt'

with open(corpus_movie_convo, 'r', encoding='iso-8859-1') as f:
    movie_convo = f.readlines()
    
with open(corpus_movie_lines, 'r', encoding='iso-8859-1') as f:
    movie_lines = f.readlines()

for line in movie_lines[:8]:
    print(line.strip())

L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!
L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!
L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.
L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?
L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.
L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow
L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.
L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No


In [20]:
lines_dic = {}
for line in movie_lines:
    objects = line.split(" +++$+++ ")
    lines_dic[objects[0]] = objects[-1]

In [21]:
def remove_punc(string):
    punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
    no_punct = ""
    for char in string:
        if char not in punctuations:
            no_punct = no_punct + char  # space is also a character
    return no_punct.lower()

In [24]:
MAX_LEN = 25


pairs = []
for con in movie_convo:
    ids = eval(con.split(" +++$+++ ")[-1])
    for i in range(len(ids)):
        qa_pairs = []
        
        if i==len(ids)-1:
            break
        
        first = remove_punc(lines_dic[ids[i]].strip())      
        second = remove_punc(lines_dic[ids[i+1]].strip())
        qa_pairs.append(first.split()[:MAX_LEN])
        qa_pairs.append(second.split()[:MAX_LEN])
        pairs.append(qa_pairs)

In [25]:
word_freq = Counter()
for pair in pairs:
    word_freq.update(pair[0])
    word_freq.update(pair[1])

In [26]:
min_word_freq = 5
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
word_map = {k: v + 1 for v, k in enumerate(words)}
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0

In [27]:
# with open('../data/WORDMAP_corpus.json', 'w') as j:
#     json.dump(word_map, j)

In [31]:
def encode_question(words, word_map, max_len=MAX_LEN):
    enc_c = [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c


def encode_reply(words, word_map, max_len=MAX_LEN):
    enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in words] + \
    [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c

In [32]:
pairs_encoded = []
for pair in pairs:
    qus = encode_question(pair[0], word_map)
    ans = encode_reply(pair[1], word_map)
    pairs_encoded.append([qus, ans])

In [34]:
# with open('../data/pairs_encoded.json', 'w') as p:
#     json.dump(pairs_encoded, p)

## Data preparation

In [49]:
import math
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

In [43]:
class Dataset(Dataset):

    def __init__(self):

        self.pairs = json.load(open('../data/pairs_encoded.json'))
        self.dataset_size = len(self.pairs)

    def __getitem__(self, i):
        
        question = torch.LongTensor(self.pairs[i][0])
        reply = torch.LongTensor(self.pairs[i][1])
            
        return question, reply

    def __len__(self):
        return self.dataset_size

In [40]:
train_loader = DataLoader(Dataset(),
                        batch_size = 100, 
                        shuffle=True, 
                        pin_memory=True)

In [46]:
qes, rply = next(iter(train_loader))
print(qes, rply)
print(qes.shape, rply.shape)

tensor([[6983,  725,   17,  ...,    0,    0,    0],
        [  26,   29,   17,  ...,    0,    0,    0],
        [ 100,  113,   20,  ...,    0,    0,    0],
        ...,
        [8543,  295,   82,  ...,    0,    0,    0],
        [  56,   55, 1732,  ...,    0,    0,    0],
        [  56,   57,  909,  ...,    0,    0,    0]]) tensor([[18241,    17,  8386,  ...,     0,     0,     0],
        [18241,  7652,    91,  ...,     0,     0,     0],
        [18241,   929,   399,  ...,     0,     0,     0],
        ...,
        [18241,    87,    29,  ...,     0,     0,     0],
        [18241,   347,    93,  ...,     0,     0,     0],
        [18241,  5518,  1239,  ...,     0,     0,     0]])
torch.Size([100, 25]) torch.Size([100, 27])


## Modelling

In [47]:
def create_masks(question, reply_input, reply_target):
    
    def subsequent_mask(size):
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0)
    
    question_mask = question!=0
    question_mask = question_mask.to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1)         # (batch_size, 1, 1, max_words)
     
    reply_input_mask = reply_input!=0
    reply_input_mask = reply_input_mask.unsqueeze(1)  # (batch_size, 1, max_words)
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data) 
    reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words, max_words)
    reply_target_mask = reply_target!=0              # (batch_size, max_words)
    
    return question_mask, reply_input_mask, reply_target_mask

In [65]:
class Embeddings(nn.Module):
    
    def __init__(self, vocab_size, d_model, max_len=50):
        super(Embeddings, self).__init__()
        
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model)
        self.dropout = nn.Dropout(0.1)
        
        self.pe = self.create_positional_encoding(max_len, d_model)
     
    
    def create_positional_encoding(self, max_len, d_model):
            
        pe = torch.zeros(max_len, d_model).to(device)
            
        for pos in range(max_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2* i) / d_model)))
                pe[pos, i+1] = math.cos(pos / (10000 ** ((2* (i + 1)) / d_model)))
                    
        pe = torch.unsqueeze(pe, 0) #(1, max_len, d_model)
        return pe
    
    
    def forward(self, encoded_words):
        embeddings = self.embed(encoded_words) * math.sqrt(d_model) #(batch_size, max_words, d_model)
        
        embeddings += self.pe[:, :embeddings.size(1)]
        
        embeddings = self.dropout(embeddings)
        
        return embeddings   

In [131]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, heads, d_model):
        super(MultiHeadAttention, self).__init__()
        
        assert d_model % heads == 0
        self.heads = heads
        self.d_k = d_model // heads
        
        self.dropout = nn.Dropout(0.1)
        
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        
        self.concat = nn.Linear(d_model, d_model)
        
    
    def forward(self, query, key, value, mask):
        
        query = self.query(query) #(batch_size, max_len, 512)
        key = self.key(key) #(batch_size, max_len, 512)
        value = self.value(value) #(batch_size, max_len, 512)
        
        # (batch_size, max_len, 512) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
        query = query.view(query.size(0), -1, self.heads, self.d_k).permute(0,2,1,3)
        key = key.view(key.size(0), -1, self.heads, self.d_k).permute(0,2,1,3)
        value = value.view(value.size(0), -1, self.heads, self.d_k).permute(0,2,1,3)
        
        # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
        scores = torch.matmul(query, key.permute(0,1,3,2)) / math.sqrt(query.size(-1))
        
        scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = nn.functional.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = torch.matmul(attn_weights, value)
        
        # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, h * d_k)
        context = context.permute(0,2,1,3).contiguous().view(context.size(0), -1, self.heads * self.d_k)
        
        interacted = self.concat(context)
        
        return interacted

In [72]:
class FeedForward(nn.Module):
    
    def __init__(self, d_model, middle_dim = 2048):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, middle_dim)
        self.fc2 = nn.Linear(middle_dim, d_model)
        self.dropout = nn.Dropout(0.1)
        
    
    def forward(self, x):
        out = nn.functional.relu(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

In [119]:
class EncoderLayer(nn.Module):
    
    def __init__(self, d_model, heads):
        super(EncoderLayer, self).__init__()
        
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.layerNorm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)
        
    
    def forward(self, embeddings, mask):
        
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        interacted = self.layerNorm(interacted + embeddings)
        
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        
        encoded = self.layerNorm(feed_forward_out + interacted)
        
        return encoded

In [59]:
class DecoderLayer(nn.Module):
    
        def __init__(self, d_model, heads):
            super(DecoderLayer, self).__init__()

            self.self_multihead = MultiHeadAttention(heads, d_model)
            self.src_multihead = MultiHeadAttention(heads, d_model)
            self.feed_forward = FeedForward(d_model)
            self.layerNorm = nn.LayerNorm(d_model)
            self.dropout = nn.Dropout(0.1)
        
        
        def forward(self, embedding, encoded_repr, src_mask, target_mask):
            
            query = self.dropout(self.self_multihead(embedding, embedding, embedding, target_mask))
            query = self.layerNorm(query + embedding)
            
            interacted = self.dropout(self.src_multihead(query, encoded_repr, encoded_repr, src_mask))
            interacted = self.layerNorm(interacted + query)
            
            feed_forward_out = self.dropout(self.feed_forward(interacted))
            
            decoded = self.layerNorm(feed_forward_out + interacted)
            
            return decoded

In [120]:
class Transformer(nn.Module):
    
    def __init__(self, d_model, heads, n_layers, word_map):
        super(Transformer, self).__init__()
        
        self.d_model = d_model
        self.vocab_size = len(word_map)
        self.embeddings = Embeddings(self.vocab_size, d_model)
        
        self.encoder = nn.ModuleList([EncoderLayer(d_model, heads) for _ in range(n_layers)])
        self.decoder = nn.ModuleList([DecoderLayer(d_model, heads) for _ in range(n_layers)])
        
        self.logit = nn.Linear(self.d_model, self.vocab_size)
        
    
    def encode(self, src_words, src_mask):
        src_embeddings = self.embeddings(src_words)
        
        for layer in self.encoder:
            src_embeddings = layer(src_embeddings, src_mask)
        
        return src_embeddings
    
    
    def decode(self, target_words, target_mask, src_embeddings, src_mask):
        tgt_embeddings = self.embeddings(target_words)
        
        for layer in self.decoder:
            tgt_embeddings = layer(tgt_embeddings, src_embeddings, src_mask, target_mask)
            
        return tgt_embeddings
    
    
    def forward(self, src_words, src_mask, target_words, target_mask):
        
        encoded = self.encode(src_words=src_words, src_mask=src_mask)
        decoded = self.decode(target_words, target_mask, encoded, src_mask)
        
        out = nn.functional.log_softmax(self.logit(decoded))
        return out

In [142]:
class AdamWarmup:
    
    def __init__(self, model_size, warmup_steps, optimizer):
        
        self.model_size = model_size
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        
        self.current_step = 0
        self.lr = 0
        
    def get_lr(self):
        return self.model_size ** (-0.5) * \
            min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))
    
    
    def step(self):
        self.current_step += 1
        lr = self.get_lr()
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        self.lr = lr
        
        #update weights
        self.optimizer.step()

In [143]:
class LoseWithLS(nn.Module):
    
    def __init__(self, size, smooth):
        super(LoseWithLS, self).__init__()
        
        self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
        self.smooth = smooth
        self.confidence = 1.0 - smooth
        self.size = size
        
    
    def forward(self, prediction, target, mask):
        
        #prediction: (batch_size, max_words, vocab_size)
        prediction = prediction.view(-1, prediction.size(-1))
        target = target.contiguous().view(-1)
        mask = mask.float()
        mask = mask.view(-1)
        
        labels = prediction.data.clone()
        labels.fill_(self.smooth / (self.size - 1))
        labels.scatter_(1, target.data.unsqueeze(1), self.confidence)
        
        loss = self.criterion(prediction, labels) #(batch_size * max_words, vocab_size)
        loss = (loss.sum(1) * mask).sum() / mask.sum()
        
        return loss

In [144]:
d_model = 512
heads=8
n_layers = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 1
warmup_steps = 4000

with open('../data/WORDMAP_corpus.json', 'r') as f:
    word_map = json.load(f)

transformer = Transformer(d_model, heads, n_layers, word_map)
transformer.to(device)

adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps= 1e-9)
transformer_optimizer = AdamWarmup(d_model, warmup_steps, adam_optimizer)
criterion = LoseWithLS(size=len(word_map), smooth=0.3)

In [145]:
def train(train_loader, transformer, transformer_optimizer, criterion, epoch, print_every):
    
    transformer.train()
    sum_loss = 0
    count = 0
    
    for i, (question, reply) in enumerate(train_loader):
        
        samples = question.shape[0]
        
        question = question.to(device)
        reply = reply.to(device)
        
        #sentence: <start> I went home today <end>
        #reply_input: <start> I went home today
        #reply_target: I went home today <end>
        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]
        
        question_mask, reply_input_mask, reply_target_mask = create_masks(question, reply_input, reply_target)
        
        #run through transformer
        out = transformer(question, question_mask, reply_input, reply_input_mask)
        loss = criterion(out, reply_target, reply_target_mask)
        
        #backprop
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()
        transformer_optimizer.step()
        
        #loss accumlation
        sum_loss += loss * samples
        count += samples
        
        if i % print_every == 0:
            print('Epoch: [{}][{}/{}]\t Loss: {:.3f}'.format(epoch, i, len(train_loader), sum_loss / count))

In [146]:
def evaluate(transformer, question, question_mask, max_len, word_map):
    
    rev_word_map = {v:k for k, v in word_map.items()}
    
    transformer.eval()
    
    start_token = word_map['<start>']
    encoded = transformer.encode(question, question_mask)
    words = torch.LongTensor([[start_token]]).to(device)
    
    for step in range(max_len - 1):
        
        size = words.shape[0]
        target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        target_mask = target_mask.to(device).unsqueeze(0)
        
        decoded = transformer.deocde(words, target_mask, encoded, question_mask) #(1,1,vocab_size)
        
        preds = transformer.logit(decoded[:, -1]) #(1, vocab_size)
        
        _, next_word = torch.max(pred, dim = 1) #(1,1)
        
        if next_word == word_map['<end>']:
            break
        words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim=1)
        
    words = words.squeeze(0)
    words = words.tolist()
    
    sen_idx = [w for w in words if w not in {word_map['<start>']}]
    sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
    
    return sentence

In [147]:
#Training

for epoch in range(epochs):
    
    train(train_loader, transformer, transformer_optimizer, criterion, epoch, print_every=1)
    state_dict = {'epoch': epoch, 'transformer': transformer, 'transformer_optimizer': transformer_optimizer}
    torch.save(state_dict, '../models/checkpoint_{}.pth.tar'.format(epoch))

  out = nn.functional.log_softmax(self.logit(decoded))


Epoch: [0][0/2217]	 Loss: 1.189
Epoch: [0][1/2217]	 Loss: 1.191
Epoch: [0][2/2217]	 Loss: 1.187
Epoch: [0][3/2217]	 Loss: 1.190
Epoch: [0][4/2217]	 Loss: 1.191
Epoch: [0][5/2217]	 Loss: 1.187
Epoch: [0][6/2217]	 Loss: 1.186
Epoch: [0][7/2217]	 Loss: 1.186
Epoch: [0][8/2217]	 Loss: 1.186
Epoch: [0][9/2217]	 Loss: 1.186
Epoch: [0][10/2217]	 Loss: 1.186
Epoch: [0][11/2217]	 Loss: 1.185
Epoch: [0][12/2217]	 Loss: 1.186
Epoch: [0][13/2217]	 Loss: 1.186
Epoch: [0][14/2217]	 Loss: 1.187
Epoch: [0][15/2217]	 Loss: 1.186
Epoch: [0][16/2217]	 Loss: 1.186
Epoch: [0][17/2217]	 Loss: 1.185
Epoch: [0][18/2217]	 Loss: 1.184
Epoch: [0][19/2217]	 Loss: 1.184
Epoch: [0][20/2217]	 Loss: 1.183
Epoch: [0][21/2217]	 Loss: 1.183
Epoch: [0][22/2217]	 Loss: 1.182
Epoch: [0][23/2217]	 Loss: 1.181
Epoch: [0][24/2217]	 Loss: 1.181
Epoch: [0][25/2217]	 Loss: 1.181
Epoch: [0][26/2217]	 Loss: 1.180
Epoch: [0][27/2217]	 Loss: 1.179
Epoch: [0][28/2217]	 Loss: 1.178
Epoch: [0][29/2217]	 Loss: 1.177
Epoch: [0][30/2217]	

Epoch: [0][245/2217]	 Loss: 0.927
Epoch: [0][246/2217]	 Loss: 0.926
Epoch: [0][247/2217]	 Loss: 0.924
Epoch: [0][248/2217]	 Loss: 0.923
Epoch: [0][249/2217]	 Loss: 0.922
Epoch: [0][250/2217]	 Loss: 0.920
Epoch: [0][251/2217]	 Loss: 0.919
Epoch: [0][252/2217]	 Loss: 0.917
Epoch: [0][253/2217]	 Loss: 0.916
Epoch: [0][254/2217]	 Loss: 0.914
Epoch: [0][255/2217]	 Loss: 0.913
Epoch: [0][256/2217]	 Loss: 0.911
Epoch: [0][257/2217]	 Loss: 0.910
Epoch: [0][258/2217]	 Loss: 0.909
Epoch: [0][259/2217]	 Loss: 0.907
Epoch: [0][260/2217]	 Loss: 0.906
Epoch: [0][261/2217]	 Loss: 0.904
Epoch: [0][262/2217]	 Loss: 0.903
Epoch: [0][263/2217]	 Loss: 0.901
Epoch: [0][264/2217]	 Loss: 0.900
Epoch: [0][265/2217]	 Loss: 0.898
Epoch: [0][266/2217]	 Loss: 0.897
Epoch: [0][267/2217]	 Loss: 0.896
Epoch: [0][268/2217]	 Loss: 0.894
Epoch: [0][269/2217]	 Loss: 0.893
Epoch: [0][270/2217]	 Loss: 0.891
Epoch: [0][271/2217]	 Loss: 0.890
Epoch: [0][272/2217]	 Loss: 0.889
Epoch: [0][273/2217]	 Loss: 0.887
Epoch: [0][274

Epoch: [0][486/2217]	 Loss: 0.655
Epoch: [0][487/2217]	 Loss: 0.655
Epoch: [0][488/2217]	 Loss: 0.654
Epoch: [0][489/2217]	 Loss: 0.653
Epoch: [0][490/2217]	 Loss: 0.652
Epoch: [0][491/2217]	 Loss: 0.651
Epoch: [0][492/2217]	 Loss: 0.650
Epoch: [0][493/2217]	 Loss: 0.649
Epoch: [0][494/2217]	 Loss: 0.649
Epoch: [0][495/2217]	 Loss: 0.648
Epoch: [0][496/2217]	 Loss: 0.647
Epoch: [0][497/2217]	 Loss: 0.646
Epoch: [0][498/2217]	 Loss: 0.645
Epoch: [0][499/2217]	 Loss: 0.644
Epoch: [0][500/2217]	 Loss: 0.644
Epoch: [0][501/2217]	 Loss: 0.643
Epoch: [0][502/2217]	 Loss: 0.642
Epoch: [0][503/2217]	 Loss: 0.641
Epoch: [0][504/2217]	 Loss: 0.640
Epoch: [0][505/2217]	 Loss: 0.640
Epoch: [0][506/2217]	 Loss: 0.639
Epoch: [0][507/2217]	 Loss: 0.638
Epoch: [0][508/2217]	 Loss: 0.637
Epoch: [0][509/2217]	 Loss: 0.636
Epoch: [0][510/2217]	 Loss: 0.635
Epoch: [0][511/2217]	 Loss: 0.635
Epoch: [0][512/2217]	 Loss: 0.634
Epoch: [0][513/2217]	 Loss: 0.633
Epoch: [0][514/2217]	 Loss: 0.632
Epoch: [0][515

Epoch: [0][727/2217]	 Loss: 0.492
Epoch: [0][728/2217]	 Loss: 0.492
Epoch: [0][729/2217]	 Loss: 0.491
Epoch: [0][730/2217]	 Loss: 0.491
Epoch: [0][731/2217]	 Loss: 0.490
Epoch: [0][732/2217]	 Loss: 0.490
Epoch: [0][733/2217]	 Loss: 0.489
Epoch: [0][734/2217]	 Loss: 0.489
Epoch: [0][735/2217]	 Loss: 0.488
Epoch: [0][736/2217]	 Loss: 0.488
Epoch: [0][737/2217]	 Loss: 0.487
Epoch: [0][738/2217]	 Loss: 0.487
Epoch: [0][739/2217]	 Loss: 0.486
Epoch: [0][740/2217]	 Loss: 0.486
Epoch: [0][741/2217]	 Loss: 0.485
Epoch: [0][742/2217]	 Loss: 0.484
Epoch: [0][743/2217]	 Loss: 0.484
Epoch: [0][744/2217]	 Loss: 0.483
Epoch: [0][745/2217]	 Loss: 0.483
Epoch: [0][746/2217]	 Loss: 0.482
Epoch: [0][747/2217]	 Loss: 0.482
Epoch: [0][748/2217]	 Loss: 0.481
Epoch: [0][749/2217]	 Loss: 0.481
Epoch: [0][750/2217]	 Loss: 0.480
Epoch: [0][751/2217]	 Loss: 0.480
Epoch: [0][752/2217]	 Loss: 0.479
Epoch: [0][753/2217]	 Loss: 0.479
Epoch: [0][754/2217]	 Loss: 0.478
Epoch: [0][755/2217]	 Loss: 0.478
Epoch: [0][756

Epoch: [0][968/2217]	 Loss: 0.382
Epoch: [0][969/2217]	 Loss: 0.382
Epoch: [0][970/2217]	 Loss: 0.381
Epoch: [0][971/2217]	 Loss: 0.381
Epoch: [0][972/2217]	 Loss: 0.381
Epoch: [0][973/2217]	 Loss: 0.380
Epoch: [0][974/2217]	 Loss: 0.380
Epoch: [0][975/2217]	 Loss: 0.379
Epoch: [0][976/2217]	 Loss: 0.379
Epoch: [0][977/2217]	 Loss: 0.379
Epoch: [0][978/2217]	 Loss: 0.378
Epoch: [0][979/2217]	 Loss: 0.378
Epoch: [0][980/2217]	 Loss: 0.378
Epoch: [0][981/2217]	 Loss: 0.377
Epoch: [0][982/2217]	 Loss: 0.377
Epoch: [0][983/2217]	 Loss: 0.377
Epoch: [0][984/2217]	 Loss: 0.376
Epoch: [0][985/2217]	 Loss: 0.376
Epoch: [0][986/2217]	 Loss: 0.375
Epoch: [0][987/2217]	 Loss: 0.375
Epoch: [0][988/2217]	 Loss: 0.375
Epoch: [0][989/2217]	 Loss: 0.374
Epoch: [0][990/2217]	 Loss: 0.374
Epoch: [0][991/2217]	 Loss: 0.373
Epoch: [0][992/2217]	 Loss: 0.373
Epoch: [0][993/2217]	 Loss: 0.373
Epoch: [0][994/2217]	 Loss: 0.372
Epoch: [0][995/2217]	 Loss: 0.372
Epoch: [0][996/2217]	 Loss: 0.372
Epoch: [0][997

Epoch: [0][1203/2217]	 Loss: 0.306
Epoch: [0][1204/2217]	 Loss: 0.305
Epoch: [0][1205/2217]	 Loss: 0.305
Epoch: [0][1206/2217]	 Loss: 0.305
Epoch: [0][1207/2217]	 Loss: 0.305
Epoch: [0][1208/2217]	 Loss: 0.304
Epoch: [0][1209/2217]	 Loss: 0.304
Epoch: [0][1210/2217]	 Loss: 0.304
Epoch: [0][1211/2217]	 Loss: 0.303
Epoch: [0][1212/2217]	 Loss: 0.303
Epoch: [0][1213/2217]	 Loss: 0.303
Epoch: [0][1214/2217]	 Loss: 0.303
Epoch: [0][1215/2217]	 Loss: 0.302
Epoch: [0][1216/2217]	 Loss: 0.302
Epoch: [0][1217/2217]	 Loss: 0.302
Epoch: [0][1218/2217]	 Loss: 0.301
Epoch: [0][1219/2217]	 Loss: 0.301
Epoch: [0][1220/2217]	 Loss: 0.301
Epoch: [0][1221/2217]	 Loss: 0.300
Epoch: [0][1222/2217]	 Loss: 0.300
Epoch: [0][1223/2217]	 Loss: 0.300
Epoch: [0][1224/2217]	 Loss: 0.300
Epoch: [0][1225/2217]	 Loss: 0.299
Epoch: [0][1226/2217]	 Loss: 0.299
Epoch: [0][1227/2217]	 Loss: 0.299
Epoch: [0][1228/2217]	 Loss: 0.299
Epoch: [0][1229/2217]	 Loss: 0.298
Epoch: [0][1230/2217]	 Loss: 0.298
Epoch: [0][1231/2217

Epoch: [0][1438/2217]	 Loss: 0.247
Epoch: [0][1439/2217]	 Loss: 0.247
Epoch: [0][1440/2217]	 Loss: 0.247
Epoch: [0][1441/2217]	 Loss: 0.247
Epoch: [0][1442/2217]	 Loss: 0.246
Epoch: [0][1443/2217]	 Loss: 0.246
Epoch: [0][1444/2217]	 Loss: 0.246
Epoch: [0][1445/2217]	 Loss: 0.246
Epoch: [0][1446/2217]	 Loss: 0.245
Epoch: [0][1447/2217]	 Loss: 0.245
Epoch: [0][1448/2217]	 Loss: 0.245
Epoch: [0][1449/2217]	 Loss: 0.245
Epoch: [0][1450/2217]	 Loss: 0.245
Epoch: [0][1451/2217]	 Loss: 0.244
Epoch: [0][1452/2217]	 Loss: 0.244
Epoch: [0][1453/2217]	 Loss: 0.244
Epoch: [0][1454/2217]	 Loss: 0.244
Epoch: [0][1455/2217]	 Loss: 0.244
Epoch: [0][1456/2217]	 Loss: 0.243
Epoch: [0][1457/2217]	 Loss: 0.243
Epoch: [0][1458/2217]	 Loss: 0.243
Epoch: [0][1459/2217]	 Loss: 0.243
Epoch: [0][1460/2217]	 Loss: 0.243
Epoch: [0][1461/2217]	 Loss: 0.242
Epoch: [0][1462/2217]	 Loss: 0.242
Epoch: [0][1463/2217]	 Loss: 0.242
Epoch: [0][1464/2217]	 Loss: 0.242
Epoch: [0][1465/2217]	 Loss: 0.242
Epoch: [0][1466/2217

Epoch: [0][1673/2217]	 Loss: 0.202
Epoch: [0][1674/2217]	 Loss: 0.202
Epoch: [0][1675/2217]	 Loss: 0.202
Epoch: [0][1676/2217]	 Loss: 0.201
Epoch: [0][1677/2217]	 Loss: 0.201
Epoch: [0][1678/2217]	 Loss: 0.201
Epoch: [0][1679/2217]	 Loss: 0.201
Epoch: [0][1680/2217]	 Loss: 0.201
Epoch: [0][1681/2217]	 Loss: 0.201
Epoch: [0][1682/2217]	 Loss: 0.200
Epoch: [0][1683/2217]	 Loss: 0.200
Epoch: [0][1684/2217]	 Loss: 0.200
Epoch: [0][1685/2217]	 Loss: 0.200
Epoch: [0][1686/2217]	 Loss: 0.200
Epoch: [0][1687/2217]	 Loss: 0.200
Epoch: [0][1688/2217]	 Loss: 0.199
Epoch: [0][1689/2217]	 Loss: 0.199
Epoch: [0][1690/2217]	 Loss: 0.199
Epoch: [0][1691/2217]	 Loss: 0.199
Epoch: [0][1692/2217]	 Loss: 0.199
Epoch: [0][1693/2217]	 Loss: 0.198
Epoch: [0][1694/2217]	 Loss: 0.198
Epoch: [0][1695/2217]	 Loss: 0.198
Epoch: [0][1696/2217]	 Loss: 0.198
Epoch: [0][1697/2217]	 Loss: 0.198
Epoch: [0][1698/2217]	 Loss: 0.198
Epoch: [0][1699/2217]	 Loss: 0.197
Epoch: [0][1700/2217]	 Loss: 0.197
Epoch: [0][1701/2217

Epoch: [0][1908/2217]	 Loss: 0.165
Epoch: [0][1909/2217]	 Loss: 0.165
Epoch: [0][1910/2217]	 Loss: 0.165
Epoch: [0][1911/2217]	 Loss: 0.165
Epoch: [0][1912/2217]	 Loss: 0.165
Epoch: [0][1913/2217]	 Loss: 0.164
Epoch: [0][1914/2217]	 Loss: 0.164
Epoch: [0][1915/2217]	 Loss: 0.164
Epoch: [0][1916/2217]	 Loss: 0.164
Epoch: [0][1917/2217]	 Loss: 0.164
Epoch: [0][1918/2217]	 Loss: 0.164
Epoch: [0][1919/2217]	 Loss: 0.164
Epoch: [0][1920/2217]	 Loss: 0.163
Epoch: [0][1921/2217]	 Loss: 0.163
Epoch: [0][1922/2217]	 Loss: 0.163
Epoch: [0][1923/2217]	 Loss: 0.163
Epoch: [0][1924/2217]	 Loss: 0.163
Epoch: [0][1925/2217]	 Loss: 0.163
Epoch: [0][1926/2217]	 Loss: 0.162
Epoch: [0][1927/2217]	 Loss: 0.162
Epoch: [0][1928/2217]	 Loss: 0.162
Epoch: [0][1929/2217]	 Loss: 0.162
Epoch: [0][1930/2217]	 Loss: 0.162
Epoch: [0][1931/2217]	 Loss: 0.162
Epoch: [0][1932/2217]	 Loss: 0.162
Epoch: [0][1933/2217]	 Loss: 0.161
Epoch: [0][1934/2217]	 Loss: 0.161
Epoch: [0][1935/2217]	 Loss: 0.161
Epoch: [0][1936/2217

Epoch: [0][2143/2217]	 Loss: 0.135
Epoch: [0][2144/2217]	 Loss: 0.135
Epoch: [0][2145/2217]	 Loss: 0.135
Epoch: [0][2146/2217]	 Loss: 0.134
Epoch: [0][2147/2217]	 Loss: 0.134
Epoch: [0][2148/2217]	 Loss: 0.134
Epoch: [0][2149/2217]	 Loss: 0.134
Epoch: [0][2150/2217]	 Loss: 0.134
Epoch: [0][2151/2217]	 Loss: 0.134
Epoch: [0][2152/2217]	 Loss: 0.134
Epoch: [0][2153/2217]	 Loss: 0.134
Epoch: [0][2154/2217]	 Loss: 0.134
Epoch: [0][2155/2217]	 Loss: 0.133
Epoch: [0][2156/2217]	 Loss: 0.133
Epoch: [0][2157/2217]	 Loss: 0.133
Epoch: [0][2158/2217]	 Loss: 0.133
Epoch: [0][2159/2217]	 Loss: 0.133
Epoch: [0][2160/2217]	 Loss: 0.133
Epoch: [0][2161/2217]	 Loss: 0.133
Epoch: [0][2162/2217]	 Loss: 0.133
Epoch: [0][2163/2217]	 Loss: 0.132
Epoch: [0][2164/2217]	 Loss: 0.132
Epoch: [0][2165/2217]	 Loss: 0.132
Epoch: [0][2166/2217]	 Loss: 0.132
Epoch: [0][2167/2217]	 Loss: 0.132
Epoch: [0][2168/2217]	 Loss: 0.132
Epoch: [0][2169/2217]	 Loss: 0.132
Epoch: [0][2170/2217]	 Loss: 0.132
Epoch: [0][2171/2217

In [None]:
#eval

checkpoint = torch.load('../models/checkpoint_{}.tar'.format())
transformer = checkpoint['transformer']

while(1):
    question = input('Question: ')
    if question == 'q':
        break
    max_len = input('Max len of words you want to generate: ')
    enc_qus = [word_map(word, word_map['<unk>']) for word in question.split()]
    question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
    question_mask = (question != 0).to(device).unsqueeze(1).unsqueeze(1)
    sentence = evaluate(transformer, question, question_mask, max_len, word_map)
    print(sentence)