# Relational RNNs by Adam Santoro et al. in PyTorch



In [1]:
import numpy as np

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [95]:
class RelationalMemory(nn.Module):
    """
    Args:
        input_size: 
        mem_slots: The total number of memory slots to use.
        head_size: The size of an attention head.
        num_heads: The number of attention heads to use. Defaults to 1.
        num_blocks: Number of times to compute attention per time step. Defaults to 1.
        forget_bias:
        input_bias:
        gate_style:
        attention_mlp_layers:
        key_size:
    """
    def __init__(self, input_size,
                 mem_slots, head_size, num_heads=1, num_blocks=1,
                 forget_bias=1.0, input_bias=0.0, gate_style='unit',
                 attention_mlp_layers=2, key_size=None):
        super(RelationalMemory, self).__init__()
        
        self._mem_slots = mem_slots
        self._head_size = head_size
        self._num_heads = num_heads
        self._mem_size = self._head_size * self._num_heads # hidden_size

        if num_blocks < 1:
            raise ValueError('num_blocks must be >= 1. Got: {}.'.format(num_blocks))
        self._num_blocks = num_blocks

        self._forget_bias = forget_bias
        self._input_bias = input_bias

        if gate_style not in ['unit', 'memory', None]:
            raise ValueError(
                'gate_style must be one of [\'unit\', \'memory\', None]. Got: '
                '{}.'.format(gate_style))
        self._gate_style = gate_style

        if attention_mlp_layers < 1:
            raise ValueError('attention_mlp_layers must be >= 1. Got: {}.'.format(
                attention_mlp_layers))
        self._attention_mlp_layers = attention_mlp_layers

        self._key_size = key_size if key_size else self._head_size

        self._linear = nn.Linear(in_features=input_size,
                                 out_features=self._mem_size)
        
        qkv_size = 2 * self._key_size + self._head_size
        total_size = qkv_size * self._num_heads
        self._attention_linear = nn.Linear(in_features=self._mem_size,
                                           out_features=total_size)
        self._attention_layer_norm = nn.LayerNorm(total_size)
        
        attention_mlp_module = nn.ModuleList([nn.Sequential(
                nn.Linear(in_features=self._mem_size,
                          out_features=self._mem_size),
                nn.ReLU())] * (self._attention_mlp_layers - 1) 
                + [nn.Linear(in_features=self._mem_size,
                             out_features=self._mem_size)]
        )
        self._attention_mlp = nn.Sequential(*attention_mlp_module)
        
        self._attend_layer_norm_1 = nn.LayerNorm(self._mem_size)
        self._attend_layer_norm_2 = nn.LayerNorm(self._mem_size)
        
        num_gates = 2 * self._calculate_gate_size()
        self._gate_inputs_linear = nn.Linear(in_features=self._mem_size,
                                             out_features=num_gates)
        
        self._gate_memory_linear = nn.Linear(in_features=self._mem_size,
                                             out_features=num_gates)
        
        
    def initial_state(self, batch_size):
        """Creates the initial memory.

        We should ensure each row of the memory is initialized to be unique,
        so initialize the matrix to be the identity. We then pad or truncate
        as necessary so that init_state is of size (batch_size, self._mem_slots, self._mem_size).

        Returns:
            init_state: A truncated or padded matrix of size
            (batch_size, self._mem_slots, self._mem_size).
        """    
        init_state = torch.eye(n=self._mem_slots).repeat(batch_size, 1, 1)
        
        # Pad the matrix with zeros.
        if self._mem_size > self._mem_slots:
            difference = self._mem_size - self._mem_slots
            pad = torch.zeros((batch_size, self._mem_slots, difference))
            init_state = torch.cat([init_state, pad], dim=-1)
        # Truncation. Take the first `self._mem_size` components.
        elif self._mem_size < self._mem_slots:
            init_state = init_state[:, :, :self._mem_size]
        
        return init_state
        

    def _multihead_attention(self, memory): # memory: [B, MEM_SLOT, MEM_SIZE]
        # F = total_size
        # mem_slots = MEM_SLOT = N
        mem_slots = memory.size(1)
        
        # [B, MEM_SLOT, MEM_SIZE] -> [B*MEM_SLOT, MEM_SIZE] -> Linear -> [B*MEM_SLOT, F]
        qkv = self._attention_linear(memory.view(-1, memory.size(2)))
        
        # [B*MEM_SLOT, F] -> Layer Norm -> [B*MEM_SLOT, F] -> [B, MEM_SLOT, F]
        qkv = self._attention_layer_norm(qkv).view(memory.size(0), mem_slots, -1)
        
        # H = num_heads
        qkv_size = 2 * self._key_size + self._head_size
        
        # [B, N, F] -> [B, N, H, F/H]
        qkv_reshape = qkv.view(-1, mem_slots, self._num_heads, qkv_size)
        
        # [B, N, H, F/H] -> [B, H, N, F/H]
        qkv_transpose = qkv_reshape.permute(0, 2, 1, 3)
        # split q, k, v
        q, k, v = torch.split(qkv_transpose, [self._key_size, self._key_size, self._head_size], dim=-1)
        
        q *= qkv_size ** -0.5
        dot_product = torch.matmul(q, torch.transpose(k, 2, 3)) # [B, H, N, N]
        weights = F.softmax(dot_product, dim=-1)
        
        #[B, H, N, V]
        output = torch.matmul(weights, v)
        
        # [B, H, N, V] -> [B, N, H, V]
        output_transpose = output.permute(0, 2, 1, 3)
        
        # [B, N, H, V] -> [B, N, H * V]
        new_memory = output_transpose.contiguous().view(-1, output_transpose.size(1), 
                                                        output_transpose.size(2)*output_transpose.size(3))
        
        return new_memory #[B, MEM_SLOTS, MEM_SIZE]
    
    
    def _attend_over_memory(self, memory):
        # memory: [B, MEM_SLOT, MEM_SIZE]
        for _ in range(self._num_blocks):
            attended_memory = self._multihead_attention(memory) # [B, MEM_SLOT, MEM_SIZE]
            
            # add a skip connection the multiheaded attention's input.
            # memory = LN_1(memory + attended_memory) [B*MEM_SLOT, MEM_SIZE]
            memory = self._attend_layer_norm_1((memory + attended_memory).view(-1, memory.size(2)))
            
            # add a skip connection to the attention_mlp's input.
            # memory = LN_2( MLP(memory) + memory)
            memory = self._attend_layer_norm_2(self._attention_mlp(memory) + memory).view(-1, 
                                                                                          attended_memory.size(1),
                                                                                          attended_memory.size(2))
    
        return memory
    
    def _calculate_gate_size(self):
        if self._gate_style == 'unit':
            return self._mem_size
        elif self._gate_style == 'memory':
          return 1
        else:
            return 0
        
    def _create_gates(self, inputs, memory):
        memory = torch.tanh(memory)
        
        #inputs [B, 1, MEM_SIZE] -> [B, 1*MEM_SIZE]
        inputs = inputs.view(inputs.size(0), -1)
        
        # [B, 1*MEM_SIZE] -> Linear -> [B, num_gates] -> [B, 1, num_gates]
        gate_inputs = self._gate_inputs_linear(inputs).unsqueeze(1)
        
        # memory [B, MEM_SLOT, MEM_SIZE] -> [B*MEM_SLOT, MEM_SIZE] -> Linear -> [B*MEM_SLOT, 2*num_gates]
        # -> [B, MEM_SLOT, 2*num_gates]
        gate_memory = self._gate_memory_linear(memory.view(-1, memory.size(2))).view(memory.size(0),
                                                                                     memory.size(1),
                                                                                     -1)
        
        input_gate, forget_gate = torch.chunk(gate_memory + gate_inputs, 2, dim=2)
        
        input_gate = torch.sigmoid(input_gate + self._input_bias)
        forget_gate = torch.sigmoid(forget_gate + self._forget_bias)
        
        return input_gate, forget_gate #[B, MEM_SLOT, num_gate], [B, MEM_SLOT, num_gate]
        
                                              
    def forward(self, x, memory, treat_input_as_matrix=False):
        # x: [B, T, F=input_size]
        # memory: [B, MEM_SLOTS, MEM_SIZE]
        batch_size = x.size(0)
        total_timesteps = x.size(1)
        
        for index in range(total_timesteps):
            # For each time-step
            # inputs: [B, 1, F=input_size]
            inputs = x[:,index].unsqueeze(1)
            
            if treat_input_as_matrix:
                # [B, 1, F] -> [B*1, F] -> linear ->[B*1, MEM_SIZE] -> [B, 1, MEM_SIZE]
                inputs_reshape =  self._linear(inputs.view(-1, input.size(2))).view(input.size(0), -1, self._mem_size)
            else:
                # [B, 1, F] -> [B, 1*F] -> linear -> [B, MEM_SIZE] -> [B, 1, MEM_SIZE]
                inputs = inputs.view(inputs.size(0), -1)
                inputs = self._linear(inputs)
                inputs_reshape = inputs.unsqueeze(1)

            # [B, MEM_SLOTS, MEM_SIZE] -> [B, MEM_SLOT+1, MEM_SIZE]
            memory_plus_input = torch.cat([memory, inputs_reshape], dim=1)

            next_memory = self._attend_over_memory(memory_plus_input)
            n = inputs_reshape.size(1)
            # [B, MEM_SLOT+1, MEM_SIZE] -> [B, MEM_SLOT, MEM_SIZE]
            next_memory = next_memory[:, :-n, :]

            if self._gate_style == 'unit' or self._gate_style == 'memory':
                input_gate, forget_gate = self._create_gates(inputs_reshape, memory) #[B, MEM_SLOT, num_gate] 
                next_memory = input_gate * torch.tanh(next_memory)
                next_memory += forget_gate * memory
            
            # output: [B, MEM_SLOT, MEM_SIZE] -> [B, MEM_SLOT*MEM_SIZE]
            output = next_memory.view(next_memory.size(0), -1)
            
            return output, next_memory


In [96]:
mem_slots = 4
head_size = 32
num_heads = 2
batch_size = 5
input_size = 3

In [97]:
input_shape = (batch_size, 3, 3)
inputs = torch.Tensor(batch_size, 3, 3)

In [98]:
mem = RelationalMemory(input_size, mem_slots, head_size, num_heads)

In [99]:
init_state = mem.initial_state(batch_size)

In [100]:
init_state.shape

torch.Size([5, 4, 64])

In [101]:
mem(inputs, init_state)

(tensor([[ 1.7596e+00,  1.2347e-39, -3.8805e-39,  ..., -1.2387e-39,
          -2.9040e-01,  3.2618e-02],
         [ 1.7596e+00,  2.9709e-01, -9.3371e-01,  ..., -2.9806e-01,
          -2.9040e-01,  3.2618e-02],
         [ 1.7596e+00,  1.2347e-39, -3.8805e-39,  ..., -1.2387e-39,
          -2.9040e-01,  3.2618e-02],
         [ 1.7596e+00,  2.9709e-01, -9.3371e-01,  ..., -2.9806e-01,
          -2.9040e-01,  3.2618e-02],
         [ 1.7596e+00,  1.2347e-39, -3.8805e-39,  ..., -1.2387e-39,
          -2.9040e-01,  3.2618e-02]], grad_fn=<ViewBackward>),
 tensor([[[ 1.7596e+00,  1.2347e-39, -3.8805e-39,  ..., -7.3236e-40,
           -4.7077e-01, -4.6448e-01],
          [-4.2098e-01,  8.0765e-39, -3.8067e-39,  ..., -5.9628e-40,
           -5.7438e-01, -4.5541e-01],
          [-5.8878e-01,  1.3328e-39,  1.0000e+00,  ..., -8.3561e-40,
           -4.3380e-01, -3.2370e-01],
          [-5.3090e-01,  2.2225e-39, -3.9518e-39,  ..., -1.2387e-39,
           -2.9040e-01,  3.2618e-02]],
 
         [[ 1.7596

In [31]:
inputs.shape

torch.Size([5, 3, 3])