In [47]:
import math, time, os, datetime, shutil, pickle

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
import torch.onnx

import import_ipynb
from Elements import MultiHeadAttention, Norm, FeedForward
%matplotlib inline

<img src="../saved/images/rmc.png" height=600 width=800>

In [122]:
def initial_memory(mem_slots, mem_size, batch_size):
    """Creates the initial memory.
    We should ensure each row of the memory is initialized to be unique,
    so initialize the matrix to be the identity. We then pad or truncate
    as necessary so that init_state is of size(mem_slots, mem_size).
    Args:
      mem_slots: rows in memory matrix
      mem_size: columns in memory matrix
      batch_size: batch size
    Returns:
      init_state: A truncated or padded identity matrix of size (batch_size,mem_slots, mem_size)
    """
    with torch.no_grad():
        init_state = torch.stack([torch.eye(mem_slots) for _ in range(batch_size)])

    # Pad the matrix with zeros.
    if mem_size > mem_slots:
      difference = mem_size - mem_slots
      pad = torch.zeros((batch_size, mem_slots, difference))
      init_state = torch.cat([init_state, pad], -1)
    # Truncation. Take the first `self._mem_size` components.
    elif mem_size < mem_slots:
      init_state = init_state[:, :, :mem_size]
    return init_state

mem_slots=4
mem_size=8
batch_size=1
memory = initial_memory(mem_slots=mem_slots,mem_size=mem_size,batch_size=batch_size)
print(memory, memory.shape)

tensor([[[1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0.]]]) torch.Size([1, 4, 8])


In [78]:
input_vector = torch.randn((batch_size,mem_size))
print(input_vector.shape)
# .unsqueeze(1) because needs to have same dimensions as memory except along dimension we are concatenating 
memory_plus_input = torch.cat([memory, input_vector.unsqueeze(1)], dim=-2) 
print(memory_plus_input, memory_plus_input.shape)

torch.Size([1, 8])
tensor([[[ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000],
         [ 0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000],
         [ 0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,
           0.0000],
         [-1.3526,  0.1950,  0.5727, -0.2224, -0.1243, -0.9050, -2.7317,
          -0.3789]]]) torch.Size([1, 5, 8])


In [79]:
updatememory = MultiHeadAttention(num_heads=3, emb_dim=8, dim_k=4, dropout=0.0)
new_memory, scores = updatememory(memory, memory_plus_input, memory_plus_input)
print(new_memory.shape)

torch.Size([1, 4, 8])


In [96]:
NormalizeMemory1 = Norm(emb_dim=8)
new_mem_norm = NormalizeMemory1(new_memory + memory)
new_mem_norm.shape

torch.Size([1, 4, 8])

In [97]:
MLP = FeedForward(emb_dim=8, ff_dim=16, dropout=0.2)
mem_mlp = MLP(new_mem_norm)
NormalizeMemory2 = Norm(emb_dim=8)
new_mem_norm2 = NormalizeMemory2(mem_mlp + new_mem_norm)
new_mem_norm2.shape

torch.Size([1, 4, 8])

In [98]:
print(memory.shape, input_vector.shape)
input_stack = torch.stack([input_vector for _ in range(mem_slots)], dim=1)
print(input_stack.shape)
h_old_x = torch.cat([memory, input_stack], dim = -1)
print(h_old_x.shape)

torch.Size([1, 4, 8]) torch.Size([1, 8])
torch.Size([1, 4, 8])
torch.Size([1, 4, 16])


In [121]:
ZGATE = nn.Linear(mem_size*2, mem_size)
z_t = torch.sigmoid(ZGATE(h_old_x)) # (batch size, memory slots, memory size)
print(z_t.shape)
print(ZGATE.weight.shape)

torch.Size([1, 4, 8])
torch.Size([8, 16])


In [100]:
new_memory = (1 - z_t)*memory + z_t*new_mem_norm2

$$z_t = \sigma(W_z \dot [m_{t - 1},x_t])$$

$$m_{t} = (1 - z_t) \circ m_{t - 1} + z_t \circ m_{t}$$

In [None]:
class RelMemCore(nn.Module):
    
    def __init__(self, mem_slots, mem_size, emb_dim, num_heads, dim_k=None, dropout=0.1):
        
        self.mem_slots = mem_slots
        self.mem_size = mem_size
        self.num_heads = num_heads
        self.dropout = dropout
        self.dim_k = dim_k if dim_k else emb_dim // num_heads
        self.attn_mem_update = MultiHeadAttention(self.num_heads,self.emb_dim,self.dim_k,self.dropout)
        self.NormalizeMemory1 = Norm(emb_dim)
        
    def initial_memory(self, mem_slots, mem_size, batch_size):
        """Creates the initial memory.
        We should ensure each row of the memory is initialized to be unique,
        so initialize the matrix to be the identity. We then pad or truncate
        as necessary so that init_state is of size(mem_slots, mem_size).
        Args:
          mem_slots: rows in memory matrix
          mem_size: columns in memory matrix
        Returns:
          init_state: A truncated or padded identity matrix of size (mem_slots, mem_size).
        """
        with torch.no_grad():
            init_mem = torch.stack([torch.eye(mem_slots) for _ in range(batch_size)])

        # Pad the matrix with zeros.
        if mem_size > mem_slots:
          difference = mem_size - mem_slots
          pad = torch.zeros((batch_size, mem_slots, difference))
          init_mem = torch.cat([init_mem, pad], -1)
        # Truncation. Take the first `self._mem_size` components.
        elif mem_size < mem_slots:
          init_mem = init_mem[:, :, :mem_size]
        return init_mem
        
    def update_memory(self, input_vector, cur_memory):
        '''
        inputs
         input_vector (batch_size, mem_size)
         cur_memory - current_memory (batch_size, mem_slots, mem_size)
        '''
        mem_plus_input = torch.cat([cur_memory, input_vector.unsqueeze(1)], dim=-2) 
        new_mem, scores = self.attn_mem_update(cur_memory, mem_plus_input, mem_plus_input)
        new_mem_norm = self.NormalizeMemory1(new_mem + cur_memory)