## Methods
###  Outer product attention (OPA)
Outer product attention (OPA) is a natural extension of the query-key-value dot product attention (Vaswani et al. 2017). Dot product attention (DPA) for single query $q$ and $n_{kv}$ pairs of key-value can be formulated as follows   
<font size="4">$A^o(q, K, V) = \sum_{i=1}^{n_{kv}}S(q.k_i)v_i$</font>  
Where $A^o ∈ R^{d_v}, v_i ∈ R^{d_v}$ and  $q,k_i ∈ R^{qk}$

**Here new outer product attention is proposed.**   
<font size="4">$A^⊗(q, K, V) = \sum_{i=1}^{n_{kv}}F(q⊙k_i)⊗v_i$</font>  
Where $A^⊗ ∈ R^{d_{qk} x d_v}, v_i ∈ R^{d_v}$ and  $q,k_i ∈ R^{qk}$  
And ⊙ is element-wise multiplication, ⊗ is outer product and $F$ is chosen as element-wise $tanh$ function.  
A crucial difference between DPA (Dot Product Attention) and OPA is that while the former retrieves an attended item $A^o$ , the latter forms a relational representation $A^⊗$.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
def op_att(q, k, v):
    print(q.size())
    print(k.size()) # [b x s x  k]
    print(v.size()) # [b x s x v]
    qq = q.unsqueeze(2).repeat(1, 1, k.shape[1], 1) # [b x s x s x k]
    kk = k.unsqueeze(1).repeat(1, q.shape[1], 1, 1) # [b x s x s x v]
    output = torch.matmul(F.tanh(qq*kk).unsqueeze(4), v.unsqueeze(1).repeat(1, q.shape[1], 1, 1).unsqueeze(3))  # BxNXNxd_kq BxNxNxd_v --> BxNXNxd_kqxd_v
    # print(output.shape)
    output = torch.sum(output, dim=2)  # BxNxd_kqxd_v
    # print(output.shape)
    return output

In [13]:
a = torch.rand(5, 10, 7)
b = torch.rand(5,  10, 7)
v = torch.rand(5, 10, 7)
aa = a.unsqueeze(2).repeat(1, 1, a.shape[1], 1) # [b x s x s x k]
bb = b.unsqueeze(1).repeat(1, b.shape[1], 1, 1) # [b x s x s x v]
print(aa.size(), bb.size())

torch.Size([5, 10, 10, 7]) torch.Size([5, 10, 10, 7])


In [15]:
output = torch.matmul(F.tanh(aa*bb).unsqueeze(4), v.unsqueeze(1).repeat(1, a.shape[1], 1, 1).unsqueeze(3)) 
output.size()

torch.Size([5, 10, 10, 7, 7])

![SAM](Images/SAM.PNG)

$M^r$ -Read As relationships stored in $M^r$ are represented as associative memories, the relational memory can be read to reconstruct previously seen item.
<font size="4">$v^r_t = softmax(f_3(x_t)^T)M^r_{t-1}f_2(x_t)$</font>  
Where $f_3$ is feed-forward neural networks that outputs a $n_q$-dimentional vector. The read value provides an additional input coming from the previous state of $M^r$ to relational construction process as follows.  
<font size="4">$M^r_t = M^r_{t-1} + \alpha_1SAM_θ(M^i_t + \alpha_2v^r_t⊗f_2(x_2))$</font>  
Where $\alpha_1$ and $\alpha_2$ are blending hyper-parameters.  
$M^i$-Read, $M^r$-Write: SAM is used to read from $M^i$ and construct a condidate relational memory, which is simply added to the previous relational memory to perform the relational update.  
The input for SAM is the combination of the current item memory $M^i_t$ and the association between the extracted item from the previous relational memory $v^r_t$ and the current input data $x_t$.

In [4]:
class MLP(nn.Module):
    def __init__(self, in_dim=28*28,  out_dim=10, hid_dim=-1, layers=1):
        super(MLP, self).__init__()
        self.layers = layers
        if hid_dim<=0:
            self.layers=-1
        if self.layers<0:
            hid_dim=out_dim
        self.fc1 = nn.Linear(in_dim, hid_dim)
        # linear layer (n_hidden -> hidden_2)
        if self.layers>0:
            self.fc2h = nn.ModuleList([nn.Linear(hid_dim, hid_dim)]*self.layers)
        # linear layer (n_hidden -> 10)
        if self.layers>=0:
            self.fc3 = nn.Linear(hid_dim, out_dim)


    def forward(self, x):
        o = self.fc1(x)
        if self.layers>0:
            for l in range(self.layers):
                o = self.fc2h[l](o)
        if self.layers >= 0:
            o = self.fc3(o)
        return o

In STM, at every timestep, the item memory $M^i_t$ is updated with new input $x_t$ using gating mechanisms as follows.  
<font size="5"> $M^i_t = F_t(M^i_{t-1}, x_t)⊙M^i_{t-1}+I_t(M^i_{t-1},x_t)⊙X_t$  </font>  
Where $F_t$ and $I_t$ are the input and forget gates, respectively  
<font size="3">$F_t(M^i_{t-1})=W_Fx_t + U_Ftanh(M^i_{t-1}) + b_F$</font>  
<font size="3">$I_t(M^i_{t-1}, x_t) = W_Ix_t + U_Itanh(M^i_{t-1}) + b_I$</font>  
Here, $W_F, U_F, W_I ∈ R^{dxd}$ are parametric weights and $b_F$, $b_I$ ∈ R are the biases and + is broadcasted if needed

**$M^r$-Read** As relationships stored in $M^r$ are represented as associative memories, the relational memory can be read to reconstruct previously seen items.  
<font size="3">$v^r_t = softmax(f_3(x_t)^T)M^r_{t-1}f_2$</font>  
Where $f_3$ is feed-forward neural network that outputs $n_q$-dimentional vector. The read value provides an additional input comming from the previous state of $M^r$ to relational construction process.

In [38]:
class STM(nn.Module):
    def __init__(self, input_size, output_size, step = 1, num_slot=8,
                 mlp_size = 128, slot_size = 96, rel_size = 96,
                 out_att_size=64, rd=True,
                 init_alphas=[None,None,None],
                 learn_init_mem=True, mlp_hid=-1):
        super(STM, self).__init__()
        self.mlp_size = mlp_size # 128
        self.slot_size = slot_size # 96
        self.rel_size = rel_size # 96
        self.rnn_hid = slot_size # 96
        self.num_slot = num_slot # 8
        self.step = step # 1
        self.rd = rd # True
        self.learn_init_mem = learn_init_mem # True

        self.out_att_size = out_att_size # 64
        
        #==================== create qkv attention projectors for each step ==============

        self.qkv_projector = nn.ModuleList([nn.Linear(slot_size, num_slot*3)]*step) # [96 -> 24]
        self.qkv_layernorm = nn.ModuleList([nn.LayerNorm([slot_size, num_slot*3])]*step)
    
        #=================== create alpha values =========================
        if init_alphas[0] is None:
            self.alpha1 = [nn.Parameter(torch.zeros(1))] * step
            for ia, a in enumerate(self.alpha1):
                setattr(self, 'alpha1' + str(ia), self.alpha1[ia])
        else:
            self.alpha1 = [init_alphas[0]]* step

        if init_alphas[1] is None:
            self.alpha2 = [nn.Parameter(torch.zeros(1))] * step
            for ia, a in enumerate(self.alpha2):
                setattr(self, 'alpha2' + str(ia), self.alpha2[ia])
        else:
            self.alpha2 = [init_alphas[1]] * step

        if init_alphas[2] is None:
            self.alpha3 = [nn.Parameter(torch.zeros(1))] * step
            for ia, a in enumerate(self.alpha3):
                setattr(self, 'alpha3' + str(ia), self.alpha3[ia])
        else:
            self.alpha3 = [init_alphas[2]] * step


        self.input_projector = MLP(input_size, slot_size, hid_dim=mlp_hid)  # [feat_size x 96, -1]
        self.input_projector2 = MLP(input_size, slot_size, hid_dim=mlp_hid)  # [feat_size x 96, -1]
        self.input_projector3 = MLP(input_size, num_slot, hid_dim=mlp_hid)  # [feat_size x 8, -1]


        self.input_gate_projector = nn.Linear(self.slot_size, self.slot_size*2) # [96  x 184]
        self.memory_gate_projector = nn.Linear(self.slot_size, self.slot_size*2) # [96  x 184]
        
        # trainable scalar gate bias tensors
        self.forget_bias = nn.Parameter(torch.tensor(1., dtype=torch.float32))
        self.input_bias = nn.Parameter(torch.tensor(0., dtype=torch.float32))

        self.rel_projector = nn.Linear(slot_size * slot_size, rel_size) # [96 x 96, 96]
        self.rel_projector2 = nn.Linear(num_slot * slot_size, slot_size) # [8 x 96, 96]
        self.rel_projector3 = nn.Linear(num_slot * rel_size, out_att_size) # [8 x 96, 64]

        self.mlp = nn.Sequential(
            nn.Linear(out_att_size, self.mlp_size),
            nn.ReLU(),
            nn.Linear(self.mlp_size, self.mlp_size),
            nn.ReLU(),
        )

        self.out = nn.Linear(self.mlp_size, output_size)
        
        # Initialize memory units 
        # item_memory_state_bias=[96x96]
        # rel_memory_state_bias=[8x96x96]
        if self.learn_init_mem:
            if not torch.cuda.is_available():
                self.register_parameter('item_memory_state_bias',
                                        torch.nn.Parameter(torch.Tensor(self.slot_size, self.slot_size).cuda()))
                self.register_parameter('rel_memory_state_bias', torch.nn.Parameter(
                    torch.Tensor(self.num_slot, self.slot_size, self.slot_size).cuda()))

            else:
                self.register_parameter('item_memory_state_bias',
                                        torch.nn.Parameter(torch.Tensor(self.slot_size, self.slot_size)))
                self.register_parameter('rel_memory_state_bias',
                                        torch.nn.Parameter(torch.Tensor(self.num_slot, self.slot_size, self.slot_size)))

            stdev = 1 / (np.sqrt(self.slot_size + self.slot_size))
            nn.init.uniform_(self.item_memory_state_bias, -stdev, stdev)
            stdev = 1 / (np.sqrt(self.slot_size + self.slot_size + self.num_slot))
            nn.init.uniform_(self.rel_memory_state_bias, -stdev, stdev)


    def create_new_state(self, batch_size):
        """Create new State with batch_size of b"""
        if self.learn_init_mem: # True
            read_heads = torch.zeros(batch_size, self.out_att_size) # Read Heads per batch_size, out_att_size=64 [bs x 64]
            item_memory_state = self.item_memory_state_bias.clone().repeat(batch_size, 1, 1) # [bs x 96 x 96]
            rel_memory_state = self.rel_memory_state_bias.clone().repeat(batch_size, 1, 1, 1) # [bs x 8 x 96 x 96]
            if not torch.cuda.is_available():
                read_heads = read_heads.cuda()
        else:
            item_memory_state =  torch.stack([torch.zeros(self.slot_size, self.slot_size) for _ in range(batch_size)])
            read_heads =  torch.zeros(batch_size, self.out_att_size)
            rel_memory_state =  torch.stack([torch.zeros(self.num_slot, self.slot_size, self.slot_size) for _ in range(batch_size)])
            if not torch.cuda.is_available():
                item_memory_state = item_memory_state.cuda()
                read_heads = read_heads.cuda()
                rel_memory_state = rel_memory_state.cuda()

        return read_heads, item_memory_state, rel_memory_state # [bs x 64], [bs x 96x96],  [bs x 8x96x96]



    def compute_gates(self, inputs, memory):
        # inputs = [feat_size x 1  x 96], memory = [bs x 96 x 96]
        memory = torch.tanh(memory)
        if len(inputs.shape) == 3:
            if inputs.shape[1] > 1:
                raise ValueError(
                    "input seq length is larger than 1. create_gate function is meant to be called for each step, with input seq length of 1")
            inputs = inputs.view(inputs.shape[0], -1) # [feat_size  x 96]

            gate_inputs = self.input_gate_projector(inputs) # [feat_size  x 184]
            gate_inputs = gate_inputs.unsqueeze(dim=1) # [feat_size  x 1 x 184]
            gate_memory = self.memory_gate_projector(memory) # [bs x 96  x 184]
        else:
            raise ValueError("input shape of create_gate function is 2, expects 3")

        gates = gate_memory + gate_inputs
        gates = torch.split(gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2)
        input_gate, forget_gate = gates
        assert input_gate.shape[2] == forget_gate.shape[2]

        input_gate = torch.sigmoid(input_gate + self.input_bias)
        forget_gate = torch.sigmoid(forget_gate + self.forget_bias)

        return input_gate, forget_gate

    def compute(self, input_step, prev_state):
        # [bs x feat_size], [[bs x 64], [bs x 96x96],  [bs x 8x96x96]]
        hid = prev_state[0] # read_heads [bs x 64]
        item_memory_state = prev_state[1] # [bs x 96x96]
        rel_memory_state = prev_state[2] #  [bs x 8x96x96]

        #transform input
        controller_outp = self.input_projector(input_step) # [feat_size x 96]
        controller_outp2 = self.input_projector2(input_step) # [feat_size x 96]
        controller_outp3 = self.input_projector3(input_step)  # [feat_size x 8]


        #Mi write
        X = torch.matmul(controller_outp.unsqueeze(2), controller_outp.unsqueeze(1))  #[10, 96, 96] Bxdxd [feat_size x 96, 1]x [feat_size x 1  x 96]
        input_gate, forget_gate = self.compute_gates(controller_outp.unsqueeze(1), item_memory_state) # [feat_size x 1  x 96] [bs x 96x96]


        #Mr read
        controller_outp3 = F.softmax(controller_outp3, dim=-1)
        controller_outp4 = torch.einsum('bn,bd,bndf->bf', controller_outp3, controller_outp2, rel_memory_state)
        X2 = torch.einsum('bd,bf->bdf', controller_outp4, controller_outp2)

        if self.rd:
            # Mi write gating
            R = input_gate * F.tanh(X)
            R += forget_gate * item_memory_state
        else:
            #Mi write
            R =  item_memory_state + torch.matmul(controller_outp.unsqueeze(2), controller_outp.unsqueeze(1))#Bxdxd

        for i in range(self.step):
            #SAM
            qkv = self.qkv_projector[i](R+self.alpha2[i]*X2)
            qkv = self.qkv_layernorm[i](qkv)
            qkv = qkv.permute(0,2,1) #Bx3Nxd

            q,k,v = torch.split(qkv, [self.num_slot]*3, 1)#BxNxd


            R0 = op_att(q, k, v) #BxNxdxd

            #Mr transfer to Mi
            R2= self.rel_projector2(R0.view(R0.shape[0], -1, R0.shape[3]).permute(0, 2, 1))
            R =  R + self.alpha3[i] * R2

            #Mr write
            rel_memory_state = self.alpha1[i]*rel_memory_state + R0

        #Mr transfer to output
        r_vec = self.rel_projector(rel_memory_state.view(rel_memory_state.shape[0],
                                                         rel_memory_state.shape[1],
                                                         -1)).view(input_step.shape[0],-1)
        out = self.rel_projector3(r_vec)

        # if self.gating_after:
        #     #Mi write gating
        #     input_gate, forget_gate = self.compute_gates(controller_outp.unsqueeze(1), R)
        #     if self.rd:
        #         R = input_gate * torch.tanh(R)
        #         R += forget_gate * item_memory_state

        return out, (out, R, rel_memory_state)

    def forward(self, input_step, hidden=None): # input_step=> [seq_size, batch_size, feature_size]

        if len(input_step.shape)==3:
            self.init_sequence(input_step.shape[1])#   self.previous_state = > [bs x 64], [bs x 96x96],  [bs x 8x96x96]
            for i in range(input_step.shape[0]): # for every sequence [word]
                logit, self.previous_state = self.compute(input_step[i], self.previous_state) # [bs x feat_size], [[bs x 64], [bs x 96x96],  [bs x 8x96x96]]

        else:
            if hidden is not None:
                logit, hidden = self.compute(input_step, hidden)
            else:
                logit, self.previous_state = self.compute(input_step,  self.previous_state)
        mlp = self.mlp(logit)
        out = self.out(mlp)
        return out, self.previous_state

    def init_sequence(self, batch_size):
        """Initializing the state."""
        self.previous_state = self.create_new_state(batch_size) # create new state [bs x 64], [bs x 96x96],  [bs x 8x96x96]

    def calculate_num_params(self):
        """Returns the total number of parameters."""
        num_params = 0
        for p in self.parameters():
            num_params += p.data.view(-1).size(0)
        return num_params

In [39]:
stm = STM(10, 20)

In [56]:
a = torch.rand(10,96) # [feat_size x 96, 1]x [feat_size x 1  x 96]
b = torch.matmul(a.unsqueeze(2), a.unsqueeze(1))

In [58]:
b = torch.tanh(b)
b

torch.Size([10, 96, 96])

In [42]:
 b = stm(a, 2)

torch.Size([7, 8, 96])
torch.Size([7, 8, 96])
torch.Size([7, 8, 96])
torch.Size([7, 8, 96])
torch.Size([7, 8, 96])
torch.Size([7, 8, 96])
torch.Size([7, 8, 96])
torch.Size([7, 8, 96])
torch.Size([7, 8, 96])
torch.Size([7, 8, 96])
torch.Size([7, 8, 96])
torch.Size([7, 8, 96])
torch.Size([7, 8, 96])
torch.Size([7, 8, 96])
torch.Size([7, 8, 96])


In [50]:
alpha1 = nn.Parameter(torch.tensor(1., dtype=torch.float32))

In [54]:
mem1 = torch.nn.Parameter(torch.Tensor(96, 96))

In [55]:
mem1.size()

torch.Size([96, 96])

## Self-Attention Transformer Memory

### [International Conference on NLP Techniques and Applications- Sep 5](http://www.wikicfp.com/cfp/servlet/event.showcfp?eventid=106242&copyownerid=33993)
### [13th International Conference on Agents and Artificial Intelligence - Sep 14](http://www.wikicfp.com/cfp/servlet/event.showcfp?eventid=105618&copyownerid=45217)
Heretofore, neural networks with external memory are restricted to single memory with lossy representations of memory interactions. A rich representation of relationships between memory pieces urges a high-order and segregated relational memory.

A rich representation of relationships between memory pieces urges a high-order and segregated relational memory. In this paper, we propose to separate the storage of individual experiences (item memory) and their occurring relationships (relational memory). 

The model proposed here creates item memory based on attention. The attention model is used to read and write over the memory module. The module uses gates to read and write over attended sections.  
The input is used to create relational reasoning (outer product to each memory section with given attention). The attention should we weighted to understand which memory section to select based on (key query value) attention mechanism)

In [None]:
class PositionalEncoding(nn.Module):
    """Implement the PE function."""
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return self.dropout(x)


In [None]:
def attention(query, key, value, mask=None, dropout=None):
    """Compute Scaled Dot Product Attention"""
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1))/ math.sqrt(d_k)
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        """Take in model size and number of heads."""
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        """Implements Figure 2"""
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

In [None]:
class Encoder(nn.Module):
    """Core encoder is a stack of N layers"""
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, mask):
        """Pass the input (and mask) through each layer in turn."""
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [None]:
class LayerNorm(nn.Module):
    """Construct a layernorm module (See citation for details). [Norm]"""
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

class EncoderLayer(nn.Module):
    """Encoder is made up of self-attn and feed forward (defined below)"""
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size
        self.slot_size = size
        self.num_slot = 5
        
        # ============= init memory model ================        
        if self.learn_init_mem:
            if not torch.cuda.is_available():
                self.register_parameter('item_memory_state_bias',
                                        torch.nn.Parameter(torch.Tensor(self.num_slot, self.slot_size).cuda()))         

            else:
                self.register_parameter('item_memory_state_bias',
                                        torch.nn.Parameter(torch.Tensor(self.num_slot, self.slot_size)))
                
            stdev = 1 / (np.sqrt(self.num_slot + self.slot_size))
            nn.init.uniform_(self.item_memory_state_bias, -stdev, stdev)
        
    def compute_gates(self, inputs, memory): # compute gates as LSTM on attention mechanism
        # inputs = [feat_size x 1  x 96], memory = [bs x 96 x 96]
        memory = torch.tanh(memory)
        if len(inputs.shape) == 3:
            if inputs.shape[1] > 1:
                raise ValueError(
                    "input seq length is larger than 1. create_gate function is meant to be called for each step, with input seq length of 1")
            inputs = inputs.view(inputs.shape[0], -1) # [feat_size  x 96]

            gate_inputs = self.input_gate_projector(inputs) # [feat_size  x 184]
            gate_inputs = gate_inputs.unsqueeze(dim=1) # [feat_size  x 1 x 184]
            gate_memory = self.memory_gate_projector(memory) # [bs x 96  x 184]
        else:
            raise ValueError("input shape of create_gate function is 2, expects 3")

        gates = gate_memory + gate_inputs
        gates = torch.split(gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2)
        input_gate, forget_gate = gates
        assert input_gate.shape[2] == forget_gate.shape[2]

        input_gate = torch.sigmoid(input_gate + self.input_bias)
        forget_gate = torch.sigmoid(forget_gate + self.forget_bias)

        return input_gate, forget_gate
    
    def compute(self, input_step, prev_state):
        # [bs x feat_size], [[bs x 64], [bs x 96x96],  [bs x 8x96x96]]
        hid = prev_state[0] # read_heads [bs x 64]
        item_memory_state = prev_state[1] # [bs x 96x96]
        rel_memory_state = prev_state[2] #  [bs x 8x96x96]

        #transform input
        controller_outp = self.input_projector(input_step) # [feat_size x 96]
        controller_outp2 = self.input_projector2(input_step) # [feat_size x 96]
        controller_outp3 = self.input_projector3(input_step)  # [feat_size x 8]


        #Mi write
        X = torch.matmul(controller_outp.unsqueeze(2), controller_outp.unsqueeze(1))  #[10, 96, 96] Bxdxd [feat_size x 96, 1]x [feat_size x 1  x 96]
        input_gate, forget_gate = self.compute_gates(controller_outp.unsqueeze(1), item_memory_state) # [feat_size x 1  x 96] [bs x 96x96]


        #Mr read
        controller_outp3 = F.softmax(controller_outp3, dim=-1)
        controller_outp4 = torch.einsum('bn,bd,bndf->bf', controller_outp3, controller_outp2, rel_memory_state)
        X2 = torch.einsum('bd,bf->bdf', controller_outp4, controller_outp2)

        if self.rd:
            # Mi write gating
            R = input_gate * F.tanh(X)
            R += forget_gate * item_memory_state
        else:
            #Mi write
            R =  item_memory_state + torch.matmul(controller_outp.unsqueeze(2), controller_outp.unsqueeze(1))#Bxdxd

        for i in range(self.step):
            #SAM
            qkv = self.qkv_projector[i](R+self.alpha2[i]*X2)
            qkv = self.qkv_layernorm[i](qkv)
            qkv = qkv.permute(0,2,1) #Bx3Nxd

            q,k,v = torch.split(qkv, [self.num_slot]*3, 1)#BxNxd


            R0 = op_att(q, k, v) #BxNxdxd

            #Mr transfer to Mi
            R2= self.rel_projector2(R0.view(R0.shape[0], -1, R0.shape[3]).permute(0, 2, 1))
            R =  R + self.alpha3[i] * R2

            #Mr write
            rel_memory_state = self.alpha1[i]*rel_memory_state + R0

        #Mr transfer to output
        r_vec = self.rel_projector(rel_memory_state.view(rel_memory_state.shape[0],
                                                         rel_memory_state.shape[1],
                                                         -1)).view(input_step.shape[0],-1)
        out = self.rel_projector3(r_vec)

        # if self.gating_after:
        #     #Mi write gating
        #     input_gate, forget_gate = self.compute_gates(controller_outp.unsqueeze(1), R)
        #     if self.rd:
        #         R = input_gate * torch.tanh(R)
        #         R += forget_gate * item_memory_state

        return out, (out, R, rel_memory_state)


    def forward(self, x, mask):
        """Follow Figure 1 (left) for connections."""
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
def make_encoder_model(src_vocab, tgt_vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N)

    # model = EncoderDecoder(
    #     Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
    #     Decoder(DecoderLayer(d_model, c(attn), c(attn), 
    #                          c(ff), dropout), N),
    #     nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
    #     nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
    #     Generator(d_model, tgt_vocab))
    
    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform(p)
    return model