# Relational RNNs by Adam Santoro et al. in PyTorch



In [1]:
import numpy as np
from scipy.spatial import distance as spdistance

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class RelationalMemory(nn.Module):
    """
    Args:
        input_size: The size of the input features
        mem_slots: The total number of memory slots to use.
        head_size: The size of an attention head.
        num_heads: The number of attention heads to use. Defaults to 1.
        num_blocks: Number of times to compute attention per time step. Defaults to 1.
        forget_bias: 
        input_bias:
        gate_style:
        attention_mlp_layers:
        key_size:
    """
    def __init__(self, input_size,
                 mem_slots, head_size, num_heads=1, num_blocks=1,
                 forget_bias=1.0, input_bias=0.0, gate_style='unit',
                 attention_mlp_layers=2, key_size=None):
        super(RelationalMemory, self).__init__()
        
        self._mem_slots = mem_slots
        self._head_size = head_size
        self._num_heads = num_heads
        self._mem_size = self._head_size * self._num_heads

        if num_blocks < 1:
            raise ValueError('num_blocks must be >= 1. Got: {}.'.format(num_blocks))
        self._num_blocks = num_blocks

        self._forget_bias = forget_bias
        self._input_bias = input_bias

        if gate_style not in ['unit', 'memory', None]:
            raise ValueError(
                'gate_style must be one of [\'unit\', \'memory\', None]. Got: '
                '{}.'.format(gate_style))
        self._gate_style = gate_style

        if attention_mlp_layers < 1:
            raise ValueError('attention_mlp_layers must be >= 1. Got: {}.'.format(
                attention_mlp_layers))
        self._attention_mlp_layers = attention_mlp_layers

        self._key_size = key_size if key_size else self._head_size

        self._linear = nn.Linear(in_features=input_size,
                                 out_features=self._mem_size)
        
        self.qkv_size = 2 * self._key_size + self._head_size
        total_size = self.qkv_size * self._num_heads
        self._attention_linear = nn.Linear(in_features=self._mem_size,
                                           out_features=total_size)
        self._attention_layer_norm = nn.LayerNorm(total_size)
        
        attention_mlp_modules = nn.ModuleList([
            nn.Sequential(
                nn.Linear(in_features=self._mem_size,
                          out_features=self._mem_size),
                nn.ReLU())] * (self._attention_mlp_layers - 1) +
            [nn.Linear(in_features=self._mem_size,
                       out_features=self._mem_size)]
        )
        self._attention_mlp = nn.Sequential(*attention_mlp_modules)
        
        self._attend_layer_norm_1 = nn.LayerNorm(self._mem_size)
        self._attend_layer_norm_2 = nn.LayerNorm(self._mem_size)
        
        num_gates = 2 * self._calculate_gate_size()
        self._gate_inputs_linear = nn.Linear(in_features=self._mem_size,
                                             out_features=num_gates)
        
        self._gate_hidden_linear = nn.Linear(in_features=self._mem_size,
                                             out_features=num_gates)
        
    def initial_state(self, batch_size):
        """Creates the initial memory.

        We should ensure each row of the memory is initialized to be unique,
        so initialize the matrix to be the identity. We then pad or truncate
        as necessary so that init_state is of size 
        (batch_size, self._mem_slots, self._mem_size).

        Returns:
            init_state: A truncated or padded matrix of size
            (batch_size, self._mem_slots, self._mem_size).
        """    
        init_state = torch.eye(n=self._mem_slots).repeat(batch_size, 1, 1)
        
        if self._mem_size > self._mem_slots:
            # Pad the matrix with zeros.
            difference = self._mem_size - self._mem_slots
            pad = torch.zeros((batch_size, self._mem_slots, difference))
            init_state = torch.cat([init_state, pad], dim=-1)
        elif self._mem_size < self._mem_slots:
            # Truncation. Take the first `self._mem_size` components.
            init_state = init_state[:, :, :self._mem_size]
        
        return init_state.detach()

    def _multihead_attention(self, memory): # memory: [B, MEM_SLOT, MEM_SIZE]
        # F = total_size
        # mem_slots = MEM_SLOT = N
        mem_slots = memory.size(1)
        
        # [B, MEM_SLOT, MEM_SIZE] -> [B*MEM_SLOT, MEM_SIZE] -> Linear -> [B*MEM_SLOT, F]
        qkv = self._attention_linear(memory.view(-1, memory.size(2)))
        
        # [B*MEM_SLOT, F] -> Layer Norm -> [B*MEM_SLOT, F] -> [B, MEM_SLOT, F]
        qkv = self._attention_layer_norm(qkv).view(memory.size(0), mem_slots, -1)
        
        # H = num_heads
        #qkv_size = 2 * self._key_size + self._head_size
        # [B, N, F] -> [B, N, H, F/H]
        qkv_reshape = qkv.view(-1, mem_slots, self._num_heads, self.qkv_size)
        
        # [B, N, H, F/H] -> [B, H, N, F/H]
        qkv_transpose = qkv_reshape.permute(0, 2, 1, 3)
        # split q, k, v
        q, k, v = torch.split(qkv_transpose, [self._key_size, self._key_size, self._head_size], dim=-1)
        
        q *= self._key_size ** -0.5
        dot_product = torch.matmul(q, torch.transpose(k, 2, 3)) # [B, H, N, N]
        weights = F.softmax(dot_product, dim=-1)
        
        #[B, H, N, V]
        output = torch.matmul(weights, v)
        
        # [B, H, N, V] -> [B, N, H, V]
        output_transpose = output.permute(0, 2, 1, 3)
        
        # [B, N, H, V] -> [B, N, H * V]
        new_memory = output_transpose.contiguous().view(-1, output_transpose.size(1), 
                                                        output_transpose.size(2)*output_transpose.size(3))
        
        return new_memory #[B, MEM_SLOTS, MEM_SIZE]
    
    
    def _attend_over_memory(self, memory):
        # memory: [B, MEM_SLOT, MEM_SIZE]
        for _ in range(self._num_blocks):
            attended_memory = self._multihead_attention(memory) # [B, MEM_SLOT, MEM_SIZE]
            
            # add a skip connection the multiheaded attention's input.
            # memory = LN_1(memory + attended_memory) [B*MEM_SLOT, MEM_SIZE]
            memory = self._attend_layer_norm_1((memory + attended_memory).view(-1, memory.size(2)))
            
            # add a skip connection to the attention_mlp's input.
            # memory = LN_2( MLP(memory) + memory)
            memory = self._attend_layer_norm_2(self._attention_mlp(memory) + memory).view(-1, 
                                                                                          attended_memory.size(1),
                                                                                          attended_memory.size(2))
        return memory
    
    def _calculate_gate_size(self):
        if self._gate_style == 'unit':
            return self._mem_size
        elif self._gate_style == 'memory':
          return 1
        else:
            return 0
        
    def _create_gates(self, inputs, memory):
        hidden = torch.tanh(memory)
        
        #inputs [B, 1, MEM_SIZE] -> [B, 1*MEM_SIZE]
        inputs = inputs.view(inputs.size(0), -1)
        
        # [B, 1*MEM_SIZE] -> Linear -> [B, num_gates] -> [B, 1, num_gates]
        gate_inputs = self._gate_inputs_linear(inputs).unsqueeze(1)
        
        # hidden [B, MEM_SLOT, MEM_SIZE] -> [B*MEM_SLOT, MEM_SIZE] -> Linear -> [B*MEM_SLOT, num_gates]
        # -> [B, MEM_SLOT, num_gates]
        gate_hidden = self._gate_hidden_linear(hidden.view(-1, hidden.size(2))).view(hidden.size(0),
                                                                                     hidden.size(1),
                                                                                     -1)
        
        input_gate, forget_gate = torch.chunk(gate_hidden + gate_inputs, 2, dim=2)
        
        input_gate = torch.sigmoid(input_gate + self._input_bias)
        forget_gate = torch.sigmoid(forget_gate + self._forget_bias)
        
        return input_gate, forget_gate #[B, MEM_SLOT, num_gates/2], [B, MEM_SLOT, num_gates/2]
                                                   
    def forward(self, x, memory=None, treat_input_as_matrix=False):
        # x: [B, T, F=input_size]
        # memory: [B, MEM_SLOTS, MEM_SIZE]
        batch_size = x.size(0)
        total_timesteps = x.size(1)
        
        output_accumulator = x.new_zeros(batch_size, total_timesteps, self._mem_slots*self._mem_size)
        
        for index in range(total_timesteps):
            # For each time-step
            # inputs: [B, 1, F=input_size]
            inputs = x[:,index].unsqueeze(1)
            
            if treat_input_as_matrix:
                # [B, 1, F] -> [B*1, F] -> linear ->[B*1, MEM_SIZE] -> [B, 1, MEM_SIZE]
                inputs_reshape =  self._linear(inputs.view(-1, inputs.size(2))).view(inputs.size(0), 
                                                                                     -1, 
                                                                                     self._mem_size)
            else:
                # [B, 1, F] -> [B, 1*F] -> linear -> [B, 1*MEM_SIZE] -> [B, 1, MEM_SIZE]
                inputs = inputs.view(inputs.size(0), -1)
                inputs = self._linear(inputs)
                inputs_reshape = inputs.unsqueeze(1)

            # [B, MEM_SLOTS, MEM_SIZE] -> [B, MEM_SLOT+1, MEM_SIZE]
            memory_plus_input = torch.cat([memory, inputs_reshape], dim=1)

            next_memory = self._attend_over_memory(memory_plus_input)
            n = inputs_reshape.size(1)
            # [B, MEM_SLOT+1, MEM_SIZE] -> [B, MEM_SLOT, MEM_SIZE]
            next_memory = next_memory[:, :-n, :]

            if self._gate_style == 'unit' or self._gate_style == 'memory':
                input_gate, forget_gate = self._create_gates(inputs_reshape, memory) #[B, MEM_SLOT, num_gates/2] 
                next_memory = input_gate * torch.tanh(next_memory)
                next_memory += forget_gate * memory
            
            # output: [B, MEM_SLOT, MEM_SIZE] -> [B, MEM_SLOT*MEM_SIZE]
            output = next_memory.view(next_memory.size(0), -1)
            
            output_accumulator[:,index] = output
            memory = next_memory
        
        return output_accumulator, memory

In [7]:
class NthFarthest(Dataset):
    def __init__(self, num_objects, num_features, batch_size, epochs, transform=None, target_transform=None):
        super(NthFarthest, self).__init__()
        
        self._num_objects = num_objects
        self._num_features = num_features
        self._transform = transform
        self._target_transform = target_transform
        
    def _get_single_set(self, num_objects, num_features):
        # Generate random binary vectors
        data = np.random.uniform(-1, 1, size=(num_objects, num_features))

        distances = spdistance.squareform(spdistance.pdist(data))
        distance_idx = np.argsort(distances)

        # Choose random distance
        nth = np.random.randint(0, num_objects)

        # Pick out the nth furthest for each object
        nth_furthest = np.where(distance_idx == nth)[1]

        # Choose random reference object
        reference = np.random.randint(0, num_objects)

        # Get identity of object that is the nth furthest from reference object
        labels = nth_furthest[reference]

        # Compile data
        object_ids = np.identity(num_objects)
        nth_matrix = np.zeros((num_objects, num_objects))
        nth_matrix[:, nth] = 1
        reference_object = np.zeros((num_objects, num_objects))
        reference_object[:, reference] = 1

        inputs = np.concatenate([data, object_ids, reference_object, nth_matrix],
                                axis=-1)
        inputs = np.random.permutation(inputs)
        labels = np.expand_dims(labels, axis=0)
        return inputs.astype(np.float32), labels.astype(np.long)
    
    def __getitem__(self, index):
        inputs, labels = self._get_single_set(self._num_objects, self._num_features)
        
        if self._transform is not None:
            inputs = self._transform(inputs)
        if self._target_transform is not None:
            labels = self._target_transform(labels)

        return inputs, labels

    def __len__(self):
        return batch_size*epochs

In [27]:
mem_slots = 4
head_size = 2048

In [28]:
batch_size = 16
epochs = 1000000

learning_rate = 1e-3

num_objects = 2
num_features = 2
input_size = num_features + 3 * num_objects

mlp_size = 256
mlp_layers = 4

In [29]:
n_furthest = NthFarthest(num_objects=num_objects,
                         num_features=num_features, 
                         batch_size=batch_size, 
                         epochs=epochs)

In [30]:
train_loader = torch.utils.data.DataLoader(batch_size=batch_size, 
                                           dataset=n_furthest)
test_loader = torch.utils.data.DataLoader(batch_size=batch_size,
                                          dataset=n_furthest)

In [31]:
class SequenceModel(nn.Module):
    def __init__(self, input_size, mem_slots, head_size, batch_size,
                 mlp_size, mlp_layers, num_objects):
        super(SequenceModel, self).__init__()
        
        self._core = RelationalMemory(input_size=input_size,
                                      mem_slots=mem_slots, 
                                      head_size=head_size)
        self.initial_memory = self._core.initial_state(batch_size=batch_size)
        
        final_mlp_modules = nn.ModuleList(
            [nn.Sequential(
                nn.Linear(in_features=self._core._mem_size * self._core._mem_slots,
                          out_features=mlp_size),
                nn.ReLU())] +
            [nn.Sequential(
                nn.Linear(in_features=mlp_size,
                          out_features=mlp_size),
                nn.ReLU())] * (mlp_layers - 2) +
            [nn.Linear(in_features=mlp_size,
                       out_features=mlp_size)]
        )
        self._final_mlp = nn.Sequential(*final_mlp_modules)
        
        self._linear = nn.Linear(in_features=mlp_size,
                                 out_features=num_objects)
        
    # inputs: [B, T, F]
    def forward(self, inputs, memory):
        output_sequence, output_memory = self._core(inputs, memory)
        outputs = output_sequence[:, -1, :].unsqueeze(1)
        
        outputs = self._final_mlp(outputs)
        logits = self._linear(outputs)
        
        return logits, output_memory

In [32]:
model = SequenceModel(input_size=input_size,
                      mem_slots=mem_slots, 
                      head_size=head_size, 
                      batch_size=batch_size,
                      mlp_size=mlp_size, 
                      mlp_layers=mlp_layers, 
                      num_objects=num_objects).to(device)

In [33]:
optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [34]:
criterion = nn.CrossEntropyLoss()

In [35]:
model.train()
memory = model.initial_memory.to(device)

for i, (inputs, labels) in enumerate(train_loader):
    inputs, labels = inputs.to(device), labels.to(device)
    optimiser.zero_grad()
    output, output_memory = model(inputs, memory)
    
    loss = criterion(output.squeeze(1), labels.squeeze(1))
    loss.backward()
    
    print("loss: ", loss.cpu().item())
    
    memory = output_memory.detach()

loss:  0.6942892074584961
loss:  0.6935383081436157
loss:  0.6925140619277954
loss:  0.6945618987083435
loss:  0.6905921101570129
loss:  0.6945213675498962
loss:  0.6923407912254333
loss:  0.6921948790550232
loss:  0.6932103633880615
loss:  0.694120466709137
loss:  0.6937230825424194
loss:  0.6945542693138123
loss:  0.6945928335189819
loss:  0.6946831941604614
loss:  0.6922829151153564
loss:  0.6932870745658875
loss:  0.6935537457466125
loss:  0.6908676624298096
loss:  0.6940553784370422
loss:  0.6896588802337646
loss:  0.6898375749588013
loss:  0.6892287731170654
loss:  0.6900903582572937
loss:  0.6927318572998047
loss:  0.6960955262184143
loss:  0.6928524971008301
loss:  0.6903606057167053
loss:  0.6910958290100098
loss:  0.6878944039344788
loss:  0.6993302702903748
loss:  0.692302405834198
loss:  0.6904089450836182
loss:  0.6917802691459656
loss:  0.6916359663009644
loss:  0.6857774257659912
loss:  0.6941584944725037
loss:  0.6943195462226868
loss:  0.6914940476417542
loss:  0.68529

loss:  0.6912351250648499
loss:  0.689175009727478
loss:  0.698544979095459
loss:  0.6893028616905212
loss:  0.6962412595748901
loss:  0.6938909888267517
loss:  0.6864912509918213
loss:  0.6960169076919556
loss:  0.6963045597076416
loss:  0.6964896321296692
loss:  0.6982907652854919
loss:  0.6929759383201599
loss:  0.6953871250152588
loss:  0.6995837092399597
loss:  0.6931304335594177
loss:  0.6932345032691956
loss:  0.6984139680862427
loss:  0.6922567486763
loss:  0.6921203136444092
loss:  0.6946203708648682
loss:  0.6945300102233887
loss:  0.7013130187988281
loss:  0.6869336366653442
loss:  0.6879980564117432
loss:  0.6965294480323792
loss:  0.6940479278564453
loss:  0.6900445222854614
loss:  0.6925024390220642
loss:  0.6875962018966675
loss:  0.6960646510124207
loss:  0.6960728168487549
loss:  0.697056770324707
loss:  0.6928101181983948
loss:  0.69049471616745
loss:  0.6898995041847229
loss:  0.6950774192810059
loss:  0.6837461590766907
loss:  0.6899386644363403
loss:  0.69561594724

loss:  0.6887850165367126
loss:  0.6933727264404297
loss:  0.6965416073799133
loss:  0.6918244957923889
loss:  0.6885743737220764
loss:  0.6967795491218567
loss:  0.6963155269622803
loss:  0.6893298625946045
loss:  0.684684157371521
loss:  0.693168580532074
loss:  0.6973685026168823
loss:  0.6848653554916382
loss:  0.699556291103363
loss:  0.6861875057220459
loss:  0.6926037669181824
loss:  0.6986609101295471
loss:  0.70079505443573
loss:  0.7001266479492188
loss:  0.6875305771827698
loss:  0.6969872713088989
loss:  0.6914628744125366
loss:  0.6970230340957642
loss:  0.6944520473480225
loss:  0.6938812732696533
loss:  0.6973568797111511
loss:  0.6905354261398315
loss:  0.6966820955276489
loss:  0.6878684759140015
loss:  0.6967214345932007
loss:  0.6850898861885071
loss:  0.6861116886138916
loss:  0.6928758025169373
loss:  0.6992956399917603
loss:  0.6935221552848816
loss:  0.6884850859642029
loss:  0.6931564807891846
loss:  0.693383514881134
loss:  0.6945047974586487
loss:  0.699177861

loss:  0.6995431184768677
loss:  0.6934179663658142
loss:  0.6922742128372192
loss:  0.692824125289917
loss:  0.6947125196456909
loss:  0.695656418800354
loss:  0.6921481490135193
loss:  0.6966071128845215
loss:  0.6956300139427185
loss:  0.6927282810211182
loss:  0.6858585476875305
loss:  0.6935818195343018
loss:  0.6900894641876221
loss:  0.6900123357772827
loss:  0.6976837515830994
loss:  0.6886906623840332
loss:  0.6959361433982849
loss:  0.6873558163642883
loss:  0.6860365271568298
loss:  0.6904100179672241
loss:  0.6951452493667603
loss:  0.6909766793251038
loss:  0.6932520866394043
loss:  0.6924836039543152
loss:  0.6991391181945801
loss:  0.6923537254333496
loss:  0.6963680982589722
loss:  0.6929829120635986
loss:  0.690061628818512
loss:  0.6930991411209106
loss:  0.6935222148895264
loss:  0.6995489597320557
loss:  0.6941311955451965
loss:  0.6933112740516663
loss:  0.6943793892860413
loss:  0.6911531090736389
loss:  0.6919308304786682
loss:  0.6899223923683167
loss:  0.690049

loss:  0.6914925575256348
loss:  0.6950885653495789
loss:  0.6914606690406799
loss:  0.6920283436775208
loss:  0.6919340491294861
loss:  0.6947298049926758
loss:  0.698607861995697
loss:  0.6943261027336121
loss:  0.6917824149131775
loss:  0.7019763588905334
loss:  0.6904352903366089
loss:  0.6865838170051575
loss:  0.6930618286132812
loss:  0.6960722804069519
loss:  0.6984136700630188
loss:  0.6922347545623779
loss:  0.6981207728385925
loss:  0.6998845338821411
loss:  0.6925809979438782
loss:  0.6943039894104004
loss:  0.6938746571540833
loss:  0.6935257911682129
loss:  0.688777506351471
loss:  0.6884879469871521
loss:  0.6978108286857605
loss:  0.6981193423271179
loss:  0.6847246885299683
loss:  0.6980490684509277
loss:  0.6945961713790894
loss:  0.6947721838951111
loss:  0.6967014670372009
loss:  0.6997168660163879
loss:  0.6909884214401245
loss:  0.6921544075012207
loss:  0.6883059144020081
loss:  0.6995817422866821
loss:  0.7017131447792053
loss:  0.6942703723907471
loss:  0.69642

loss:  0.6892625093460083
loss:  0.6851351857185364
loss:  0.6881669163703918
loss:  0.6966465711593628
loss:  0.6993347406387329
loss:  0.6939584612846375
loss:  0.6954267024993896
loss:  0.6958765983581543
loss:  0.6840575337409973
loss:  0.6946519613265991
loss:  0.6876624822616577
loss:  0.6891136169433594
loss:  0.6930837631225586
loss:  0.6954424977302551
loss:  0.701802134513855
loss:  0.6865566372871399
loss:  0.6937385201454163
loss:  0.6980472803115845
loss:  0.697300910949707
loss:  0.6935340166091919
loss:  0.6957088708877563
loss:  0.6875801086425781
loss:  0.7050045132637024
loss:  0.7007207870483398
loss:  0.683353841304779
loss:  0.6876616477966309
loss:  0.6881241202354431
loss:  0.6931053400039673
loss:  0.6848940849304199
loss:  0.6881552934646606
loss:  0.698803186416626
loss:  0.6976219415664673
loss:  0.6918325424194336
loss:  0.6857380867004395
loss:  0.7007657885551453
loss:  0.6949061751365662
loss:  0.6991812586784363
loss:  0.6898170709609985
loss:  0.6954881

loss:  0.6900596618652344
loss:  0.6945833563804626
loss:  0.6877343058586121
loss:  0.6953093409538269
loss:  0.6940040588378906
loss:  0.6944358944892883
loss:  0.6971844434738159
loss:  0.7014098167419434
loss:  0.6955775022506714
loss:  0.6917064785957336
loss:  0.6893587708473206
loss:  0.6929960250854492
loss:  0.699982225894928
loss:  0.6965975761413574
loss:  0.6893919110298157
loss:  0.6968192458152771
loss:  0.6871902346611023
loss:  0.6861231923103333
loss:  0.6928264498710632
loss:  0.692450225353241
loss:  0.6899158954620361
loss:  0.7022206783294678
loss:  0.6965035200119019
loss:  0.6907980442047119
loss:  0.6901686191558838
loss:  0.6977433562278748
loss:  0.6864941120147705
loss:  0.6961396932601929
loss:  0.6993504166603088
loss:  0.6954563856124878
loss:  0.687395453453064
loss:  0.6868399381637573
loss:  0.6931779980659485
loss:  0.6855236291885376
loss:  0.6852506995201111
loss:  0.69000244140625
loss:  0.6972134113311768
loss:  0.6904947757720947
loss:  0.69240856

KeyboardInterrupt: 