# Metalearned Neural Memory with local memory updates

In [1]:
 import math, copy, sys, logging, json, time, random, os, string, pickle, re

import torch
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
#from torch.distributions import Categorical

from modules.MetaLearnNeuralMemory import FFMemoryLearned

# MNMp -The Memory Controller 

Imagine that an LSTM is the agent. LSTMs takes in a current state x_t, the previous hidden state h_t-1 and output the current hidden state h_t

Now imagine that our LSTM has a storage center for memory besides h_t, the 
`neural memory network` (NMN), and at each step it receives a current state x_t, the previous hidden state h_t-1 and a read out from the NMN v_t-1. 

Each x_t has shape (batch size, n_units), each readout vector v_t-1 has shape `(batch size, self.n_in_mem)`

Since the LSTM is initialized `torch.nn.LSTMCell(input_size, hidden_size, bias=True)` We concatenate the input x_t with the memory readout vector v_t-1

`self.lstm_l1 = nn.LSTMCell(n_units + n_in_mem, n_units)`

In the paper this LSTM controller is 

$$h_t = LSTM(x_t, v^{r}_{t-1}, h_{t-1})$$

The first time that the MNMp forward pass is used, the initial memory readout vector - self.v, and LSTM controller hidden states - self.h_lstm and self.c_lstm is initialized as zero vectors, it is then updated thereafter with each subsequent forward pass x_t.

In PyTorch the LSTMCell takes an input and a hidden state (hidden state, cell state) tuple, so this mathematical term appears as 

`self.h_lstm, self.c_lstm = self.lstm_l1(torch.cat([x, self.v], dim=1), (self.h_lstm, self.c_lstm))`

The hidden state output of the LSTM controller is placed through an affine transformation and tanh non-linearity to produce a large vector that is then separated into the interaction vectors.

$$[k^{r}_{t,1} . . . k^{r}_{t,H}; k^{w}_{t,1} . . .  k^{w}_{t,H}; v^{w}_{t,1}. . . v^{w}_{t,H}; B_{t}] = tanh(W_v h_t + b_v)$$

These interaction terms contain value vectors to write to memory, key vectors that if seen again in the future are meant to retreive these memorized values, and also key vectors to retreive memorys for immediate use.

Leaveingout the number of heads, each step of the LSTM agent generates two keys and one value. A read key to retreive memory, and a write key and write value to store memory.

In [None]:
class MNMp(nn.Module):

    def __init__(self, dim_hidden, n_heads = 4):
        
        """ dim_hidden is the hidden size of the LSTM controller,
            the Memory Network, and the interaction vectors
            n_heads is the number of interaction heads """

        super(MNMp, self).__init__()
        
        self.dim_hidden = dim_hidden
        self.n_heads = n_heads
        
        self.control = nn.LSTMCell(dim_hidden*2, dim_hidden)
        
        dim_concat_interact = dim_hidden*n_heads*3 + dim_hidden
        self.interaction = nn.Linear(dim_hidden, dim_concat_interact)
        self.memfunc = FFMemoryLearned(dim_hidden)
        self.kv_rate = nn.Linear(dim_hidden, 1)
        self.read_out = nn.Linear(dim_hidden+dim_hidden, dim_hidden)
        
        #self.v_r = None
        #self.h_lstm = None
        #self.c_lstm = None
        self.v_r = torch.zeros((1, self.dim_hidden)).float()
        self.h_lstm = torch.zeros((1, self.dim_hidden)).float()
        self.c_lstm = torch.zeros((1, self.dim_hidden)).float()
            
    def repeat_v_h_c(self, batch_size):
            
            #self.v_r = torch.zeros((batch_size, self.dim_hidden)).float()
            #self.h_lstm = torch.zeros((batch_size, self.dim_hidden)).float()
            #self.c_lstm = torch.zeros((batch_size, self.dim_hidden)).float()
            self.v_r = self.v_r.repeat(batch_size,1)
            self.h_lstm = self.h_lstm.repeat(batch_size,1)
            self.c_lstm = self.c_lstm.repeat(batch_size,1)
            
            if next(self.parameters()).is_cuda:
                self.v_r = self.v_r.cuda()
                self.h_lstm = self.h_lstm.cuda()
                self.c_lstm = self.c_lstm.cuda()
            
    def forward(self, x):
        """ the input must have shape (batch_size, emb_dim) because it will be 
        concatenated with self.v_r of the same shape """

        self.repeat_v_h_c(x.shape[0])
        x = x.squeeze(1)
        self.h_lstm, self.c_lstm = self.control(torch.cat([x, self.v_r], dim=1), 
                                                (self.h_lstm, self.c_lstm))
        
        int_vecs = torch.tanh(self.interaction(self.h_lstm))
        beta_, n_k_v = torch.split(int_vecs, 
                                   [self.dim_hidden, 
                                   self.dim_hidden*self.n_heads*3],
                                   dim=1)  
        
        beta = torch.sigmoid(self.kv_rate(beta_)) #(batch_size,1)
        n_k_v = n_k_v.view(n_k_v.shape[0], self.n_heads, -1).contiguous()
        k_w, v_w, k_r = torch.chunk(n_k_v, 3, dim=2)
        reconst_loss, reconst_loss_init = self.memfunc.update(k_w, v_w, 
                                                                beta_rate=beta)
        self.v_r = self.memfunc.read(k_r)
        h_lstm = self.read_out(torch.cat([self.h_lstm, self.v_r], dim=1))

        return h_lstm.unsqueeze(1), reconst_loss, reconst_loss_init 

## Perceptron Learning Rule

The perceptron learning rule is used to do fast gradient free updates to the neural memory network. This perceptron learning rule is explained well here: http://hagan.okstate.edu/4_Perceptron.pdf 

The perceptron learning rule allows you to change the weights of a linear transformation matrix in such a way as to nudge it's output closer to or away from a desired target or vector of targets, without having to calculate a mean squared error, and perform backpropagation. 

Consider the some layer of a feed forward neural network, activations "a" = relu(Wx+b).

"a" might have several dimensions, but the first dimension of "a" is only dependant on x and the first row of W. So if you consider W to be a single row and "a" to be a single activation, then without loss of generality, what applies for this row applies to the entire matrix W and the entire vector a. 

if `W = [-1, 1]` and `b = [-1]`, you can see by plotting `W[x_1, x_2] + b = 0` on a horizontal `x_1` by vertical `x_2` axis that this is a line with slope `1` that intersects the `x_1` axis at `-1`. 

<img src = 'files/grapher1.gif' width=400 height =400>

This is the decision boundary because at this boundary `Wx+b=0` and on either side of the boundary you are either positive or negative.  The point `p1 = x_1, x_2 = 2, 1` would be evaluated as `-2` if you plugged it into `Wx+b` and so would be zero after an activation function `f(a) = max(0,a)`. You can tell which side of the line above is the negative and positive regime by seeing which direction the vector `[-1, 1]` points. Moving `-1` along `x_1` means one step to the left, and `1` step upwards along `x_2` draws an arrow that point up and to the left. So it seems the point `[2,1]` is on the wrong side of this line if you want it to be positive. That direction that vector `W` points in, is the direction towards the side of the decision boundary that is positive in `Wx`. The opposite side of the boundary is negative in `Wx`. 

Suppose without calculating the gradient you want to change `W` such that the point `p1 = x_1, x_2 = 2, 1` will be classified as positive. if you add `p1` to `W` you get `W = [1, 2]`, and if you add the error to the bias, ie `error = target - activation = 1 - 0 = 1`, you get `error + b_old = -1 + 1 = b_new = 0`

<img src = 'files/grapher2.gif' width=400 height =400>

The new line points up and to the right and will now classify p1 correctly. `Wx + b = 2+2+1 = 5`. Notice that b also changes in the correct direction, if you didnt change b, the resulting decision boundary `[1, 2]x - 1 = 0` would have been a negatively sloped line that intersects the `x_1` axis at `x_1 = 1`, whereas now that `b = 1`, the intersect is at `0`. This update has not only rotated the vector `W` to point towards `p1` but shifted the decision boundary downwards pulling the boundary away from `p1` so that `p1` is farther within the region of positivity and therefore more positive.

The perceptron learning rule formula is as follows:

$$ W_t := W_{t-1} + (target - activation) \otimes input^{T} $$

$$ b_t = b_{t-1} + (target - activation) $$

`t` is the update iteration timestep. The circle with an x inside is the outer product, this fits with the shape of W since if the activation is m-dimensional, and input is n-dimensional, the W is shape mxn and the bias is shape mx1. Therefore the the update to W must be the same shape as W which is the shape you would get if you performed an outer product between (target - activation) x (input^T) = (update) which has shape (mx1)(1xn) = (mxn)

 In our example (target - activation) x (input^T) = (1 - 0) x `[[2],[1]]^T` = `[2,1]`

For us the target vector is not some kind of supervised label, instead it is a target that the `backward feedback prediction functions (BFPF)` learns to come up with in order to update all the layers of the `neural memory network` (NMN) in such a way as to bind the many to many mappings of key value pairs.

Our memory is stored in a multi-layer feed forward NMN

$$ v^{r}_{t,i} = {NMN}(k^{r}_{t,i})$$

$$ z'^{l} = {BFPF}^{l}(v^{w}_{}t) $$

v^w is the memorywe want to remember, ie write to memory. k^r is the thought that occurs in real time that prompts us to retrieve v^r. 
v^r is the retreived memory for timestep t and i-th head, i-th because there might be several memories we want to retreive, t because we might be performing this retrieval for each step in a task.

Lets re-write the above perceptron rule in a format that is analoguous to the layer by layer update rule in [Metalearned Neural Memory](https://arxiv.org/abs/1907.09720)

$$ W_t := W_{t-1} - (activation - target) \otimes x^{T} $$

M^l is the weight matrix for layer l of the NMN, t is the new matrix after update, t-1 is the old matrix before update, B is a learning rate, z^l is the current activation for layer l prior to update, z'^l is the target activation for layer l. z^{l-1} are activations of the previous layer

$$ M^{l}_{t} :=  M^{l}_{t-1} - B^{l}_{t}(z^{l} - z'^{l}) {z^{l-1}}^{T} $$