# Relational RNNs by Adam Santoro et al. in PyTorch



In [1]:
import numpy as np

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence

In [None]:
class RelationalMemory(nn.Module):
    """
    Args:
        input_size: 
        mem_slots: The total number of memory slots to use.
        head_size: The size of an attention head.
        num_heads: The number of attention heads to use. Defaults to 1.
        num_blocks: Number of times to compute attention per time step. Defaults to 1.
        forget_bias:
        input_bias:
        gate_style:
        attention_mlp_layers:
        key_size:
    """
    def __init__(self, input_size,
                 mem_slots, head_size, num_heads=1, num_blocks=1,
                 forget_bias=1.0, input_bias=0.0, gate_style='unit',
                 attention_mlp_layers=2, key_size=None):
        super(RelationalMemory, self).__init__()
        
        self._mem_slots = mem_slots
        self._head_size = head_size
        self._num_heads = num_heads
        self._mem_size = self._head_size * self._num_heads # hidden_size

        if num_blocks < 1:
            raise ValueError('num_blocks must be >= 1. Got: {}.'.format(num_blocks))
        self._num_blocks = num_blocks

        self._forget_bias = forget_bias
        self._input_bias = input_bias

        if gate_style not in ['unit', 'memory', None]:
            raise ValueError(
                'gate_style must be one of [\'unit\', \'memory\', None]. Got: '
                '{}.'.format(gate_style))
        self._gate_style = gate_style

        if attention_mlp_layers < 1:
            raise ValueError('attention_mlp_layers must be >= 1. Got: {}.'.format(
                attention_mlp_layers))
        self._attention_mlp_layers = attention_mlp_layers

        self._key_size = key_size if key_size else self._head_size

        self._linear = nn.Linear(in_features=input_size,
                                 out_features=self._mem_size)
        
        key_size = self._key_size
        value_size = self._head_size
        qkv_size = 2 * key_size + value_size
        total_size = qkv_size * self._num_heads
        self._attention_linear = nn.Linear(in_features=self._mem_size,
                                           out_features=total_size)
        self._attention_layer_norm = nn.LayerNorm(total_size)
        
        self._attention_mlp = 
        self._attend_layer_norm_1 = nn.LayerNorm()
        self._attend_layer_norm_2 = nn.LayerNorm()
        
        
 
        
    def initial_state(self, batch_size):
        """Creates the initial memory.

        We should ensure each row of the memory is initialized to be unique,
        so initialize the matrix to be the identity. We then pad or truncate
        as necessary so that init_state is of size (self._batch_size, self._mem_slots, self._mem_size).

        Returns:
            init_state: A truncated or padded matrix of size
            (self._batch_size, self._mem_slots, self._mem_size).
        """    
        init_state = torch.eye(num_rows=self._mem_slots).repeat(self._batch_size, 1, 1)
        
        # Pad the matrix with zeros.
        if self._mem_size > self._mem_slots:
            difference = self._mem_size - self._mem_slots
            pad = torch.zeros((self._batch_size, self._mem_slots, difference))
            init_state = torch.cat([init_state, pad], dim=-1)
        # Truncation. Take the first `self._mem_size` components.
        elif self._mem_size < self._mem_slots:
            init_state = init_state[:, :, :self._mem_size]
        
        return init_state
        

    def _multihead_attention(self, memory): # memory: [B, MEM_SLOT, MEM_SIZE]
        # F = total_size
        # mem_slots = MEM_SLOT = N
        mem_slots = memory.size(1)
        
        # [B, MEM_SLOT, MEM_SIZE] -> [B*MEM_SLOT, MEM_SIZE] -> Linear -> [B*MEM_SLOT, F]
        qkv = self._attention_linear(memory.view(-1, memory.size(2)))
        
        # [B*MEM_SLOT, F] -> Layer Norm -> [B*MEM_SLOT, F] -> [B, MEM_SLOT, F]
        qkv = self._attention_layer_norm(qkv).view(memory.size(0), mem_slots, -1)
        
        # H = num_heads
        qkv_size = 2 * self._key_size + self._head_size
        
        # [B, N, F] -> [B, N, H, F/H]
        qkv_reshape = qkv.view(-1, mem_slots, self._num_heads, qkv_size)
        
        # [B, N, H, F/H] -> [B, H, N, F/H]
        qkv_transpose = qkv_reshape.permute(0, 2, 1, 3)
        # split q, k, v
        q, k, v = torch.split(qkv_transpose, [self._key_size, self._key_size, self._head_size], dim=-1)
        
        q *= qkv_size ** -0.5
        dot_product = torch.matmul(q, torch.t(k)) # [B, H, N, N]
        weights = F.softmax(dot_product)
        
        #[B, H, N, V]
        output = torch.matmul(weights, v)
        
        # [B, H, N, V] -> [B, N, H, V]
        output_transpose = output.permute(0, 2, 1, 3)
        
        # [B, N, H, V] -> [B, N, H * V]
        new_memory = output_transpose.view(-1, output_transpose.size(1), 
                                           output_transpose.size(2)*output_transpose.size(3))
        
        return new_memory #[B, MEM_SLOTS, MEM_SIZE]
    
    
    def _attend_over_memory(self, memory):
        # memory: [B, MEM_SLOT, MEM_SIZE]
        for _ in range(self._num_blocks):
            attended_memory = self._multihead_attention(memory) # [B, MEM_SLOT, MEM_SIZE]
            
            # add a skip connection the multiheaded attention's input.
            memory = self.LayerNorm((memory + attended_memory).view(-1, memory.size(2))).view(memory.size(0),
                                                                                              memory.size(1), -1)

            # add a skip connection to the attention_mlp's input.
            
            
        return memory
        
        
    
    def forward(self, x, memory, treat_input_as_matrix=False):
        # x: [B, T, F=input_size]
        batch_size = x.size(0)
        total_timesteps = x.size(1)
        
        for index in range(total_timesteps):
            # For each time-step
            # inputs: [B, 1, F]
            inputs = x[:,index].unsqueeze(1)
            
            # memory: [B, MEM_SLOTS, MEM_SIZE]
            if treat_input_as_matrix:
                # [B, 1, F] -> [B*1, F] -> linear ->[B*1, MEM_SIZE] -> [B, 1, MEM_SIZE]
                inputs_reshape =  self._linear(inputs.view(-1, input.size(2))).view(input.size(0), -1, self._mem_size)
            else:
                # [B, 1, F] -> [B, 1*F] -> linear -> [B, MEM_SIZE] -> [B, 1, MEM_SIZE]
                inputs = inputs.view(input.size(0), -1)
                inputs = self._linear(inputs)
                input_reshape = inputs.unsqueze(1)

            # [B, MEM_SLOTS, MEM_SIZE] -> [B, MEM_SLOT+1 or MEM_SLOT+1, MEM_SIZE]
            memory_plus_input = torch.cat([memory, input_reshape], dim=1)

            next_memory = self._attend_over_memory(memory_plus_input)
            n = inputs_reshape.size(1)
            # [B, MEM_SLOT+1, MEM_SIZE] -> [B, MEM_SLOT, MEM_SIZE]
            next_memory = next_memory[:, :-n, :]

            if self._gate_style == 'unit' or self._gate_style == 'memory':
                self._input_gate, self._forget_gate = self._create_gates(inputs_reshape, memory)
                next_memory = self._input_gate * tf.tanh(next_memory)
                next_memory += self._forget_gate * memory

            output = 
            
            

In [33]:
insputs = torch.Tensor(64, 10, 32)

In [36]:
insputs[:, 9].unsqueeze(1).shape

torch.Size([64, 1, 32])

In [3]:
import tensorflow as tf
tf.enable_eager_execution()
from sonnet.python.modules import relational_memory
from sonnet.python.modules import basic

  return f(*args, **kwds)
  return f(*args, **kwds)


In [14]:
mem_slots = 4
head_size = 32
num_heads = 2
batch_size = 5

input_shape = (batch_size, 1, 3)

In [15]:
mem = relational_memory.RelationalMemory(mem_slots, head_size, num_heads)

In [16]:
init_state = mem.initial_state(batch_size)

In [17]:
init_state.shape

TensorShape([Dimension(5), Dimension(4), Dimension(64)])

In [18]:
inputs = basic.BatchFlatten(preserve_dims=2)(tf.zeros(input_shape))

In [19]:
inputs.shape

TensorShape([Dimension(5), Dimension(1), Dimension(3)])

In [20]:
inputs_reshape = basic.BatchApply(basic.Linear(64), n_dims=2)(inputs)

In [21]:
inputs_reshape.shape

TensorShape([Dimension(5), Dimension(1), Dimension(64)])

In [22]:
n = inputs_reshape.get_shape().as_list()[1]

In [23]:
n

1