In [1]:
import math, time, os, datetime, shutil, pickle

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

import import_ipynb
from MoveData import *
from EncoderDecoder import *
from Talk import *
from Trainer import *
from Talk import *

importing Jupyter notebook from MoveData.ipynb
importing Jupyter notebook from EncoderDecoder.ipynb
importing Jupyter notebook from Elements.ipynb
importing Jupyter notebook from Talk.ipynb
importing Jupyter notebook from Trainer.ipynb


[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/carsonlam/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


## MemoryTransformer

In [28]:
class MemoryTransformer(nn.Module):
    def __init__(self, in_vocab_size, out_vocab_size, emb_dim, n_layers, 
                 heads, mem_slots, dropout):
        super().__init__()
        
        self.batch_size = None
        dim_k = emb_dim // heads
        self.mem_slots = mem_slots
        
        self.encoder = Encoder(in_vocab_size, emb_dim, n_layers, heads, dropout)
        self.decoder = Decoder(out_vocab_size, emb_dim, n_layers, heads, dropout)
        self.out = nn.Linear(emb_dim, out_vocab_size)
        
        
        #with torch.no_grad():
        self.memory = torch.randn((self.mem_slots, emb_dim))
        '''
        mem_mask = np.ones((1,self.mem_slots)).astype('uint8')
        self.mem_mask =  torch.from_numpy(mem_mask) == 1
        
        self.mem_update = MultiHeadAttention(heads, emb_dim, dim_k, dropout)
        self.z_gate = nn.Linear(emb_dim, emb_dim)
        self.NormalizeMemory = Norm(emb_dim)
        '''

    def repackage_hidden(self, h):
        """Wraps hidden states in new Tensors, to detach them from their history."""
        # needed for truncated BPTT, called at every batch forward pass
        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
            return tuple(self.repackage_hidden(v) for v in h)
        
    def batch_memory(self, in_toks):
        self.batch_size = in_toks.size(0)
        self.memory = torch.stack([self.memory for _ in range(self.batch_size)])
        #self.mem_mask = torch.stack([self.mem_mask for _ in range(self.batch_size)])
        print("setting batch size to ", self.batch_size)
        
    def update_memory(self):
        #mem_dialogue = torch.cat([self.memory, self.e_output, self.d_output], dim=-2) 
        #mem_dialogue = torch.cat([self.memory, self.d_output], dim=-2)
        #new_memory, scores = self.mem_update(self.memory, mem_dialogue, mem_dialogue)
        #new_mem_norm = self.NormalizeMemory(new_memory + self.memory)
        #z_t = torch.sigmoid(self.z_gate(self.memory)) # (batch size, memory slots, memory size)
        #self.memory = (1 - z_t)*self.memory + z_t*new_mem_norm
        self.memory = self.e_output

    def forward(self, in_toks, in_mask, out_seq, out_mask):
        self.memory = self.repackage_hidden(self.memory)
        if self.batch_size == None: self.batch_memory(in_toks)
            
        self.e_output = self.encoder(in_toks, in_mask)

        mem_en_vecs = torch.cat([self.memory, self.e_output], dim=-2) 
        mem_en_mask = torch.from_numpy(np.ones((self.batch_size, 1, mem_en_vecs.size(-2)))) == 1
        self.d_output = self.decoder(out_seq, out_mask, mem_en_vecs, mem_en_mask)
        output = self.out(self.d_output)
        
        return output

In [40]:
opt = Options(batchsize=1, device = torch.device("cpu"), epochs=20, lr=0.01, 
              max_len = 25, save_path = '../saved/weights/memory_weights')

data_iter, infield, outfield, opt = json2datatools(path='../saved/memory.json', opt=opt)

emb_dim, n_layers, heads, mem_slots, dropout = 32, 2, 8, 1, 0.01 
chloe = MemoryTransformer(len(infield.vocab), len(outfield.vocab), emb_dim, n_layers, heads, mem_slots, dropout)

#load_subset_weights(chloe, opt)
print(talk_to_chloe("my name is fluffy", chloe, opt, infield, outfield))

<pad> <sos> <sos> <pad> snuggles silly silly its


In [None]:
def trainer(model, data_iterator, options, optimizer, scheduler):

    if torch.cuda.is_available() and options.device == torch.device("cuda:0"):
        print("a GPU was detected, model will be trained on GPU")
        model = model.cuda()
    else:
        print("training on cpu")

    model.train()
    start = time.time()
    best_loss = 100
    for epoch in range(options.epochs):
        total_loss = 0
        for i, batch in enumerate(data_iterator): 
            src = batch.listen.transpose(0,1)
            trg = batch.reply.transpose(0,1)
            trg_input = trg[:, :-1]
            src_mask, trg_mask = create_masks(src, trg_input, options)
            preds = model(src, src_mask, trg_input, trg_mask)
            #model.update_memory() # Update Memory 
            
            ys = trg[:, 1:].contiguous().view(-1)
            optimizer.zero_grad()
            batch_loss = F.cross_entropy(preds.view(-1, preds.size(-1)), 
                                         ys, ignore_index = options.trg_pad)
            batch_loss.backward()
            optimizer.step()
            total_loss += batch_loss.item()

        epoch_loss = total_loss/(num_batches(data_iterator)+1)
        scheduler.step(epoch_loss)

        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), options.save_path)
        print("%dm: epoch %d loss = %.3f" %((time.time() - start)//60, epoch, epoch_loss))
        total_loss = 0

    return model

optimizer = torch.optim.Adam(chloe.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=3)

chloe = trainer(chloe, data_iter, opt, optimizer, scheduler)
print(talk_to_chloe("my name is snuggles", chloe, opt, infield, outfield))
chloe.update_memory() # Update Memory 
print(talk_to_chloe("what is my name?", chloe, opt, infield, outfield))

training on cpu
0m: epoch 0 loss = 0.263
0m: epoch 1 loss = 0.136
0m: epoch 2 loss = 0.173
0m: epoch 3 loss = 0.277
0m: epoch 4 loss = 0.193
0m: epoch 5 loss = 0.148
0m: epoch 6 loss = 0.152
0m: epoch 7 loss = 0.113
0m: epoch 8 loss = 0.131
0m: epoch 9 loss = 0.080
0m: epoch 10 loss = 0.064
0m: epoch 11 loss = 0.167
0m: epoch 12 loss = 0.122
0m: epoch 13 loss = 0.100
0m: epoch 14 loss = 0.159
0m: epoch 15 loss = 0.081
0m: epoch 16 loss = 0.087


In [42]:
opt = Options(batchsize=1, device = torch.device("cpu"), epochs=20, lr=0.01, 
              max_len = 25, save_path = '../saved/weights/memory_weights')

data_iter, infield, outfield, opt = json2datatools(path='../saved/memory.json', opt=opt)

conversation_list = [
    {"listen":"my name is fluffy", "reply":"hello fluffy!"},
    {"listen":"what is my name?", "reply":"its fluffy silly"},
    {"listen":"my name is snuggles", "reply":"hello snuggles!"},
    {"listen":"what is my name?", "reply":"its snuggles silly"},
    {"listen":"my name is bobo", "reply":"hello bobo!"},
    {"listen":"what is my name?", "reply":"its bobo silly"},
                    ]

optimizer = torch.optim.Adam(chloe.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=4)

sos_tok = torch.LongTensor([[outfield.vocab.stoi['<sos>']]]) 
eos_tok = torch.LongTensor([[outfield.vocab.stoi['<eos>']]]) 

chloe.train()
start = time.time()
best_loss = 100
opt.epochs = 20
for epoch in range(opt.epochs):
    total_loss = 0
    for i in range(len(conversation_list)):
        listen_string = conversation_list[i]["listen"]
        reply_string = conversation_list[i]["reply"]
        listen_toks = string2tensor(listen_string, infield)
        reply_toks = string2tensor(reply_string, outfield)
        reply_start = torch.cat((sos_tok,reply_toks), dim=1)
        reply_labels = torch.cat((reply_toks,eos_tok), dim=1).contiguous().view(-1)
        
        listen_mask, reply_mask = create_masks(listen_toks, reply_start, opt)
        
        logits = chloe(listen_toks, listen_mask, reply_start, reply_mask)
        
        chloe.update_memory() # Update Memory
        
        flat_logits = logits.view(-1, logits.size(-1))
        optimizer.zero_grad()
        batch_loss = F.cross_entropy(flat_logits, reply_labels, ignore_index = opt.trg_pad)

        batch_loss.backward() #batch_loss.backward(retain_graph=True) #
        torch.nn.utils.clip_grad_norm_(chloe.parameters(), max_norm = 0.5) 
        optimizer.step()

        total_loss += batch_loss.item()

    epoch_loss = total_loss/len(conversation_list)
    scheduler.step(epoch_loss)

    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(chloe.state_dict(), opt.save_path)
    print("%dm: epoch %d loss = %.3f" %((time.time() - start)//60, 
                                        epoch, epoch_loss))
    total_loss = 0

    
chloe.eval()
print("> my name is fluffy >", talk_to_chloe("my name is fluffy", chloe, opt, infield, outfield))
chloe.update_memory()
print("> what is my name? >", talk_to_chloe("what is my name?", chloe, opt, infield, outfield))
chloe.update_memory()
print("> my name is snuggles >", talk_to_chloe("my name is snuggles", chloe, opt, infield, outfield))
chloe.update_memory()
print("> what is my name? >", talk_to_chloe("what is my name?", chloe, opt, infield, outfield))
chloe.update_memory()
print("> my name is bobo >", talk_to_chloe("my name is bobo", chloe, opt, infield, outfield))
chloe.update_memory()
print("> what is my name? >", talk_to_chloe("what is my name?", chloe, opt, infield, outfield))

0m: epoch 0 loss = 0.446
0m: epoch 1 loss = 0.719
0m: epoch 2 loss = 0.354
0m: epoch 3 loss = 0.475
0m: epoch 4 loss = 0.280
0m: epoch 5 loss = 0.335
0m: epoch 6 loss = 0.250
0m: epoch 7 loss = 0.215
0m: epoch 8 loss = 0.194
0m: epoch 9 loss = 0.190
0m: epoch 10 loss = 0.190
0m: epoch 11 loss = 0.193
0m: epoch 12 loss = 0.192
0m: epoch 13 loss = 0.193
0m: epoch 14 loss = 0.191
0m: epoch 15 loss = 0.191
0m: epoch 16 loss = 0.193
0m: epoch 17 loss = 0.191
0m: epoch 18 loss = 0.191
0m: epoch 19 loss = 0.190
> my name is fluffy > hello fluffy !
> what is my name? > hello bobo !
> my name is snuggles > hello snuggles !
> what is my name? > hello bobo !
> my name is bobo > its bobo silly
> what is my name? > its bobo silly


Next we need to train the memory. How do we do this? we need to talk to the model and allow it to accumulate at least one cycle of conversation, then teach it to respond correctly given the previous listen-reply exchange

meowci beaucoup !


thank am hi
<unk> meowci <unk>
thank meowci hi
<unk> meowci <unk>


You > my name is fluffy
Chloe > <sos> <unk> meowci <unk> <eos>

You > hi
Chloe > <sos> thank am hi <eos>

You > how?
Chloe > <sos> <unk> am <unk> <eos>



KeyboardInterrupt: 