In [2]:
import math, time, os, datetime, shutil, pickle, sys
sys.path.append("../../")
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

from scripts.MoveData import *
from scripts.Transformer import *
from scripts.TalkTrain import *

import nltk
nltk.download('wordnet') 

%load_ext autoreload
%autoreload 2

print('torch.version',torch.__version__)
print('torch.cuda.is_available()',torch.cuda.is_available())

torch.version 1.3.0
torch.cuda.is_available() False


[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/carsonlam/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


This is working version of the TransformerNTM, I show that with long tapered training starting from 0.01 and tapering to very small learning rates ~ 0.0005, you can get a Transformer chatbot to memorize your name and you can change your name and the chatbot will remember your new name. The final loss was loss = 0.010 for this to happen after 400 - 800 epochs, which is waay too many. 

The [Neural Turing Machine paper](https://arxiv.org/pdf/1410.5401.pdf) explains the components of neural memory. [This blog](https://rylanschaeffer.github.io/content/research/neural_turing_machine/main.html) does a great job expalaining the paper to technical non-researchers. 

## Addressing
Adressing is creating weight vectors across the rows of the memory to determine where to read and write. Each stage generates an intermediate weight vector that gets passed to the next stage. First is content addressing:

### Content Adressing 
generates a weight vector based on how similar each row in memory is to a length-C vector key k_t emitted by the controller

<img src="https://rylanschaeffer.github.io/content/research/neural_turing_machine/ntm_addr_1.png">

For each head, the controller produces a key vector kt that is compared to each row of Mt using a similarity measure. In this paper, the authors use cosine similarity

$$ K(k_t, M_t(i)) = \frac{k_t \cdot M_t(i)}{\|k_t\| \cdot \|M_t(i)\|}$$ 

The PyTorch version of this formula is 

`F.cosine_similarity(self.memory + 1e-16, k + 1e-16, dim=-1)`

The variable `wc` in ` wc = self._similarity(k, beta)` is the weighted softmax of these similarities and can be used as and an attention weighting over the rows of the matrix based on similarity to a generated vector k. Larger betas cause the distribution over the rows of the memory to be more concentrated on the highest cosine similarity row, thus beta is called the key strength or focus.

$$w_t^c(i) = \frac{exp\Big(\beta_t K (k_t, M_t(i))\Big)}{\sum_j exp\Big(\beta_t K(k_t, M_t(j))\Big)}$$

`wc = F.softmax(beta * F.cosine_similarity(self.memory + 1e-16, k + 1e-16, dim=-1), dim=1)`

## location-based addressing

In some cases, we may want to read from specific memory locations instead of looking for specific memory values. The example the authors give is the function f(x,y)=x∗y. In this case, we don't care what the values of x and y are, just that x and y are consistently read from the same memory locations. This is called location-based addressing, and to implement it, we'll need three more stages. In the second stage, a scalar parameter g ∈ (0,1), called the interpolation gate, blends the content weight vector wc with the previous time step's weight vector w_t−1 to produce the gated weighting wg. This allows the system learn when to use (or ignore) content-based addressing.

<img src="https://rylanschaeffer.github.io/content/research/neural_turing_machine/ntm_addr_2.png">

$$w_t^g \leftarrow g_t w_t^c + (1- g_t) w_{t-1}$$

                                              wg = g * wc + (1 - g) * w_prev
                                              
## Shift

s - Shift weighting (batch_size, memory_n) (sums to 1)

We'd like the controller to be able to shift focus to other rows. Let's suppose that as one of the system's parameters, the range of allowable shifts is specified. For example, a head's attention could shift forward a row (+1), stay still (0), or shift backward a row(-1). 

<img src="https://rylanschaeffer.github.io/content/research/neural_turing_machine/ntm_addr_3.png">

We'll perform the shifts modulo R so that a shift forward at the bottom row of memory moves the head's attention to the top row, and similarly for a shift backward at the top row. After interpolation, each head emits a normalized shift weighting st, and the following convolutional shift is performed to produce the shifted weight w_hat

$$\tilde{w}_t(i) \leftarrow \sum\limits_{j=0}^{R-1} w_t^g(j) s_t(i-j)$$

                           F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(-1)

I never liked this notation for convolution. it leaves alot unsaid. think of s as being sliding window dot product, if you have already done the intro to pytorch lesson with 2D image convolutions, you might think of s as a filter of size 3 being applied to the image wg. To create the wrap-around padding effect, the last element is appended to the beginning and the first element is appended to the end. since the filter size is 3, this is just the right padding to result in an output tensor of the same size as the input. 

The F.conv1D function takes an input of shape (batch_size, input channels, sequence length), filter of shape (output channels, input channels, filter length) and outputs a tensor of shape (batch_size, output channels, output sequence length) 

If you are wondering why we have a for loop that goes through each sample in the batch, it is because we are not sharing weights across samples, each sample's filter is indenpendant and is part of its own independant history of states and actions, this is not a 1:1 mapping task where we are doing the same task for every sample in the batch 

The fourth and final stage, sharpening, is used to prevent the shifted weight w_hat from blurring. To do this, a scalar gamma >= 1 is required

$$w_t(i) \leftarrow \frac{\tilde{w}_t(i)^{\gamma_t}}{\sum\limits_j \tilde{w}_t(j)^{\gamma_t}}$$

## Writing 

$$\mathcal{M}_t^{erased}(i) \leftarrow \mathcal{M}_{t-1}(i)[\mathbf{1} - w_t(i) e_t ]$$

$$\mathcal{M}_t(i) \leftarrow \mathcal{M}_t^{erased}(i) + w_t(i) a_t$$

$$\mathcal{M}_t(i) \leftarrow \mathcal{M}_{t-1}(i)[\mathbf{1} - w_t(i) e_t ] + w_t(i) a_t $$

The initial write function sequentially overwrites rows in the memory whereas the another possibility 
might be to intelligently choose which row to overwrite by learning a write weighting ww and erase vector e.

In [3]:
class NTM(nn.Module):
    """ Neural Turing Machine Memory"""
    def __init__(self, N, M, controller_size):
        
        """Initialize the Memory matrix.
        The memory's dimensions are (batch_size x N x M).
        Each batch has it's own memory matrix.
        N: Number of rows in the memory.
        M: Number of columns/features in the memory.
        """
        super(NTM, self).__init__()

        self.N = N
        self.M = M
        self.controller_size = controller_size

        self.memory0 = torch.ones(self.N,self.M).abs_()*1e-6

        # create Fully Connected layer for addressing using controller output
        self.address_param_sizes = [self.M, 1, 1, 3, 1]
        self.addresses =  nn.Linear(self.controller_size, 
                                    sum(self.address_param_sizes))  

    def reset(self, batch_size):
        """Reset the memory"""
        self.batch_size = batch_size
        self.write_loc = 0
        self.memory = self.memory0.clone().repeat(batch_size, 1, 1)
        
    def write(self, a):
        
        ww = torch.zeros(a.size(0), self.N)
        ww[:, self.write_loc] = 1.0
        e = torch.ones(a.size(0), self.M)

        erase = torch.matmul(ww.unsqueeze(-1), e.unsqueeze(1))
        add = torch.matmul(ww.unsqueeze(-1), a.unsqueeze(1))

        # write to memory
        self.memory = self.memory * (1 - erase) + add
        self.write_loc = (self.write_loc + 1) % self.N

    def read(self, wr):
        """Read from memory (according to section 3.1)"""
        return torch.matmul(wr.unsqueeze(1), self.memory).squeeze(1)
        
    def address(self, controller_output, w_prev):
        """NTM Addressing (according to section 3.3)
           both w_prev and ware Softmax weightings over rows 
           of the memory matrix with shapes (batch_size, memory_n)
        input:
            controller_output- (batch_size, controller_size)
            w_prev - The weighting produced in the previous time step
        output:
            w - new Softmax weighting over rows of the memory matrix
        """
        address_params = self.addresses(controller_output)

        k, beta, g, s, gamma = self.split_cols(address_params, 
                                               self.address_param_sizes)
        """
        k - The key vector (batch_size, memory_m) (a vector)
        beta - The key strength (focus) (batch_size, 1) (0,infinity)
        g - Scalar interpolation gate with w_prev (batch_size, 1) (0,1)
        s - Shift weighting (batch_size, memory_n) (sums to 1)
        gamma - Sharpen weighting scalar (batch_size, 1) (1,infinity)
        """
        beta = F.softplus(beta)
        g = torch.sigmoid(g)
        s = F.softmax(s, dim=1)
        gamma = 1 + F.softplus(gamma)
        # Content Addressing
        wc = self._similarity(k, beta)
        # Location Adressing
        wg = self._interpolate(w_prev, wc, g)
        w_hat = self._shift(wg, s)
        wr = self._sharpen(w_hat, gamma)
  
        return wr

    def split_cols(self, mat, lengths):
        """Split a 2D matrix to variable length columns."""
        assert mat.size()[1] == sum(lengths), "Lengths must be summed to num columns"
        l = np.cumsum([0] + lengths) # [ 0, 20, 21, 22, 25, 26]
        results = []
        for s, e in zip(l[:-1], l[1:]):  # 0 20, 20 21, ... 
            results += [mat[:, s:e]]
        return results
    
    def _similarity(self, k, beta):
        k = k.view(self.batch_size, 1, -1)
        w = F.softmax(beta * F.cosine_similarity(self.memory+1e-16, 
                                                 k+1e-16,dim=-1),dim=1)
        return w

    def _interpolate(self, w_prev, wc, g):
        return g * wc + (1 - g) * w_prev

    def convolve(self, w, s):
        """Circular convolution implementation."""
        assert s.size(0) == 3
        t = torch.cat([w[-1:], w, w[:1]])
        c = F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(-1) 
        # .view(-1) gets rid of the first two 1 dims inc
        return c
    
    def _shift(self, wg, s):
        result = torch.zeros(wg.size())
        for b in range(self.batch_size):
            result[b] = self.convolve(wg[b], s[b])
        return result

    def _sharpen(self, w_hat, gamma):
        w = w_hat ** gamma
        w = torch.div(w, torch.sum(w, dim=1).view(-1, 1) + 1e-16)
        return w
    

In [4]:
class Controller(nn.Module):
    """
    A Neural Turing Machine controller based on LSTM
    summarizes sequences into a hidden state that can
    be used to generate write weights (addressing)
    
    we use the same LSTM to generate write weights
    as we use to generate the vector that is stored
    in memory, this is probably a design flaw 
    """
    def __init__(self, memory_module, hidden_size, num_layers):
        
        super(Controller, self).__init__()

        self.memory_module = memory_module
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.summarizer = nn.LSTM(input_size=hidden_size,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            batch_first=True)
        
        self.summarizer_h = nn.Parameter(torch.randn(n_layers,1,emb_dim))
        self.summarizer_c = nn.Parameter(torch.randn(n_layers,1,emb_dim))
        
        self.recaller = nn.LSTM(input_size=hidden_size,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            batch_first=True)
        
        self.recaller_h = nn.Parameter(torch.randn(n_layers,1,emb_dim))
        self.recaller_c = nn.Parameter(torch.randn(n_layers,1,emb_dim))
        
        self.context = torch.randn(batch_size, 1, emb_dim)
        self.mask = (torch.ones(batch_size, 1, 1) == 1.)
        self.w = torch.zeros(batch_size, self.memory_module.N)
        self.w[:,0] = 1.0 # set reader attention at first spot in the memory

    def reset(self, batch_size):
        self.batch_size = batch_size
        h = self.summarizer_h.clone().repeat(1, batch_size, 1)
        c = self.summarizer_c.clone().repeat(1, batch_size, 1)
        self.summarizer_s = h, c 
        
        h = self.recaller_h.clone().repeat(1, batch_size, 1)
        c = self.recaller_c.clone().repeat(1, batch_size, 1)
        self.recaller_s = h, c 
        
        self.memory_module.reset(batch_size)

    def forward(self, x, prev_state):
        out, state = self.lstm(x, prev_state)
        return out, state
    
    def write2memory(self, encoding):
        """writes encoding to memory"""
        summary, self.summarizer_s  = self.summarizer(encoding, self.summarizer_s)
        last_h = summary[:,-1,:] # last layer output[:,-1,:] = end_token_hidden[-1,:,:]
        self.memory_module.write(last_h)

    def memory2context(self, encoding):
        """
        uses encoding to query the memory and update the context
        """
        query, self.recaller_s  = self.recaller(encoding, self.recaller_s)
        last_h = query[:,-1,:] # last layer output[:,-1,:] = end_token_hidden[-1,:,:]
        self.w = self.memory_module.address(last_h, self.w)
        context = self.memory_module.read(self.w)
        self.context = context.unsqueeze(1) # add the seq_len =1 time dimension

    def detach_memory(self):
        self.memory_module.memory = self.memory_module.memory.detach()
        self.w = self.w.detach()
        self.context = self.context.detach()
        self.mask = self.mask.detach()

In [5]:
class Transformer(nn.Module):
    
    def __init__(self, in_vocab_size, out_vocab_size, emb_dim, n_layers, 
                 heads, mem_slots, dropout):
        
        super().__init__()
        
        self.emb_dim = emb_dim
        dim_k = emb_dim // heads
        self.mem_slots = mem_slots
        
        self.memcoder = Decoder(in_vocab_size, emb_dim, n_layers, heads, dropout)
        self.decoder = Decoder(out_vocab_size, emb_dim, n_layers, heads, dropout)
        self.out = nn.Linear(emb_dim, out_vocab_size)
        
        self.ntm = NTM(mem_slots, emb_dim, emb_dim)
        self.controller = Controller(self.ntm, emb_dim, n_layers)
        
    def forward(self, in_toks, in_mask, out_toks, out_mask):
        self.controller.detach_memory()
        self.incoding = self.memcoder(in_toks, in_mask, self.controller.context, self.controller.mask)
        self.dout = self.decoder(out_toks, out_mask, self.incoding, in_mask)
        output = self.out(self.dout)
        return output

In [23]:
opt = Options(batchsize=1, device = torch.device("cpu"), epochs=50, lr=0.01, 
              max_len = 25, save_path = 'weights/ntm_weights')

data_iter, infield, outfield, opt = json2datatools(path='../../saved/memory.json', opt=opt)

emb_dim, n_layers, heads, mem_slots, dropout, batch_size = 8, 1, 3, 4, 0.001, 1

chloe = Transformer(len(infield.vocab), len(outfield.vocab), 
                    emb_dim, n_layers, heads, mem_slots, dropout)

chloe.controller.reset(batch_size)
load_subset_weights(chloe, opt)

In [24]:
def talk_to_model(input_str, model, opt, infield, outfield):
    '''
    input:
        input_str is a string, it is what you want to say to the dialogue model
        model is a encoder, decoder and a last layer linear transformation
        opt is an options object with the maximum length of the output sequence opt.max_len
        infield and outfield are the data.fields that store the vocabulary
    output:
        an output string response from the dialogue model
    '''
    model.eval()
    model.cpu()
    in_toks = string2tensor(input_str, infield) # string to tensor 
    in_mask = (in_toks != infield.vocab.stoi['<pad>']).unsqueeze(-2) #make input mask
    
    model.incoding = model.memcoder(in_toks, in_mask, model.controller.context, model.controller.mask)
    model.controller.memory2context(model.incoding)
    
    # Initialize decoder ouput as the start token decoder input 
    init_tok = outfield.vocab.stoi['<sos>'] # this is the integer for the start token
    decoder_input = torch.LongTensor([[init_tok]]) # start token to initiate the decoder
    
    for pos in range(opt.max_len):
        # make target mask, pos+1 cause pos starts at 0
        decoder_input_mask = nopeak_mask(size=pos+1, opt=opt) 
        model.dout = model.decoder(decoder_input, decoder_input_mask, model.incoding, in_mask)
        out = model.out(model.dout)
        softout = F.softmax(out, dim=-1) 

        distr = Categorical(probs=softout)
        action = distr.sample()[:,-1].unsqueeze(0) # sample from that distribution to get next token
        decoder_input = torch.cat((decoder_input, action), dim=1) 

        if outfield.vocab.itos[action] == '<eos>':
            de_str = ' '.join([outfield.vocab.itos[tok] for tok in decoder_input[0][1:-1]])
            return de_str
        
    de_str = ' '.join([outfield.vocab.itos[tok] for tok in decoder_input[0]])
    return de_str

In [28]:
load_subset_weights(chloe, opt)

print(talk_to_model("my name is bobo", chloe, opt, infield, outfield))

chloe.controller.write2memory(chloe.dout)
chloe.controller.memory2context(chloe.dout)
print(chloe.controller.memory_module.memory.data)
print(chloe.controller.w.data)
print(talk_to_model("what is my name?", chloe, opt, infield, outfield))

hi bobo !
tensor([[[ 0.1757, -0.0697,  0.3409,  0.0694, -0.1293,  0.0456,  0.3600,
           0.1633],
         [ 0.1502, -0.0755,  0.3560,  0.1028, -0.1073,  0.0203,  0.3796,
           0.1731],
         [ 0.1426, -0.0790,  0.3639,  0.1227, -0.1037,  0.0104,  0.3839,
           0.1745],
         [ 0.1360, -0.0829,  0.3770,  0.1473, -0.1023,  0.0115,  0.3865,
           0.1750]]])
tensor([[0.2476, 0.2475, 0.2507, 0.2542]])
you are chloe


In [231]:
load_subset_weights(chloe, opt)

conversation_list = [
{"listen":"my name is chloe", "reply":"hi chloe!"},
{"listen":"what is my name?", "reply":"you are chloe"},
{"listen":"my name is fluffy", "reply":"hey fluffy!"},
{"listen":"what is my name?", "reply":"fluffy pillow"},
{"listen":"my name is snuggles", "reply":"hello snuggles!"},
{"listen":"what is my name?", "reply":"snuggles the bunny"},
{"listen":"my name is bobo", "reply":"hi bobo!"},
{"listen":"what is my name?", "reply":"you are bobo"},
                    ]

sos_tok = torch.LongTensor([[outfield.vocab.stoi['<sos>']]]) 
eos_tok = torch.LongTensor([[outfield.vocab.stoi['<eos>']]]) 

chloe.train()
start = time.time()
best_loss = 100

opt.epochs = 40
opt.lr = 0.0005
optimizer = torch.optim.Adam(chloe.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.8, patience=4)

for epoch in range(opt.epochs):
    total_loss = 0
    for i in range(len(conversation_list)):
        
        # init optimizer
        optimizer.zero_grad()
        
        # prepare source and targets 
        listen_string = conversation_list[i]["listen"]
        reply_string = conversation_list[i]["reply"]
        listen_toks = string2tensor(listen_string, infield)
        reply_toks = string2tensor(reply_string, outfield)
        reply_start = torch.cat((sos_tok,reply_toks), dim=1) #teacher forcing 
        reply_labels = torch.cat((reply_toks,eos_tok), dim=1).contiguous().view(-1) #target
        listen_mask, reply_mask = create_masks(listen_toks, reply_start, opt)
        
        # forward pass and write decoder output to memory
        logits = chloe(listen_toks, listen_mask, reply_start, reply_mask)
        #chloe.incoding = model.memcoder(listen_toks, listen_mask, chloe.controller.context, chloe.controller.mask)
        #chloe.controller.memory2context(chloe.incoding)
        #chloe.dout = chloe.decoder(reply_start, reply_mask, self.incoding, in_mask)
        #output = self.out(self.dout)
        
        chloe.controller.write2memory(chloe.dout)
        chloe.controller.memory2context(chloe.dout)
        
        # calculate loss
        flat_logits = logits.view(-1, logits.size(-1))
        batch_loss = F.cross_entropy(flat_logits, reply_labels, ignore_index = opt.trg_pad)
        
        # calculate gradients
        batch_loss.backward()
        
        # update weights
        torch.nn.utils.clip_grad_norm_(chloe.parameters(), max_norm = 1.0) 
        optimizer.step()

        total_loss += batch_loss.item()

    epoch_loss = total_loss/len(conversation_list)
    scheduler.step(epoch_loss)

    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(chloe.state_dict(), opt.save_path)
        print("%dm: epoch %d loss = %.3f" %((time.time() - start)//60, 
                                        epoch, epoch_loss))
    
    total_loss = 0
    
print("finished")

0m: epoch 0 loss = 0.117
0m: epoch 1 loss = 0.073
0m: epoch 5 loss = 0.067
0m: epoch 6 loss = 0.043
0m: epoch 7 loss = 0.038
0m: epoch 10 loss = 0.036
0m: epoch 20 loss = 0.032
finished


In [19]:
load_subset_weights(chloe, opt)
test_list = [
    " my name is fluffy ",
    " what is my name? ",
    " my name is snuggles",
    " what is my name? ",
    " my name is bobo ",
    " what is my name? ",
    " my name is chloe",
    " what is my name? ",
]

for i in test_list:
    print(" > ", i, " > ",  talk_to_model(i,chloe,opt,infield,outfield))
    chloe.controller.write2memory(chloe.dout)
    chloe.controller.memory2context(chloe.dout)

 >   my name is fluffy   >  hey fluffy !
 >   what is my name?   >  fluffy pillow
 >   my name is snuggles  >  hello snuggles !
 >   what is my name?   >  snuggles the bunny
 >   my name is bobo   >  hi bobo !
 >   what is my name?   >  you are bobo
 >   my name is chloe  >  hi chloe !
 >   what is my name?   >  you are chloe
