In [None]:
import math, time, os, datetime, shutil, pickle

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

import import_ipynb
from MoveData import *
from EncoderDecoder import *
from Talk import *
from Trainer import *
from LearningDynamics import *

In [599]:
class MemoryTransformer(nn.Module):
    def __init__(self, in_vocab_size, out_vocab_size, emb_dim, n_layers, 
                 heads, mem_slots, dropout):
        
        super().__init__()
        
        self.batch_size = None
        dim_k = emb_dim // heads
        self.mem_slots = mem_slots
        
        self.encoder = Decoder(in_vocab_size, emb_dim, n_layers, heads, dropout)
        self.decoder = Decoder(out_vocab_size, emb_dim, n_layers, heads, dropout)
        self.out = nn.Linear(emb_dim, out_vocab_size)
        
        self.memory = torch.randn((1,self.mem_slots, emb_dim))
        self.mem_mask = torch.ones(1,1,self.mem_slots) == 1
        
        self.mem_update = MultiHeadAttention(heads, emb_dim, dim_k, dropout)
        self.normalizeMemory1 = Norm(emb_dim)
        self.z_gate = nn.Linear(emb_dim*2, emb_dim)

    def repackage_hidden(self, h):
        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
            return tuple(self.repackage_hidden(v) for v in h)
        
    def update_memory(self):
        mem_dialogue = torch.cat([self.memory, self.e_output, self.d_output], dim=-2) 
        new_memory, scores = self.mem_update(self.memory, mem_dialogue, mem_dialogue)
        new_mem_norm = self.normalizeMemory1(new_memory + self.memory)
        z_t = torch.sigmoid(self.z_gate(torch.cat([self.memory, new_mem_norm], dim=-1))) 
        self.memory = (1 - z_t)*self.memory + z_t*new_mem_norm
        
    def forward(self, in_toks, in_mask, out_toks, out_mask):
        self.memory = self.repackage_hidden(self.memory)  
        self.mem_mask = self.repackage_hidden(self.mem_mask)  
        self.e_output = self.encoder(in_toks, in_mask, self.memory, self.mem_mask)
        self.d_output = self.decoder(out_toks, out_mask, self.e_output, in_mask)
        output = self.out(self.d_output)
        return output

In [600]:
def talk_to_model(input_str, model, opt, infield, outfield):
    '''
    input:
        input_str is a string, it is what you want to say to the dialogue model
        model is a encoder, decoder and a last layer linear transformation
        opt is an options object with the maximum length of the output sequence opt.max_len
        infield and outfield are the data.fields that store the vocabulary
    output:
        an output string response from the dialogue model
    '''
    model.eval()
    model.cpu()
    input_sequence = string2tensor(input_str, infield) # string to tensor 
    input_mask = (input_sequence != infield.vocab.stoi['<pad>']).unsqueeze(-2) #make input mask
    model.e_output = model.encoder(input_sequence, input_mask, model.memory, model.mem_mask)
    init_tok = outfield.vocab.stoi['<sos>'] # this is the integer for the start token
    decoder_input = torch.LongTensor([[init_tok]]) # use start token to initiate the decoder
    
    for pos in range(opt.max_len):
        decoder_input_mask = nopeak_mask(size=pos+1, opt=opt) # make target mask, pos+1 casue pos starts at 0
        model.d_output = model.decoder(decoder_input, decoder_input_mask, model.e_output, input_mask)
        out = model.out(model.d_output)
        softout = F.softmax(out, dim=-1) 

        distr = Categorical(probs=softout)
        action = distr.sample()[:,-1].unsqueeze(0) # sample from that distribution to get next token
        decoder_input = torch.cat((decoder_input, action), dim=1) 

        if outfield.vocab.itos[action] == '<eos>':
            de_str = ' '.join([outfield.vocab.itos[tok] for tok in decoder_input[0][1:-1]])
            return de_str
        
    de_str = ' '.join([outfield.vocab.itos[tok] for tok in decoder_input[0]])
    return de_str

In [601]:
opt = Options(batchsize=1, device = torch.device("cpu"), epochs=20, lr=0.01, 
              max_len = 25, save_path = '../saved/weights/memory_weights')

data_iter, infield, outfield, opt = json2datatools(path='../saved/memory.json', opt=opt)

emb_dim, n_layers, heads, mem_slots, dropout = 8, 3, 4, 2, 0.1

chloe = MemoryTransformer(len(infield.vocab), len(outfield.vocab), emb_dim, n_layers, heads, mem_slots, dropout)

print(talk_to_model("my name is bobo", chloe, opt, infield, outfield))
print(talk_to_model("what is my name?", chloe, opt, infield, outfield))


<sos> ! pillow pillow <pad> , <pad> , ? hi taco the taco how so how ! ? ! hi snuggles <pad> snuggles you bunny hello


In [614]:

conversation_list = [
    {"listen":" ", "reply":"so, how are you ?"},
    {"listen":"my name is snuggles", "reply":"hello snuggles!"},
    {"listen":"what is my name?", "reply":"snuggles the bunny"},
    {"listen":"my name is bobo", "reply":"hi bobo!"},
    {"listen":"what is my name?", "reply":"you are bobo"},
    {"listen":"my name is taco", "reply":"sup taco !"},
    {"listen":"what is my name?", "reply":"taco"},
    {"listen":"my name is fluffy", "reply":"hey fluffy!"},
    {"listen":"what is my name?", "reply":"fluffy pillow"},
                    ]

opt.lr = 0.01
opt.epochs = 50 

optimizer = torch.optim.Adam(chloe.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10)

sos_tok = torch.LongTensor([[outfield.vocab.stoi['<sos>']]]) 
eos_tok = torch.LongTensor([[outfield.vocab.stoi['<eos>']]]) 

chloe.train()
start = time.time()
best_loss = 100

for epoch in range(opt.epochs):
    total_loss = 0
    for i in range(len(conversation_list)):
        listen_string = conversation_list[i]["listen"]
        reply_string = conversation_list[i]["reply"]
        listen_toks = string2tensor(listen_string, infield)
        reply_toks = string2tensor(reply_string, outfield)
        reply_start = torch.cat((sos_tok,reply_toks), dim=1)
        reply_labels = torch.cat((reply_toks,eos_tok), dim=1).contiguous().view(-1)
        
        listen_mask, reply_mask = create_masks(listen_toks, reply_start, opt)
        
        logits = chloe(listen_toks, listen_mask, reply_start, reply_mask)
        
        chloe.update_memory()
        
        flat_logits = logits.view(-1, logits.size(-1))
        optimizer.zero_grad()
        batch_loss=F.cross_entropy(flat_logits,reply_labels,ignore_index=opt.trg_pad)
        batch_loss.backward() 
        torch.nn.utils.clip_grad_norm_(chloe.parameters(), max_norm = 1.0) 
        optimizer.step()

        total_loss += batch_loss.item()

    epoch_loss = total_loss/len(conversation_list)
    scheduler.step(epoch_loss)

    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(chloe.state_dict(), opt.save_path)
        print("%dm: epoch %d loss = %.3f" %((time.time() - start)//60, epoch, epoch_loss))
    
    total_loss = 0
    
print("finished")
load_subset_weights(chloe, opt)

0m: epoch 0 loss = 0.459
0m: epoch 1 loss = 0.312
0m: epoch 4 loss = 0.273
0m: epoch 8 loss = 0.273
0m: epoch 15 loss = 0.272
0m: epoch 24 loss = 0.253
0m: epoch 25 loss = 0.190
finished


In [616]:
load_subset_weights(chloe, opt) 
#scheduler = CosineWithRestarts(optimizer, T_max=len(conversation_list))

chloe.eval()

test_list = [
    " ", 
    "my name is snuggles", 
    "what is my name?", 
    "my name is bobo", 
    "what is my name?", 
    "my name is taco", 
    "what is my name?", 
    "my name is fluffy", 
    "what is my name?", 
    " ",
]

for i in test_list:
    print(" > ", i, " > ",  talk_to_model(i,chloe,opt,infield,outfield))
    chloe.update_memory()

 >     >  so , how are you ?
 >  my name is snuggles  >  hello snuggles !
 >  what is my name?  >  you are bobo
 >  my name is bobo  >  hi bobo !
 >  what is my name?  >  fluffy pillow
 >  my name is taco  >  sup taco !
 >  what is my name?  >  you are bobo
 >  my name is fluffy  >  hey fluffy !
 >  what is my name?  >  you are bobo
 >     >  so , how are you ?


In [584]:
 while True:
    tell_chloe = input("You > ")
    chloes_reply = talk_to_model(tell_chloe, chloe, opt, infield, outfield)
    chloe.update_memory()
    if ("bye chloe" in tell_chloe or "bye ttyl" in chloes_reply):
        print('Chloe > '+ chloes_reply + '\n')
        break
    else:
        print('Chloe > '+ chloes_reply + '\n') 

You > my name is fluffy
Chloe > hey fluffy !

You > what is my name
Chloe > you are bobo

You > my name is fluffy
Chloe > hey fluffy !

You > what is my name
Chloe > you are bobo

You > what is my name?
Chloe > taco

You > whats my name
Chloe > so , how are you ?



KeyboardInterrupt: 

Next we need to train the memory. How do we do this? we need to talk to the model and allow it to accumulate at least one cycle of conversation, then teach it to respond correctly given the previous listen-reply exchange