In [618]:
import math, time, os, datetime, shutil, pickle

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

import import_ipynb
from MoveData import *
from EncoderDecoder import *
from Elements import * 
from Talk import *
from Trainer import *
from LearningDynamics import *

In [745]:
class MemoryTransformer(nn.Module):
    def __init__(self, in_vocab_size, out_vocab_size, emb_dim, n_layers, 
                 heads, mem_slots, dropout):
        
        super().__init__()
        
        self.batch_size = None
        dim_k = emb_dim // heads
        self.mem_slots = mem_slots
        self.emb_dim = emb_dim
        
        self.encoder = Decoder(in_vocab_size, emb_dim, n_layers, heads, dropout)
        self.decoder = Decoder(out_vocab_size, emb_dim, n_layers, heads, dropout)
        self.out = nn.Linear(emb_dim, out_vocab_size)
        
        with torch.no_grad():
            self.memory = torch.randn((1,self.mem_slots, emb_dim))
        '''  
        self.memory = torch.stack([torch.eye(self.mem_slots) for _ in range(1)])
        if emb_dim > self.mem_slots:
            difference = emb_dim - self.mem_slots
            pad = torch.zeros((1, self.mem_slots, difference))
            self.memory = torch.cat([self.memory, pad], -1)
        elif emb_dim < self.mem_slots:
            self.memory = self.memory[:, :, :emb_dim]
        '''
        self.mem_mask = torch.ones(1,1,self.mem_slots) == 1
        
        self.mem_update = MultiHeadAttention(heads, emb_dim, dim_k, dropout)
        self.normalizeMemory1 = Norm(emb_dim)
        self.z_gate = nn.Linear(emb_dim*2, emb_dim)

    def repackage_hidden(self, h):
        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
            return tuple(self.repackage_hidden(v) for v in h)
        
    def update_memory(self):
        mem_dialogue = torch.cat([self.memory, self.e_output, self.d_output], dim=-2) 
        #new_memory, scores = self.mem_update(self.memory, mem_dialogue, mem_dialogue)
        self.memory, scores = self.mem_update(self.memory, mem_dialogue, mem_dialogue)
        #new_mem_norm = self.normalizeMemory1(new_memory + self.memory)
        #z_t = torch.sigmoid(self.z_gate(torch.cat([self.memory, new_mem_norm], dim=-1))) 
        #self.memory = (1 - z_t)*self.memory + z_t*new_mem_norm
        
    def forward(self, in_toks, in_mask, out_toks, out_mask):
        self.memory = self.repackage_hidden(self.memory)  
        self.mem_mask = self.repackage_hidden(self.mem_mask)  
        self.e_output = self.encoder(in_toks, in_mask, self.memory, self.mem_mask)
        self.d_output = self.decoder(out_toks, out_mask, self.e_output, in_mask)
        output = self.out(self.d_output)
        return output

In [746]:
def talk_to_model(input_str, model, opt, infield, outfield):
    '''
    input:
        input_str is a string, it is what you want to say to the dialogue model
        model is a encoder, decoder and a last layer linear transformation
        opt is an options object with the maximum length of the output sequence opt.max_len
        infield and outfield are the data.fields that store the vocabulary
    output:
        an output string response from the dialogue model
    '''
    model.eval()
    model.cpu()
    input_sequence = string2tensor(input_str, infield) # string to tensor 
    input_mask = (input_sequence != infield.vocab.stoi['<pad>']).unsqueeze(-2) #make input mask
    model.e_output = model.encoder(input_sequence, input_mask, model.memory, model.mem_mask)
    init_tok = outfield.vocab.stoi['<sos>'] # this is the integer for the start token
    decoder_input = torch.LongTensor([[init_tok]]) # use start token to initiate the decoder
    
    for pos in range(opt.max_len):
        decoder_input_mask = nopeak_mask(size=pos+1, opt=opt) # make target mask, pos+1 casue pos starts at 0
        model.d_output = model.decoder(decoder_input, decoder_input_mask, model.e_output, input_mask)
        out = model.out(model.d_output)
        softout = F.softmax(out, dim=-1) 

        distr = Categorical(probs=softout)
        action = distr.sample()[:,-1].unsqueeze(0) # sample from that distribution to get next token
        decoder_input = torch.cat((decoder_input, action), dim=1) 

        if outfield.vocab.itos[action] == '<eos>':
            de_str = ' '.join([outfield.vocab.itos[tok] for tok in decoder_input[0][1:-1]])
            return de_str
        
    de_str = ' '.join([outfield.vocab.itos[tok] for tok in decoder_input[0]])
    return de_str

In [747]:
opt = Options(batchsize=1, device = torch.device("cpu"), epochs=20, lr=0.01, 
              max_len = 25, save_path = '../saved/weights/memory_weights')

data_iter, infield, outfield, opt = json2datatools(path='../saved/memory.json', opt=opt)

emb_dim, n_layers, heads, mem_slots, dropout = 8, 1, 4, 4, 0.0

chloe = MemoryTransformer(len(infield.vocab), len(outfield.vocab), emb_dim, n_layers, heads, mem_slots, dropout)


print(chloe.memory[0])
print(talk_to_model("my name is bobo", chloe, opt, infield, outfield))
chloe.update_memory()
print(chloe.memory[0])
print(talk_to_model("what is my name", chloe, opt, infield, outfield))
chloe.update_memory()
print(chloe.memory[0])

tensor([[-0.4841, -0.0458,  1.2232,  0.0397, -2.1136, -0.5114,  0.5870, -0.9800],
        [-0.8795, -0.2978,  1.3578, -1.0898, -0.5419, -1.7764,  0.3256,  0.7855],
        [-0.4614,  0.4574,  1.4177, -0.8479,  0.9467,  0.0997,  0.4975, -0.1688],
        [ 1.2256,  0.1544, -0.6595,  0.8973, -0.0573, -1.2790, -0.1369,  0.1134]])

tensor([[-0.0134,  0.0326,  0.1355, -0.1515,  0.4848,  0.0264,  0.2180, -0.2135],
        [-0.0732, -0.0125,  0.1349, -0.0699,  0.4909, -0.0727,  0.2234, -0.1524],
        [-0.0239, -0.0292,  0.1792, -0.0644,  0.4947, -0.0371,  0.2724, -0.1650],
        [ 0.0107,  0.0119,  0.1586, -0.1304,  0.4849,  0.0373,  0.2529, -0.2087]],
       grad_fn=<SelectBackward>)
hi bobo taco <pad> <unk> bobo <sos> hi <pad> taco hi hello hello hi hi
tensor([[ 0.0932,  0.1204,  0.0775, -0.1688,  0.4691, -0.0655, -0.0106, -0.0411],
        [ 0.0927,  0.1201,  0.0777, -0.1688,  0.4695, -0.0678, -0.0110, -0.0408],
        [ 0.0921,  0.1200,  0.0771, -0.1685,  0.4696, -0.0668, -0.0111, -

In [748]:

conversation_list = [
    {"listen":"hello", "reply":"hi"},
    {"listen":"im sunggles", "reply":"hello snuggles"},
    {"listen":"whats my name", "reply":"snuggles"},
    {"listen":"im bobo", "reply":"bobo"},
    {"listen":"whats my name", "reply":"bobo"},
    {"listen":"im taco", "reply":"hello taco"},
    {"listen":"whats my name", "reply":"taco"},
    {"listen":"my name is fluffy", "reply":"hello fluffy"},
    {"listen":"whats my name", "reply":"fluffy"},
                    ]

opt.lr = 0.01
opt.epochs = 10

optimizer = torch.optim.Adam(chloe.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience=5)

sos_tok = torch.LongTensor([[outfield.vocab.stoi['<sos>']]]) 
eos_tok = torch.LongTensor([[outfield.vocab.stoi['<eos>']]]) 

chloe.train()
start = time.time()
best_loss = 100

for epoch in range(opt.epochs):
    total_loss = 0
    for i in range(len(conversation_list)):
        listen_string = conversation_list[i]["listen"]
        reply_string = conversation_list[i]["reply"]
        listen_toks = string2tensor(listen_string, infield)
        reply_toks = string2tensor(reply_string, outfield)
        reply_start = torch.cat((sos_tok,reply_toks), dim=1)
        reply_labels = torch.cat((reply_toks, eos_tok), dim=1).contiguous().view(-1)
        
        listen_mask, reply_mask = create_masks(listen_toks, reply_start, opt)
        
        logits = chloe(listen_toks, listen_mask, reply_start, reply_mask)
        
        chloe.update_memory()
        
        flat_logits = logits.view(-1, logits.size(-1))
        optimizer.zero_grad()
        batch_loss=F.cross_entropy(flat_logits,reply_labels,ignore_index=opt.trg_pad)
        batch_loss.backward() 
        torch.nn.utils.clip_grad_norm_(chloe.parameters(), max_norm = 1.0) 
        optimizer.step()

        total_loss += batch_loss.item()

    epoch_loss = total_loss/len(conversation_list)
    scheduler.step(epoch_loss)

    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(chloe.state_dict(), opt.save_path)
        print("%dm: epoch %d loss = %.3f" %((time.time() - start)//60, epoch, epoch_loss))
        print(chloe.memory[0])
    
    total_loss = 0
    
print("finished")

load_subset_weights(chloe, opt)

0m: epoch 0 loss = 2.273
tensor([[-0.0274,  0.0043,  0.0720, -0.0387,  0.4485, -0.1252,  0.1340, -0.0538],
        [-0.0274,  0.0043,  0.0720, -0.0387,  0.4485, -0.1252,  0.1340, -0.0538],
        [-0.0274,  0.0043,  0.0720, -0.0387,  0.4485, -0.1252,  0.1340, -0.0538],
        [-0.0274,  0.0043,  0.0720, -0.0387,  0.4485, -0.1252,  0.1340, -0.0538]],
       grad_fn=<SelectBackward>)
0m: epoch 1 loss = 1.607
tensor([[-0.0637, -0.0379,  0.0776,  0.0279,  0.4715, -0.1717,  0.1498, -0.0385],
        [-0.0637, -0.0379,  0.0776,  0.0279,  0.4715, -0.1717,  0.1498, -0.0385],
        [-0.0637, -0.0379,  0.0776,  0.0279,  0.4715, -0.1717,  0.1498, -0.0385],
        [-0.0637, -0.0379,  0.0776,  0.0279,  0.4715, -0.1717,  0.1498, -0.0385]],
       grad_fn=<SelectBackward>)
0m: epoch 2 loss = 1.315
tensor([[-0.0945, -0.0718,  0.0657,  0.0864,  0.4979, -0.1910,  0.1399, -0.0256],
        [-0.0945, -0.0718,  0.0657,  0.0864,  0.4979, -0.1910,  0.1399, -0.0256],
        [-0.0945, -0.0718,  0.0657,  

In [726]:
in_vec = torch.randn(1,1,chloe.emb_dim)
mem_dialogue = torch.cat([chloe.memory, in_vec], dim=-2) 
print(mem_dialogue[0])
new_memory, scores = chloe.mem_update(chloe.memory, mem_dialogue, mem_dialogue)
print("----------------------------------------------------------------------")
print(new_memory[0])
print("----------------------------------------------------------------------")
new_mem_norm = chloe.normalizeMemory1(new_memory + chloe.memory)
print("----------------------------------------------------------------------")
print(new_mem_norm[0])
print("----------------------------------------------------------------------")
z_t = torch.sigmoid(chloe.z_gate(torch.cat([chloe.memory, new_mem_norm], dim=-1))) 
chloe.memory = (1 - z_t)*chloe.memory + z_t*new_mem_norm
print("----------------------------------------------------------------------")
print(chloe.memory[0])

tensor([[-2.1513,  0.4629,  0.5391, -0.7875,  0.3189,  1.0446, -0.5672, -0.4214],
        [-1.2432,  1.0948,  0.5867, -2.5005, -0.4963,  0.3767, -0.7254, -0.2532],
        [ 1.6809, -0.3526, -0.0733, -1.1703,  0.1123, -0.2663, -0.3607,  0.8543],
        [ 0.2006,  0.8663, -1.2880, -1.7130, -0.4821, -0.1936,  2.7281,  0.2105],
        [ 1.0299, -0.1425,  0.8465,  1.0078, -0.7210, -1.8339,  1.4362, -0.3073]])
----------------------------------------------------------------------
tensor([[-0.1576,  0.2983, -0.0238, -0.6226, -0.2082,  0.6605,  0.4974, -0.2098],
        [-0.1808,  0.1269, -0.0847, -0.8431, -0.1231,  0.8690,  0.6174, -0.3190],
        [-0.2514,  0.1243, -0.0298, -0.7493, -0.1235,  0.8246,  0.6049, -0.2453],
        [-0.2788, -0.0535, -0.0585, -0.8879, -0.0593,  0.7805,  0.5836, -0.1028]],
       grad_fn=<SelectBackward>)
----------------------------------------------------------------------
----------------------------------------------------------------------
tensor([[-1.68

tensor([[ 0.1629, -2.0674,  1.2915, -0.1818,  0.0793, -0.3673,  0.1813,  0.8741],
        [ 0.1629, -2.0674,  1.2915, -0.1818,  0.0793, -0.3673,  0.1813,  0.8741],
        [ 0.1629, -2.0674,  1.2915, -0.1818,  0.0793, -0.3673,  0.1813,  0.8741],
        [ 0.1629, -2.0674,  1.2915, -0.1818,  0.0793, -0.3673,  0.1813,  0.8741]],
       grad_fn=<SelectBackward>)

In [671]:
load_subset_weights(chloe, opt) 
#scheduler = CosineWithRestarts(optimizer, T_max=len(conversation_list))

chloe.eval()

test_list = [
    "hello", 
    "im sunggles",
    "whats my name", 
    "im bobo", 
    "whats my name", 
    "im taco", 
    "whats my name", 
    "my name is fluffy",
    "whats my name", 
]

for i in test_list:
    print(" > ", i, " > ",  talk_to_model(i,chloe,opt,infield,outfield))
    chloe.update_memory()

 >  hello  >  hi
 >  im sunggles  >  hello snuggles
 >  whats my name  >  snuggles
 >  im bobo  >  bobo
 >  whats my name  >  fluffy
 >  im taco  >  hello taco
 >  whats my name  >  snuggles
 >  my name is fluffy  >  hello fluffy
 >  whats my name  >  bobo


In [584]:
 while True:
    tell_chloe = input("You > ")
    chloes_reply = talk_to_model(tell_chloe, chloe, opt, infield, outfield)
    chloe.update_memory()
    if ("bye chloe" in tell_chloe or "bye ttyl" in chloes_reply):
        print('Chloe > '+ chloes_reply + '\n')
        break
    else:
        print('Chloe > '+ chloes_reply + '\n') 

You > my name is fluffy
Chloe > hey fluffy !

You > what is my name
Chloe > you are bobo

You > my name is fluffy
Chloe > hey fluffy !

You > what is my name
Chloe > you are bobo

You > what is my name?
Chloe > taco

You > whats my name
Chloe > so , how are you ?



KeyboardInterrupt: 

Next we need to train the memory. How do we do this? we need to talk to the model and allow it to accumulate at least one cycle of conversation, then teach it to respond correctly given the previous listen-reply exchange