# Relational Memories
--- 
Implementation based on details list in [Relational recurrent neural networks](https://arxiv.org/pdf/1806.01822.pdf)

In [1]:
import math
import torch
import torch.nn as nn

In [2]:
NUM_MEM = 10
MEM_DIM = 256
H = 8

M = torch.randn(NUM_MEM,MEM_DIM)
WQ, WK, WV = [torch.randn(H,MEM_DIM//H,MEM_DIM//H) for _ in range(3)]

In [3]:
def split_heads(x, num_heads=8):
    """Split x such to add an extra num_heads dimension"""
    if len(x.shape) != 2:
      raise ValueError("Input must have rank 2")
    shape = x.shape
    return x.view(shape[0], num_heads, shape[1]//num_heads).permute(0, 1, 2)

## Allowing memories to interact using multi-head dot product attention

In [4]:
M_h = split_heads(M,H)

$Q = M W^{q}; Q = M W^{k}; Q = M W^{v}$  

In [5]:
Q, K, V = [torch.einsum('hmd,shm->shd',[W,M_h]) for W in [WQ,WK,WV]]

$A_{\theta}(M) = softmax(\frac{QK^{T}}{\sqrt d_{k}})Q; \theta = ( W^{q}, W^{k}, W^{v}) $

In [6]:
_M_ = torch.matmul((torch.matmul(Q, K.transpose(1,2)) / math.sqrt(MEM_DIM//H)), V)
_M_ = _M_.view(-1,MEM_DIM)

## Encoding new memories  
$A_{\theta}(M) = softmax(\frac{M W^{q} ([M;x] W^{k})^{T}}{\sqrt d_{k}})[M;x]W^{q}; \theta = ( W^{q}, W^{k}, W^{v}) $

In [7]:
x = torch.randn(1,MEM_DIM)
M_x = torch.cat([M,x],dim=0)

M_h_x = split_heads(M_x,H)
M_h = split_heads(M,H)

Q = torch.einsum('hmd,shm->shd',[WQ,M_h])
K = torch.einsum('hmd,shm->shd',[WK,M_h_x])
V = torch.einsum('hmd,shm->shd',[WV,M_h_x])

_M_ = torch.einsum('sht,thf->shf',[torch.einsum('shf,thf->sht',[Q,K]) / math.sqrt(MEM_DIM//H), V])

_M_ = _M_.contiguous().view(-1,MEM_DIM)

_M_.shape

torch.Size([10, 256])

# Memory Module

In [8]:
class Memory(torch.nn.Module):
    def __init__(self,num_mem,mem_dim,H=8):
        super(Memory, self).__init__()
        self.H = H
        self.MEM_DIM = mem_dim
        self.M = torch.randn(NUM_MEM,MEM_DIM)
        self.WQ, self.WK, self.WV = [torch.randn(H,mem_dim//H,mem_dim//H) for _ in range(3)]
    
    def split_heads(x):
        """Split x such to add an extra num_heads dimension"""
        return x.view(x.shape[0], self.H, x.shape[1]//self.H).permute(0, 1, 2)

    def project_memory(self,W,M):
        return torch.einsum('hmd,shm->shd',[W,M])
    
    def encode_memory(self,x):
        # concatenat new input to memory matrix
        M_x = torch.cat([self.M,x],dim=0)
        
        # split memory for MHDPA
        M_h_x = split_heads(M_x)
        M_h = split_heads(self.M)
        
        # MHDPA using new input
        Q = self.project_memory(self.WQ,M_h) 
        K = self.project_memory(self.WK,M_h_x) 
        V = self.project_memory(self.WV,M_h_x) 
        
        _M_ = torch.einsum('sht,thf->shf',[torch.einsum('shf,thf->sht',[Q,K]) / math.sqrt(MEM_DIM//H), V])
        
        # reshape memories
        _M_ = _M_.contiguous().view(-1,self.MEM_DIM)
        
        return _M_
    
    def forward(self,x):
        return self.encode_memory(x)
        

In [9]:
mem = Memory(NUM_MEM,MEM_DIM,H)
mem(x).shape

torch.Size([10, 256])