# Relational RNNs by Adam Santoro et al. in PyTorch



In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class RelationalMemory(nn.Module):
    """
    Args:
        mem_slots: The total number of memory slots to use.
        head_size: The size of an attention head.
        num_heads: The number of attention heads to use. Defaults to 1.
        num_blocks: Number of times to compute attention per time step. Defaults to 1.
        forget_bias:
        input_bias:
        gate_style:
        attention_mlp_layers:
        key_size:
    """
    def __init__(self, mem_slots, head_size, num_heads=1, num_blocks=1,
                 forget_bias=1.0, input_bias=0.0, gate_style='unit',
                 attention_mlp_layers=2, key_size=None):
        super(RelationalMemory, self).__init__()
        
        self._mem_slots = mem_slots
        self._head_size = head_size
        self._num_heads = num_heads
        self._mem_size = self._head_size * self._num_heads

        if num_blocks < 1:
            raise ValueError('num_blocks must be >= 1. Got: {}.'.format(num_blocks))
        self._num_blocks = num_blocks

        self._forget_bias = forget_bias
        self._input_bias = input_bias

        if gate_style not in ['unit', 'memory', None]:
            raise ValueError(
                'gate_style must be one of [\'unit\', \'memory\', None]. Got: '
                '{}.'.format(gate_style))
        self._gate_style = gate_style

        if attention_mlp_layers < 1:
            raise ValueError('attention_mlp_layers must be >= 1. Got: {}.'.format(
                attention_mlp_layers))
        self._attention_mlp_layers = attention_mlp_layers

        self._key_size = key_size if key_size else self._head_size
        
        
    def initial_state(self, batch_size):
        """Creates the initial memory.

        We should ensure each row of the memory is initialized to be unique,
        so initialize the matrix to be the identity. We then pad or truncate
        as necessary so that init_state is of size (batch_size, self._mem_slots, self._mem_size).

        Args:
            batch_size: The size of the batch.

        Returns:
            init_state: A truncated or padded matrix of size
            (batch_size, self._mem_slots, self._mem_size).
        """    
        init_state = torch.eye(num_rows=self._mem_slots).repeat(batch_size, 1, 1)
        
        # Pad the matrix with zeros.
        if self._mem_size > self._mem_slots:
            difference = self._mem_size - self._mem_slots
            pad = torch.zeros((batch_size, self._mem_slots, difference))
            init_state = torch.cat([init_state, pad], dim=-1)
        # Truncation. Take the first `self._mem_size` components.
        elif self._mem_size < self._mem_slots:
            init_state = init_state[:, :, :self._mem_size]
        
        return init_state
        

    def _multihead_attention(self, memory):
        key_size = self._key_size
        value_size = self._head_size
    
        qkv_size = 2 * key_size + value_size
        total_size = qkv_size * self._num_heads 
        
        
        
    def forward(self, inputs, memory, treat_input_as_matrix=False):
        # [B, T, F]
        if treat_input_as_matrix:
            # Flatten 
            inputs = 
        else:
            inputs = 
            
        