In [6]:
import math, copy, sys, logging, json, time, random, os, string, pickle, re

import torch
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

from sklearn.metrics import accuracy_score
import numpy as np
import matplotlib.pyplot as plt

from modules.TransformerComponents import Transformer
from modules.Vocabulary import Vocab
from modules.MetaLearnNeuralMemory import FFMemoryLearned
from modules.LoadTrainSave import save_model, load_model, Teacher

%matplotlib inline
%load_ext autoreload
%autoreload 2

np.random.seed(0) 
random.seed(0)
torch.manual_seed(0)

print('torch.version', torch.__version__)
print('torch.cuda.is_available()', torch.cuda.is_available())
print('torch.cuda.device_count()', torch.cuda.device_count())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
torch.version 1.7.0
torch.cuda.is_available() True
torch.cuda.device_count() 2


In [24]:
class MNMp(nn.Module):

    def __init__(self, dim_hidden, n_heads = 4):
        
        """ dim_hidden is the hidden size of the LSTM controller,
            the Memory Network, and the interaction vectors
            n_heads is the number of interaction heads """

        super(MNMp, self).__init__()
        
        self.dim_hidden = dim_hidden
        self.n_heads = n_heads
        
        self.control = nn.LSTMCell(dim_hidden*2, dim_hidden)
        
        dim_concat_interact = dim_hidden*n_heads*3 + dim_hidden
        self.interaction = nn.Linear(dim_hidden, dim_concat_interact)
        self.memfunc = FFMemoryLearned(dim_hidden)
        self.kv_rate = nn.Linear(dim_hidden, 1)
        self.read_out = nn.Linear(dim_hidden+dim_hidden, dim_hidden)
        
        #self.v_r = None
        #self.h_lstm = None
        #self.c_lstm = None
        
        #''' 
        self.v_r = torch.nn.Parameter( torch.zeros((1, self.dim_hidden)))
        self.h_lstm = torch.nn.Parameter(torch.zeros((1, self.dim_hidden)))
        self.c_lstm = torch.nn.Parameter(torch.zeros((1, self.dim_hidden)))
        #'''
        
    def repeat_v_h_c(self, batch_size):

            #self.v_r = torch.zeros((batch_size, self.dim_hidden)).float()
            #self.h_lstm = torch.zeros((batch_size, self.dim_hidden)).float()
            #self.c_lstm = torch.zeros((batch_size, self.dim_hidden)).float()
            
            #'''
            if self.v_r.shape[0] != batch_size:
                v_r = self.v_r.repeat(batch_size,1)
                h_lstm = self.h_lstm.repeat(batch_size,1)
                c_lstm = self.c_lstm.repeat(batch_size,1)
            #'''
            
            if next(self.parameters()).is_cuda:
                self.v_r = self.v_r.cuda()
                self.h_lstm = self.h_lstm.cuda()
                self.c_lstm = self.c_lstm.cuda()
                                         
            return v_r, h_lstm, c_lstm
            
    def forward(self, x):
        """ the input must have shape (batch_size, emb_dim) because it will be 
        concatenated with self.v_r of the same shape """

        v_r, h_lstm, c_lstm = self.repeat_v_h_c(x.shape[0])
        
        x = x.squeeze(1)
        
        h_lstm, c_lstm = self.control(torch.cat([x, v_r], dim=1), (h_lstm, c_lstm))
        
        int_vecs = torch.tanh(self.interaction(h_lstm))
        
        beta_, n_k_v = torch.split(int_vecs, 
                                   [self.dim_hidden,self.dim_hidden*self.n_heads*3],
                                   dim=1)  
        
        beta = torch.sigmoid(self.kv_rate(beta_)) #(batch_size,1)
        
        n_k_v = n_k_v.view(n_k_v.shape[0], self.n_heads, -1).contiguous()
        
        k_w, v_w, k_r = torch.chunk(n_k_v, 3, dim=2)
        
        reconst_loss, reconst_loss_init = self.memfunc.update(k_w, v_w, beta_rate=beta)
        
        v_r = self.memfunc.read(k_r)
        
        h_lstm = self.read_out(torch.cat([h_lstm, v_r], dim=1))

        return h_lstm.unsqueeze(1), reconst_loss, reconst_loss_init 

In [25]:

class Bot(nn.Module):
    
    def __init__(self, emb_dim, n_layers, heads, dropout, vocab):
        
        super().__init__()
        
        self.emb_dim = emb_dim
        
        self.vocab = vocab
        self.sos_tok = torch.LongTensor([[self.vocab.word2index["<SOS>"]]]) 
        self.eos_tok = torch.LongTensor([[self.vocab.word2index["<EOS>"]]]) 
        
        self.encodeInput = Transformer(emb_dim, n_layers, heads, dropout)
        self.encodeEncoding = Transformer(emb_dim, n_layers, heads, dropout)
        self.decodeEncoding = Transformer(emb_dim, n_layers, heads, dropout)

        self.mnm = MNMp(emb_dim, heads)
        
        self.context_vec = None
        
    def memory_utils(self, batch_size):

        if self.context_vec is None:
            cntxt_seq_len = 1
            self.context_vec = torch.randn(batch_size, cntxt_seq_len, self.emb_dim)
            
        if self.context_vec.shape[0] > batch_size:
            self.context_vec = self.context_vec[0,:,:]
            
        if self.context_vec.shape[0] < batch_size:
            self.context_vec = self.context_vec[0,:,:].repeat(batch_size, 1, 1)
    
        self.context_vec = self.context_vec.detach()
        self.mnm.memfunc.detach_mem()
        
    def forward(self, in_toks, in_mask, out_toks, out_mask):
        
        self.memory_utils(batch_size = in_toks.shape[0])
        
        in_vecs = self.vocab.embedding(in_toks)
        out_vec = self.vocab.embedding(out_toks)

        self.context_vec, rcl, rcli = self.mnm(self.context_vec)
        encin_vec = self.encodeInput(in_vecs, in_mask, self.context_vec, None)
        self.context_vec = self.encodeEncoding(self.context_vec, None, encin_vec, None)
        
        dout = self.decodeEncoding(out_vec, out_mask, encin_vec, in_mask)
        
        return dout, rcl, rcli
    
    def teacher_forcing(self, src, trg):
        
        self.train()
        trg_start = torch.cat((self.sos_tok.repeat(trg.shape[0],1), trg),dim=1)
        trg_end = torch.cat((trg, self.eos_tok.repeat(trg.shape[0],1)),dim=1)
        src_mask = (src != self.vocab.word2index["<PAD>"]).unsqueeze(-2)
        trg_mask = (trg_end != self.vocab.word2index["<PAD>"]).unsqueeze(-2)
        
        seq_len = trg_start.size(1) 
        np_mask = np.triu(np.ones((1,seq_len,seq_len)),k=1).astype('uint8')
        np_mask =  torch.from_numpy(np_mask) == 0
        
        if trg.is_cuda:
            np_mask = np_mask.cuda()
            
        trg_mask = trg_mask & np_mask
        
        out_vecs, rcl, rcli = self.forward(src, src_mask, trg_start, trg_mask)
        
        return out_vecs, trg_end, rcl, rcli
    
    def string2string(self, input_string, maxlen = 20):
        
        self.eval()
        in_toks = self.vocab.string2tensor(input_string)
        in_vecs = self.vocab.embedding(in_toks)
        
        self.memory_utils(batch_size=in_toks.shape[0])
        
        self.context_vec, rcl, rcli = self.mnm(self.context_vec)
        encin_vec = self.encodeInput(in_vecs, None, self.context_vec, None)
        self.context_vec = self.encodeEncoding(self.context_vec, None, encin_vec, None)
        
        decode_toks = self.sos_tok
        
        for pos in range(maxlen):
            
            decode_vecs = self.vocab.embedding(decode_toks)
            dout = self.decodeEncoding(decode_vecs, None, encin_vec, None)
            vocabdist = self.vocab.emb2vocab(dout)
            next_toks = torch.argmax(vocabdist, dim=2)
            decode_toks = torch.cat((decode_toks, next_toks[:,-1].unsqueeze(0)), dim=1) 
            
            if next_toks[:,-1] == self.eos_tok.squeeze(0):
                
                toks = decode_toks[0][1:-1].data.cpu().numpy()
                de_str = ' '.join([self.vocab.index2word[int(tok)] for tok in toks])

                return de_str
            
        toks = decode_toks[0].data.cpu().numpy()
        de_str = ' '.join([self.vocab.index2word[tok] for tok in toks])
        return de_str

In [26]:
emb_dim, n_layers, heads, dropout = 32, 2, 2, 0.05

vocab = Vocab(emb_dim)
model = Bot(emb_dim, n_layers, heads, dropout, vocab)
teacher = Teacher(model.vocab)

In [28]:
total_batches = 64*10
best_acc = 0
lamda = 8
batch_size = 64
learning_rate = 0.01 #0.001

optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate,betas=(0.9, 0.98),eps=1e-9)
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min',factor=0.99,patience=100)

loss_all_list = []
rcloss_all_list = []
accuracy_list = []

for batch in range(total_batches):
    
    intro, introtarget, whatmyname, yournameis = teacher.get_batch(batch_size)
    
    out_vecs, trg_end, rcl, rcli = model.teacher_forcing(intro, introtarget)
    
    vocab_logits = model.vocab.emb2vocab(out_vecs)
    
    predictions = vocab_logits.view(-1, vocab_logits.size(-1))
    
    target = trg_end.view(-1)

    batch_loss = F.cross_entropy(predictions, target, 
                                 ignore_index = model.vocab.word2index["<PAD>"])

    reconstruction_loss = lamda*rcl
    
    ################# Next Part of Conversation ########################
    
    out_vecs, trg_end, rcl, rcli = model.teacher_forcing(whatmyname, yournameis)
    
    vocab_logits = model.vocab.emb2vocab(out_vecs)

    predictions = vocab_logits.view(-1, vocab_logits.size(-1))
    
    target = trg_end.view(-1)
    
    acc = accuracy_score(target, torch.argmax(predictions, dim=1))

    batch_loss += F.cross_entropy(predictions, target, 
                                 ignore_index = model.vocab.word2index["<PAD>"])
    
    reconstruction_loss += lamda*rcl
    
    ################ Cobine Losses +++++++++++++++++++++++++++++
    conversation_loss = batch_loss + reconstruction_loss
    
    scheduler.step(conversation_loss)
    optimizer.zero_grad()
    conversation_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    if batch % int(total_batches/20 + 1) == 0:
        
        loss_all_list.append(conversation_loss.float().item())
        rcloss_all_list.append(reconstruction_loss.float().item())
        accuracy_list.append(acc)
        mean_accuracy = np.mean(accuracy_list[-10:])

        if mean_accuracy > best_acc:
            print('Saving Model...')
            best_acc = mean_accuracy
            
            ''' 
            pickle.dump(model.vocab.word2index,open("modelstate/word2index.p","wb"))
            pickle.dump(model.vocab.index2word,open("modelstate/index2word.p","wb"))
            pickle.dump(model.vocab.emb2vocab.weight,open("modelstate/emb2vocab.weight.p","wb"))
            pickle.dump(model.vocab.embedding.weight,open("modelstate/embedding.weight.p","wb"))
            pickle.dump(model.context_vec,open("modelstate/context_vec.p","wb"))
            pickle.dump(model.mnm.memfunc.Ws,open("modelstate/Ws.p","wb"))
            save_model(model,"modelstate/task.pth")
            '''
            
        print("mean accuracy", round(mean_accuracy,4), 
              "celoss", round(batch_loss.float().item(),4), 
              "rcloss", round(reconstruction_loss.float().item(),6), 
              "d_rcloss", round((rcli - rcl).float().item(),4),
              "training progress", round(batch/total_batches,4),
              "learning rate", scheduler._last_lr)
            
        if mean_accuracy > 0.97:
            break

Saving Model...
mean accuracy 0.7031 celoss 0.7518 rcloss 7.6e-05 d_rcloss 0.0424 training progress 0.0 learning rate [0.01]
Saving Model...
mean accuracy 0.7422 celoss 0.4669 rcloss 0.004161 d_rcloss 0.0468 training progress 0.0516 learning rate [0.01]
Saving Model...
mean accuracy 0.7639 celoss 0.5378 rcloss 0.006587 d_rcloss 0.043 training progress 0.1031 learning rate [0.01]
Saving Model...
mean accuracy 0.7682 celoss 0.4938 rcloss 0.000683 d_rcloss 0.0445 training progress 0.1547 learning rate [0.01]
mean accuracy 0.7667 celoss 0.4806 rcloss 0.001113 d_rcloss 0.0479 training progress 0.2062 learning rate [0.01]
Saving Model...
mean accuracy 0.7752 celoss 0.4967 rcloss 0.002815 d_rcloss 0.0473 training progress 0.2578 learning rate [0.01]
Saving Model...
mean accuracy 0.7783 celoss 0.4513 rcloss 0.006464 d_rcloss 0.0531 training progress 0.3094 learning rate [0.01]
Saving Model...
mean accuracy 0.7812 celoss 0.4197 rcloss 0.001991 d_rcloss 0.0513 training progress 0.3609 learning r