# Relational Memories
--- 
Implementation based on details list in [Relational recurrent neural networks](https://arxiv.org/pdf/1806.01822.pdf)

In [1]:
import math
import torch
import torch.nn as nn

In [2]:
NUM_MEM = 10
MEM_DIM = 256
H = 8

M = torch.randn(NUM_MEM,MEM_DIM)
WQ, WK, WV = [nn.Linear(MEM_DIM//H,MEM_DIM//H) for _ in range(3)]

In [3]:
def split_heads(x, num_heads=8):
    """Split x such to add an extra num_heads dimension"""
    if len(x.shape) != 2:
      raise ValueError("Input must have rank 2")
    shape = x.shape
    return x.view(shape[0], num_heads, shape[1]//num_heads).permute(0, 1, 2)

## Allowing memories to interact using multi-head dot product attention

In [4]:
M_h = split_heads(M,H)

$Q = M W^{q}; Q = M W^{k}; Q = M W^{v}$  

In [5]:
Q, K, V = WQ(M_h), WK(M_h), WV(M_h)

$A_{\theta}(M) = softmax(\frac{QK^{T}}{\sqrt d_{k}})Q; \theta = ( W^{q}, W^{k}, W^{v}) $

In [6]:
_M_ = torch.matmul((torch.matmul(Q, K.transpose(1,2)) / math.sqrt(MEM_DIM//H)), V)

In [7]:
_M_ = _M_.view(-1,MEM_DIM)

## Encoding new memories  
$A_{\theta}(M) = softmax(\frac{M W^{q} ([M;x] W^{k})^{T}}{\sqrt d_{k}})[M;x]W^{q}; \theta = ( W^{q}, W^{k}, W^{v}) $

In [8]:
x = torch.randn(1,MEM_DIM)
M_x = torch.cat([_M_,x],dim=0)

M_h_x = split_heads(M_x,H)
M_h = split_heads(M,H)

Q, K, V = WQ(M_h), WK(M_h_x), WV(M_h_x)
_M_ = torch.einsum('sht,thf->shf',[torch.einsum('shf,thf->sht',[Q,K]) / math.sqrt(MEM_DIM//H), V])

In [9]:
_M_.shape

torch.Size([10, 8, 32])

# Memory Module

In [10]:
class Memory(torch.nn.Module):
    def __init__(self):
        pass
    
    def forward(self):
        pass