In [47]:
import math, time, os, datetime, shutil, pickle

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
import torch.onnx

import import_ipynb
from Elements import MultiHeadAttention, Norm, FeedForward
%matplotlib inline

From a big picture perspective, the transformer we have built so far, just allows us to map an input sequence to an output sequence. All the layers we put together, essentially are for the purpose of allowing us to do this action -> reaction task. When you are talking to someone, your responses are not just based on what they said last, it is based on what you have said earlier in the conversation, your memory of past conversations, what your currently impression of the person is, and other knowledge of this person and their relationship with the world. This requires us to be able to hold an internal state that persists through time, a memory. In your brain, you have just neurons and the cells that support those neurons. Here we are building a form of neural memory. 

In the work by [Santoro, A. et al](https://arxiv.org/pdf/1806.01822.pdf) Santoro and collegues built on past work on building memory into neural networks and devised what they called the Relational Memory Core. The image below summarizes this type of neural memory. 

<img src="../saved/images/rmc.png" height=600 width=800>

In (b) of this image notice the 4 x 6 matrix of grey dots labeled memory. This matrix represents a storage of information from the past, past memories. The light grey vector labeled input represents new information from the current time point that we wish to incorporate into our past memories to save them for the future. (c) describes a multi-headed attention mechanism for updating the past memories that is the same as the multiheaded attention we have already learned in the Transformer. In this example the past memories plays the role of the q sequence, the concatenation of the past memories to the input as a new row plays the role of the k and v sequence. The <font color='green'>weights matrix (q_seq_len, k_seq_len)</font> plays the role of the score matrix. The output of the attention mechanism is the updated memory. 

In [151]:
teaching = False

def initial_memory(mem_slots, mem_size, batch_size):
    """Creates the initial memory.
    We should ensure each row of the memory is initialized to be unique,
    so initialize the matrix to be the identity. We then pad or truncate
    as necessary so that init_state is of size(mem_slots, mem_size).
    Args:
      mem_slots: rows in memory matrix
      mem_size: columns in memory matrix
      batch_size: batch size
    Returns:
      init_state: A truncated or padded identity matrix of size (batch_size,mem_slots, mem_size)
    """
    with torch.no_grad():
        init_state = torch.stack([torch.eye(mem_slots) for _ in range(batch_size)])

    # Pad the matrix with zeros.
    if mem_size > mem_slots:
      difference = mem_size - mem_slots
      pad = torch.zeros((batch_size, mem_slots, difference))
      init_state = torch.cat([init_state, pad], -1)
    # Truncation. Take the first `self._mem_size` components.
    elif mem_size < mem_slots:
      init_state = init_state[:, :, :mem_size]
    return init_state


if teaching:
    mem_slots=4
    mem_size=8
    batch_size=1
    memory = initial_memory(mem_slots=mem_slots,mem_size=mem_size,batch_size=batch_size)
    print(memory, memory.shape)

In [78]:
if teaching:
    input_vector = torch.randn((batch_size,mem_size))
    print(input_vector.shape)
    memory_plus_input = torch.cat([memory, input_vector.unsqueeze(1)], dim=-2) 
    print(memory_plus_input, memory_plus_input.shape)

torch.Size([1, 8])
tensor([[[ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000],
         [ 0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000],
         [ 0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,
           0.0000],
         [-1.3526,  0.1950,  0.5727, -0.2224, -0.1243, -0.9050, -2.7317,
          -0.3789]]]) torch.Size([1, 5, 8])


In [79]:
if teaching:
    updatememory = MultiHeadAttention(num_heads=3, emb_dim=8, dim_k=4, dropout=0.0)
    new_memory, scores = updatememory(memory, memory_plus_input, memory_plus_input)
    print(new_memory.shape)
    NormalizeMemory1 = Norm(emb_dim=8)
    new_mem_norm = NormalizeMemory1(new_memory + memory)
    new_mem_norm.shape

torch.Size([1, 4, 8])


In [97]:
if teaching:
    MLP = FeedForward(emb_dim=8, ff_dim=16, dropout=0.2)
    mem_mlp = MLP(new_mem_norm)
    NormalizeMemory2 = Norm(emb_dim=8)
    new_mem_norm2 = NormalizeMemory2(mem_mlp + new_mem_norm)
    new_mem_norm2.shape

torch.Size([1, 4, 8])

In [98]:
if teaching:
    print(memory.shape, input_vector.shape)
    input_stack = torch.stack([input_vector for _ in range(mem_slots)], dim=1)
    print(input_stack.shape)
    h_old_x = torch.cat([memory, input_stack], dim = -1)
    print(h_old_x.shape)
    ZGATE = nn.Linear(mem_size*2, mem_size)
    z_t = torch.sigmoid(ZGATE(h_old_x)) # (batch size, memory slots, memory size)
    print(z_t.shape)
    print(ZGATE.weight.shape)
    new_memory = (1 - z_t)*memory + z_t*new_mem_norm2

torch.Size([1, 4, 8]) torch.Size([1, 8])
torch.Size([1, 4, 8])
torch.Size([1, 4, 16])


$$z_t = \sigma(W_z \dot [m_{t - 1},x_t])$$

$$m_{t} = (1 - z_t) \circ m_{t - 1} + z_t \circ m_{t}$$

In [138]:
class RelMemCore(nn.Module):
    
    def __init__(self, mem_slots, mem_size, num_heads, dim_k=None, dropout=0.1):
        super(RelMemCore, self).__init__()
        self.mem_slots = mem_slots
        self.mem_size = mem_size
        self.num_heads = num_heads
        self.dropout = dropout
        self.dim_k = dim_k if dim_k else self.mem_size // num_heads
        self.attn_mem_update = MultiHeadAttention(self.num_heads,self.mem_size,self.dim_k,self.dropout)
        self.normalizeMemory1 = Norm(self.mem_size)
        self.normalizeMemory2 = Norm(self.mem_size)
        self.MLP = FeedForward(self.mem_size, ff_dim=self.mem_size*2, dropout=dropout)
        self.ZGATE = nn.Linear(self.mem_size*2, self.mem_size)
        
    def initial_memory(self, batch_size):
        """Creates the initial memory.
        We should ensure each row of the memory is initialized to be unique,
        so initialize the matrix to be the identity. We then pad or truncate
        as necessary so that init_state is of size (mem_slots, mem_size).
        Args:
          mem_slots: rows in memory matrix
          mem_size: columns in memory matrix
        Returns:
          init_state: A truncated or padded identity matrix of size (mem_slots, mem_size).
        """
        with torch.no_grad():
            init_mem = torch.stack([torch.eye(self.mem_slots) for _ in range(batch_size)])

        # Pad the matrix with zeros.
        if self.mem_size > self.mem_slots:
          difference = self.mem_size - self.mem_slots
          pad = torch.zeros((batch_size, self.mem_slots, difference))
          init_mem = torch.cat([init_mem, pad], -1)
        # Truncation. Take the first `self._mem_size` components.
        elif self.mem_size < self.mem_slots:
          init_mem = init_mem[:, :, :self.mem_size]
        return init_mem
        
    def update_memory(self, input_vector, cur_memory):
        '''
        inputs
         input_vector (batch_size, mem_size)
         cur_memory - current_memory (batch_size, mem_slots, mem_size)
        output
         new_memory (batch_size, mem_slots, mem_size)
        '''
        mem_plus_input = torch.cat([cur_memory, input_vector.unsqueeze(1)], dim=-2) 
        new_mem, scores = self.attn_mem_update(cur_memory, mem_plus_input, mem_plus_input)
        new_mem_norm = self.normalizeMemory1(new_mem + cur_memory)
        mem_mlp = self.MLP(new_mem_norm)
        new_mem_norm2 = self.normalizeMemory2(mem_mlp + new_mem_norm)
        input_stack = torch.stack([input_vector for _ in range(self.mem_slots)], dim=1)
        h_old_x = torch.cat([cur_memory, input_stack], dim = -1)
        z_t = torch.sigmoid(self.ZGATE(h_old_x)) # (batch size, memory slots, memory size)
        new_memory = (1 - z_t)*memory + z_t*new_mem_norm2
        return new_memory

In [139]:
rmc = RelMemCore(mem_slots=4, mem_size=8, num_heads=3)
cur_mem = rmc.initial_memory(batch_size=1)
input_vector = torch.randn((batch_size,mem_size))
new_memory = rmc.update_memory(input_vector, cur_mem)

In [141]:
print(cur_mem, cur_mem.shape)
print(new_memory, new_memory.shape)

tensor([[[1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0.]]]) torch.Size([1, 4, 8])
tensor([[[ 1.4589, -0.5712, -0.2194, -0.2879, -0.3375,  0.1780, -0.1343,
           0.3922],
         [ 0.4755,  0.9693, -0.3751, -0.4355, -0.6177,  0.1510, -0.4546,
           0.7467],
         [ 0.3362, -0.7275,  1.2541, -0.3972, -0.4413,  0.0264, -0.1823,
           0.5694],
         [ 0.4066, -0.6822, -0.2506,  1.2347, -0.4806,  0.2021, -0.2281,
           0.5219]]], grad_fn=<AddBackward0>) torch.Size([1, 4, 8])


In [150]:
for name, param in rmc.named_parameters():
    if param.requires_grad:
        print(name, param.shape)

attn_mem_update.q_linear.weight torch.Size([6, 8])
attn_mem_update.q_linear.bias torch.Size([6])
attn_mem_update.k_linear.weight torch.Size([6, 8])
attn_mem_update.k_linear.bias torch.Size([6])
attn_mem_update.v_linear.weight torch.Size([6, 8])
attn_mem_update.v_linear.bias torch.Size([6])
attn_mem_update.out.weight torch.Size([8, 6])
attn_mem_update.out.bias torch.Size([8])
normalizeMemory1.alpha torch.Size([8])
normalizeMemory1.bias torch.Size([8])
normalizeMemory2.alpha torch.Size([8])
normalizeMemory2.bias torch.Size([8])
MLP.linear_1.weight torch.Size([16, 8])
MLP.linear_1.bias torch.Size([16])
MLP.linear_2.weight torch.Size([8, 16])
MLP.linear_2.bias torch.Size([8])
ZGATE.weight torch.Size([8, 16])
ZGATE.bias torch.Size([8])


In [149]:
for parameter in rmc.parameters():
    print(parameter.shape)

torch.Size([6, 8])
torch.Size([6])
torch.Size([6, 8])
torch.Size([6])
torch.Size([6, 8])
torch.Size([6])
torch.Size([8, 6])
torch.Size([8])
torch.Size([8])
torch.Size([8])
torch.Size([8])
torch.Size([8])
torch.Size([16, 8])
torch.Size([16])
torch.Size([8, 16])
torch.Size([8])
torch.Size([8, 16])
torch.Size([8])
