In [20]:
import math, time, os, datetime, shutil, pickle

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

import import_ipynb
from MoveData import *
from EncoderDecoder import *
from Talk import *
from Trainer import *
from LearningDynamics import *

#from Beam import *

In [21]:
class MemoryTransformer(nn.Module):
    def __init__(self, in_vocab_size, out_vocab_size, emb_dim, n_layers, 
                 heads, mem_slots, dropout):
        
        super().__init__()
        
        self.emb_dim = emb_dim
        dim_k = emb_dim // heads
        self.mem_slots = mem_slots
        self.batch_size = None
        self.memory = None
        
        self.encoder = Encoder(in_vocab_size, emb_dim, n_layers, heads, dropout)
        self.decoder = Decoder(out_vocab_size, emb_dim, n_layers, heads, dropout)
        self.out = nn.Linear(emb_dim, out_vocab_size)

        self.mhdpa = MultiHeadAttention(heads, emb_dim, dim_k, dropout)
        self.z_gate = nn.Linear(emb_dim, emb_dim)
        self.norm_mem = Norm(emb_dim)
        
    def update_memory(self):
        #print(self.memory.shape, self.d_output.shape)
        if self.memory.size(0) < self.d_output.size(0):
            self.memory = self.memory.repeat(self.d_output.size(0), 1, 1)
        mem_dialogue = torch.cat([self.memory, self.d_output], dim=-2) 
        new_memory, _ = self.mhdpa(self.memory, mem_dialogue, mem_dialogue)
        new_mem_norm = self.norm_mem(new_memory + self.memory)
        z_t = torch.sigmoid(self.z_gate(self.memory))
        self.memory = (1 - z_t)*self.memory + z_t*new_mem_norm
         
    def concat_mem(self, in_encoded, mask):
        if isinstance(self.memory, torch.Tensor):
            #print(in_encoded.shape, self.memory.shape)
            if in_encoded.size(0) < self.memory.size(0):
                in_encoded = in_encoded.repeat(self.memory.size(0), 1, 1)
            in_mem_encoded = torch.cat([in_encoded, self.memory], dim=-2) 
            mask=torch.from_numpy(np.ones((1,in_mem_encoded.size(-2))).astype('uint8'))==1
            mask=torch.stack([mask for _ in range(in_mem_encoded.size(0))])
            return in_mem_encoded, mask
        else:
            self.memory = in_encoded
            return in_encoded, mask 
        
    def repackage_hidden(self, h):
        if isinstance(h, torch.Tensor):
            return h.detach()
        elif h == None:
            return None
        else:
            return tuple(self.repackage_hidden(v) for v in h)
        
    def forward(self, in_toks, in_mask, out_toks, out_mask):
        self.memory = self.repackage_hidden(self.memory)
        in_encoded = self.encoder(in_toks, in_mask)
        in_encoded, in_mask = self.concat_mem(in_encoded, in_mask)
        self.d_output = self.decoder(out_toks, out_mask, in_encoded, in_mask)
        output = self.out(self.d_output)
        return output

In [3]:
def talk_to_model(input_str, model, opt, infield, outfield):
    '''
    input:
        input_str is a string, it is what you want to say to the dialogue model
        model is a encoder, decoder and a last layer linear transformation
        opt is an options object with the maximum length of the output sequence opt.max_len
        infield and outfield are the data.fields that store the vocabulary
    output:
        an output string response from the dialogue model
    '''
    model.eval()
    model.cpu()
    input_sequence = string2tensor(input_str, infield) # string to tensor 
    input_mask = (input_sequence != infield.vocab.stoi['<pad>']).unsqueeze(-2) #make input mask
    encoding = model.encoder(input_sequence, input_mask)
    init_tok = outfield.vocab.stoi['<sos>'] # this is the integer for the start token
    decoder_input = torch.LongTensor([[init_tok]]) # use start token to initiate the decoder
    
    for pos in range(opt.max_len):
        decoder_input_mask = nopeak_mask(size=pos+1, opt=opt) # make target mask, pos+1 casue pos starts at 0
        out = model.out(model.decoder(decoder_input, decoder_input_mask, encoding, input_mask))
        softout = F.softmax(out, dim=-1) 

        distr = Categorical(probs=softout)
        action = distr.sample()[:,-1].unsqueeze(0) # sample from that distribution to get next token
        decoder_input = torch.cat((decoder_input, action), dim=1) 

        if outfield.vocab.itos[action] == '<eos>':
            de_str = ' '.join([outfield.vocab.itos[tok] for tok in decoder_input[0][1:-1]])
            return de_str
        
    de_str = ' '.join([outfield.vocab.itos[tok] for tok in decoder_input[0]])
    return de_str

In [70]:
opt = Options(batchsize=1, device = torch.device("cpu"), epochs=20, lr=0.0001, 
              max_len = 25, save_path = '../saved/weights/memory_weights')

data_iter, infield, outfield, opt = json2datatools(path='../saved/memory.json', opt=opt)

emb_dim, n_layers, heads, mem_slots, dropout = 32, 2, 2, 3, 0.01 
chloe = MemoryTransformer(len(infield.vocab), len(outfield.vocab), 
                          emb_dim, n_layers, heads, mem_slots, dropout)

load_subset_weights(chloe, opt)
print(talk_to_model("my name is fluffy", chloe, opt, infield, outfield))

hey fluffy !


In [71]:
#scheduler = CosineWithRestarts(optimizer, T_max=len(conversation_list))
load_subset_weights(chloe, opt)
chloe.eval()

test_list = [
    " my name is fluffy ",
    " what is my name? ",
    " my name is fluffy what is my name?",
    " my name is snuggles",
    " what is my name? ",
    " my name is snuggles what is my name? ",
    " my name is bobo ",
    " what is my name? ",
    " my name is bobo what is my name? "
]

opt.k = 10

for i in test_list:
    print(" > ", i, " > ",  translate_sentence(i,chloe,opt,infield,outfield))
    chloe.update_memory() # Update Memory

 >   my name is fluffy   >  hey fluffy!
 >   what is my name?   >  hey fluffy!
 >   my name is fluffy what is my name?  >  hey fluffy!
 >   my name is snuggles  >  hello snuggles!
 >   what is my name?   >  hey fluffy!
 >   my name is snuggles what is my name?   >  hello snuggles!
 >   my name is bobo   >  hi bobo!
 >   what is my name?   >  hey fluffy!
 >   my name is bobo what is my name?   >  <unk> <unk> bobo


In [69]:

conversation_list = [
{"listen":"my name is fluffy", "reply":"hey fluffy!"},
{"listen":"what is my name?", "reply":"fluffy pillow"},
{"listen":"my name is fluffy what is my name?", "reply":"fluffy pillow"},
{"listen":"my name is snuggles", "reply":"hello snuggles!"},
{"listen":"what is my name?", "reply":"snuggles the bunny"},
{"listen":"my name is snuggles what is my name?", "reply":"snuggles the bunny"},
{"listen":"my name is bobo", "reply":"hi bobo!"},
{"listen":"what is my name?", "reply":"you are bobo"},
{"listen":"my name is bobo what is my name?", "reply":"you are bobo"},
                    ]

optimizer = torch.optim.Adam(chloe.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.6, patience=3)

sos_tok = torch.LongTensor([[outfield.vocab.stoi['<sos>']]]) 
eos_tok = torch.LongTensor([[outfield.vocab.stoi['<eos>']]]) 

chloe.train()
start = time.time()
best_loss = 100
opt.epochs = 50 
for epoch in range(opt.epochs):
    total_loss = 0
    for i in range(len(conversation_list)):
        listen_string = conversation_list[i]["listen"]
        reply_string = conversation_list[i]["reply"]
        listen_toks = string2tensor(listen_string, infield)
        reply_toks = string2tensor(reply_string, outfield)
        reply_start = torch.cat((sos_tok,reply_toks), dim=1)
        reply_labels = torch.cat((reply_toks,eos_tok), dim=1).contiguous().view(-1)
        
        listen_mask, reply_mask = create_masks(listen_toks, reply_start, opt)
        
        logits = chloe(listen_toks, listen_mask, reply_start, reply_mask)
        
        chloe.update_memory() # Update Memory
        
        flat_logits = logits.view(-1, logits.size(-1))
        optimizer.zero_grad()
        batch_loss = F.cross_entropy(flat_logits, reply_labels, ignore_index = opt.trg_pad)

        batch_loss.backward() #batch_loss.backward(retain_graph=True) #
        torch.nn.utils.clip_grad_norm_(chloe.parameters(), max_norm = 1.0) 
        optimizer.step()

        total_loss += batch_loss.item()

    epoch_loss = total_loss/len(conversation_list)
    scheduler.step(epoch_loss)

    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(chloe.state_dict(), opt.save_path)
        print("%dm: epoch %d loss = %.3f" %((time.time() - start)//60, 
                                        epoch, epoch_loss))
        
    if epoch_loss < 0.09:
        break
    
    total_loss = 0
    
print("finished")

0m: epoch 0 loss = 0.765
0m: epoch 1 loss = 0.399
0m: epoch 2 loss = 0.152
0m: epoch 3 loss = 0.132
0m: epoch 4 loss = 0.129
0m: epoch 8 loss = 0.128
0m: epoch 17 loss = 0.113
finished


Next we need to train the memory. How do we do this? we need to talk to the model and allow it to accumulate at least one cycle of conversation, then teach it to respond correctly given the previous listen-reply exchange