# Relational RNNs by Adam Santoro et al. in PyTorch



In [1]:
import numpy as np
from scipy.spatial import distance as spdistance

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [16]:
class RelationalMemory(nn.Module):
    """
    Args:
        input_size: The size of the input features
        mem_slots: The total number of memory slots to use.
        head_size: The size of an attention head.
        num_heads: The number of attention heads to use. Defaults to 1.
        num_blocks: Number of times to compute attention per time step. Defaults to 1.
        forget_bias: 
        input_bias:
        gate_style:
        attention_mlp_layers:
        key_size:
    """
    def __init__(self, input_size,
                 mem_slots, head_size, num_heads=1, num_blocks=1,
                 forget_bias=1.0, input_bias=0.0, gate_style='unit',
                 attention_mlp_layers=2, key_size=None):
        super(RelationalMemory, self).__init__()
        
        self._mem_slots = mem_slots
        self._head_size = head_size
        self._num_heads = num_heads
        self._mem_size = self._head_size * self._num_heads

        if num_blocks < 1:
            raise ValueError('num_blocks must be >= 1. Got: {}.'.format(num_blocks))
        self._num_blocks = num_blocks

        self._forget_bias = forget_bias
        self._input_bias = input_bias

        if gate_style not in ['unit', 'memory', None]:
            raise ValueError(
                'gate_style must be one of [\'unit\', \'memory\', None]. Got: '
                '{}.'.format(gate_style))
        self._gate_style = gate_style

        if attention_mlp_layers < 1:
            raise ValueError('attention_mlp_layers must be >= 1. Got: {}.'.format(
                attention_mlp_layers))
        self._attention_mlp_layers = attention_mlp_layers

        self._key_size = key_size if key_size else self._head_size

        self._linear = nn.Linear(in_features=input_size,
                                 out_features=self._mem_size)
        
        qkv_size = 2 * self._key_size + self._head_size
        total_size = qkv_size * self._num_heads
        self._attention_linear = nn.Linear(in_features=self._mem_size,
                                           out_features=total_size)
        self._attention_layer_norm = nn.LayerNorm(total_size)
        
        attention_mlp_modules = nn.ModuleList([
            nn.Sequential(
                nn.Linear(in_features=self._mem_size,
                          out_features=self._mem_size),
                nn.ReLU())] * (self._attention_mlp_layers - 1) +
            [nn.Linear(in_features=self._mem_size,
                       out_features=self._mem_size)]
        )
        self._attention_mlp = nn.Sequential(*attention_mlp_modules)
        
        self._attend_layer_norm_1 = nn.LayerNorm(self._mem_size)
        self._attend_layer_norm_2 = nn.LayerNorm(self._mem_size)
        
        num_gates = 2 * self._calculate_gate_size()
        self._gate_inputs_linear = nn.Linear(in_features=self._mem_size,
                                             out_features=num_gates)
        
        self._gate_hidden_linear = nn.Linear(in_features=self._mem_size,
                                             out_features=num_gates)
        
    def initial_state(self, batch_size):
        """Creates the initial memory.

        We should ensure each row of the memory is initialized to be unique,
        so initialize the matrix to be the identity. We then pad or truncate
        as necessary so that init_state is of size 
        (batch_size, self._mem_slots, self._mem_size).

        Returns:
            init_state: A truncated or padded matrix of size
            (batch_size, self._mem_slots, self._mem_size).
        """    
        init_state = torch.eye(n=self._mem_slots).repeat(batch_size, 1, 1)
        
        if self._mem_size > self._mem_slots:
            # Pad the matrix with zeros.
            difference = self._mem_size - self._mem_slots
            pad = torch.zeros((batch_size, self._mem_slots, difference))
            init_state = torch.cat([init_state, pad], dim=-1)
        elif self._mem_size < self._mem_slots:
            # Truncation. Take the first `self._mem_size` components.
            init_state = init_state[:, :, :self._mem_size]
        
        return init_state.detach()

    def _multihead_attention(self, memory): # memory: [B, MEM_SLOT, MEM_SIZE]
        # F = total_size
        # mem_slots = MEM_SLOT = N
        mem_slots = memory.size(1)
        
        # [B, MEM_SLOT, MEM_SIZE] -> [B*MEM_SLOT, MEM_SIZE] -> Linear -> [B*MEM_SLOT, F]
        qkv = self._attention_linear(memory.view(-1, memory.size(2)))
        
        # [B*MEM_SLOT, F] -> Layer Norm -> [B*MEM_SLOT, F] -> [B, MEM_SLOT, F]
        qkv = self._attention_layer_norm(qkv).view(memory.size(0), mem_slots, -1)
        
        # H = num_heads
        qkv_size = 2 * self._key_size + self._head_size
        
        # [B, N, F] -> [B, N, H, F/H]
        qkv_reshape = qkv.view(-1, mem_slots, self._num_heads, qkv_size)
        
        # [B, N, H, F/H] -> [B, H, N, F/H]
        qkv_transpose = qkv_reshape.permute(0, 2, 1, 3)
        # split q, k, v
        q, k, v = torch.split(qkv_transpose, [self._key_size, self._key_size, self._head_size], dim=-1)
        
        q *= qkv_size ** -0.5
        dot_product = torch.matmul(q, torch.transpose(k, 2, 3)) # [B, H, N, N]
        weights = F.softmax(dot_product, dim=-1)
        
        #[B, H, N, V]
        output = torch.matmul(weights, v)
        
        # [B, H, N, V] -> [B, N, H, V]
        output_transpose = output.permute(0, 2, 1, 3)
        
        # [B, N, H, V] -> [B, N, H * V]
        new_memory = output_transpose.contiguous().view(-1, output_transpose.size(1), 
                                                        output_transpose.size(2)*output_transpose.size(3))
        
        return new_memory #[B, MEM_SLOTS, MEM_SIZE]
    
    
    def _attend_over_memory(self, memory):
        # memory: [B, MEM_SLOT, MEM_SIZE]
        for _ in range(self._num_blocks):
            attended_memory = self._multihead_attention(memory) # [B, MEM_SLOT, MEM_SIZE]
            
            # add a skip connection the multiheaded attention's input.
            # memory = LN_1(memory + attended_memory) [B*MEM_SLOT, MEM_SIZE]
            memory = self._attend_layer_norm_1((memory + attended_memory).view(-1, memory.size(2)))
            
            # add a skip connection to the attention_mlp's input.
            # memory = LN_2( MLP(memory) + memory)
            memory = self._attend_layer_norm_2(self._attention_mlp(memory) + memory).view(-1, 
                                                                                          attended_memory.size(1),
                                                                                          attended_memory.size(2))
        return memory
    
    def _calculate_gate_size(self):
        if self._gate_style == 'unit':
            return self._mem_size
        elif self._gate_style == 'memory':
          return 1
        else:
            return 0
        
    def _create_gates(self, inputs, memory):
        hidden = torch.tanh(memory)
        
        #inputs [B, 1, MEM_SIZE] -> [B, 1*MEM_SIZE]
        inputs = inputs.view(inputs.size(0), -1)
        
        # [B, 1*MEM_SIZE] -> Linear -> [B, num_gates] -> [B, 1, num_gates]
        gate_inputs = self._gate_inputs_linear(inputs).unsqueeze(1)
        
        # hidden [B, MEM_SLOT, MEM_SIZE] -> [B*MEM_SLOT, MEM_SIZE] -> Linear -> [B*MEM_SLOT, num_gates]
        # -> [B, MEM_SLOT, num_gates]
        gate_hidden = self._gate_hidden_linear(hidden.view(-1, hidden.size(2))).view(hidden.size(0),
                                                                                     hidden.size(1),
                                                                                     -1)
        
        input_gate, forget_gate = torch.chunk(gate_hidden + gate_inputs, 2, dim=2)
        
        input_gate = torch.sigmoid(input_gate + self._input_bias)
        forget_gate = torch.sigmoid(forget_gate + self._forget_bias)
        
        return input_gate, forget_gate #[B, MEM_SLOT, num_gates/2], [B, MEM_SLOT, num_gates/2]
                                                   
    def forward(self, x, memory=None, treat_input_as_matrix=False):
        # x: [B, T, F=input_size]
        # memory: [B, MEM_SLOTS, MEM_SIZE]
        batch_size = x.size(0)
        total_timesteps = x.size(1)
        
        output_accumulator = x.new_zeros(batch_size, total_timesteps, self._mem_slots*self._mem_size)
        
        for index in range(total_timesteps):
            # For each time-step
            # inputs: [B, 1, F=input_size]
            inputs = x[:,index].unsqueeze(1)
            
            if treat_input_as_matrix:
                # [B, 1, F] -> [B*1, F] -> linear ->[B*1, MEM_SIZE] -> [B, 1, MEM_SIZE]
                inputs_reshape =  self._linear(inputs.view(-1, inputs.size(2))).view(inputs.size(0), 
                                                                                     -1, 
                                                                                     self._mem_size)
            else:
                # [B, 1, F] -> [B, 1*F] -> linear -> [B, MEM_SIZE] -> [B, 1, MEM_SIZE]
                inputs = inputs.view(inputs.size(0), -1)
                inputs = self._linear(inputs)
                inputs_reshape = inputs.unsqueeze(1)

            # [B, MEM_SLOTS, MEM_SIZE] -> [B, MEM_SLOT+1, MEM_SIZE]
            memory_plus_input = torch.cat([memory, inputs_reshape], dim=1)

            next_memory = self._attend_over_memory(memory_plus_input)
            n = inputs_reshape.size(1)
            # [B, MEM_SLOT+1, MEM_SIZE] -> [B, MEM_SLOT, MEM_SIZE]
            next_memory = next_memory[:, :-n, :]

            if self._gate_style == 'unit' or self._gate_style == 'memory':
                input_gate, forget_gate = self._create_gates(inputs_reshape, memory) #[B, MEM_SLOT, num_gates/2] 
                next_memory = input_gate * torch.tanh(next_memory)
                next_memory += forget_gate * memory
            
            # output: [B, MEM_SLOT, MEM_SIZE] -> [B, MEM_SLOT*MEM_SIZE]
            output = next_memory.view(next_memory.size(0), -1)
            
            output_accumulator[:,index] = output
            memory = next_memory
        
        return output_accumulator, memory

In [17]:
class NthFarthest(Dataset):
    def __init__(self, num_objects, num_features, batch_size, epochs, transform=None, target_transform=None):
        super(NthFarthest, self).__init__()
        
        self._num_objects = num_objects
        self._num_features = num_features
        self._transform = transform
        self._target_transform = target_transform
        
    def _get_single_set(self, num_objects, num_features):
        # Generate random binary vectors
        data = np.random.uniform(-1, 1, size=(num_objects, num_features))

        distances = spdistance.squareform(spdistance.pdist(data))
        distance_idx = np.argsort(distances)

        # Choose random distance
        nth = np.random.randint(0, num_objects)

        # Pick out the nth furthest for each object
        nth_furthest = np.where(distance_idx == nth)[1]

        # Choose random reference object
        reference = np.random.randint(0, num_objects)

        # Get identity of object that is the nth furthest from reference object
        labels = nth_furthest[reference]

        # Compile data
        object_ids = np.identity(num_objects)
        nth_matrix = np.zeros((num_objects, num_objects))
        nth_matrix[:, nth] = 1
        reference_object = np.zeros((num_objects, num_objects))
        reference_object[:, reference] = 1

        inputs = np.concatenate([data, object_ids, reference_object, nth_matrix],
                                axis=-1)
        inputs = np.random.permutation(inputs)
        labels = np.expand_dims(labels, axis=0)
        return inputs.astype(np.float32), labels.astype(np.long)
    
    def __getitem__(self, index):
        inputs, labels = self._get_single_set(self._num_objects, self._num_features)
        
        if self._transform is not None:
            inputs = self._transform(inputs)
        if self._target_transform is not None:
            labels = self._target_transform(labels)

        return inputs, labels

    def __len__(self):
        return batch_size*epochs

In [18]:
mem_slots = 4
head_size = 2048

In [27]:
batch_size = 1600
epochs = 1000000

learning_rate = 1e-2

num_objects = 4
num_features = 4
input_size = num_features + 3 * num_objects

mlp_size = 256
mlp_layers = 4

In [28]:
n_furthest = NthFarthest(num_objects=num_objects,
                         num_features=num_features, 
                         batch_size=batch_size, 
                         epochs=epochs)

In [29]:
train_loader = torch.utils.data.DataLoader(batch_size=batch_size, 
                                           dataset=n_furthest)
test_loader = torch.utils.data.DataLoader(batch_size=batch_size,
                                          dataset=n_furthest)

In [30]:
class SequenceModel(nn.Module):
    def __init__(self, input_size, mem_slots, head_size, batch_size,
                 mlp_size, mlp_layers, num_objects):
        super(SequenceModel, self).__init__()
        
        self._core = RelationalMemory(input_size=input_size,
                                      mem_slots=mem_slots, 
                                      head_size=head_size)
        self.initial_memory = self._core.initial_state(batch_size=batch_size)
        
        final_mlp_modules = nn.ModuleList(
            [nn.Sequential(
                nn.Linear(in_features=self._core._mem_size * self._core._mem_slots,
                          out_features=mlp_size),
                nn.ReLU())] +
            [nn.Sequential(
                nn.Linear(in_features=mlp_size,
                          out_features=mlp_size),
                nn.ReLU())] * (mlp_layers - 2) +
            [nn.Linear(in_features=mlp_size,
                       out_features=mlp_size)]
        )
        self._final_mlp = nn.Sequential(*final_mlp_modules)
        
        self._linear = nn.Linear(in_features=mlp_size,
                                 out_features=num_objects)
        
    # inputs: [B, T, F]
    def forward(self, inputs, memory):
        output_sequence, output_memory = self._core(inputs, memory)
        outputs = output_sequence[:, -1, :].unsqueeze(1)
        
        outputs = self._final_mlp(outputs)
        logits = self._linear(outputs)
        
        return logits, output_memory

In [31]:
model = SequenceModel(input_size=input_size,
                      mem_slots=mem_slots, 
                      head_size=head_size, 
                      batch_size=batch_size,
                      mlp_size=mlp_size, 
                      mlp_layers=mlp_layers, 
                      num_objects=num_objects).to(device)

In [32]:
optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [33]:
criterion = nn.CrossEntropyLoss()

In [34]:
model.train()
memory = model.initial_memory

for i, (inputs, labels) in enumerate(train_loader):
    inputs, labels = inputs.to(device), labels.to(device)
    optimiser.zero_grad()
    output, output_memory = model(inputs, memory)
    
    loss = criterion(output.squeeze(1), labels.squeeze(1))
    loss.backward()
    
    print("loss: ", loss.cpu().item())
    
    memory = output_memory.detach()

loss:  1.3867828845977783
loss:  1.388818621635437
loss:  1.388674259185791
loss:  1.3847616910934448
loss:  1.3875209093093872
loss:  1.3877465724945068
loss:  1.386612892150879
loss:  1.3863935470581055
loss:  1.387014627456665
loss:  1.3877519369125366
loss:  1.3856565952301025
loss:  1.388579249382019
loss:  1.388862133026123
loss:  1.3882534503936768
loss:  1.3880913257598877
loss:  1.3872520923614502
loss:  1.3843721151351929
loss:  1.386006474494934
loss:  1.3887149095535278
loss:  1.3873852491378784
loss:  1.3886222839355469
loss:  1.3893730640411377
loss:  1.3872851133346558
loss:  1.3868238925933838
loss:  1.3878239393234253
loss:  1.3880730867385864
loss:  1.3896894454956055
loss:  1.3878532648086548
loss:  1.3880445957183838
loss:  1.3889265060424805
loss:  1.3878422975540161
loss:  1.387606143951416
loss:  1.3861229419708252
loss:  1.3868069648742676
loss:  1.389318585395813
loss:  1.388189673423767
loss:  1.3883622884750366
loss:  1.3864812850952148
loss:  1.3886764049530

KeyboardInterrupt: 