In [1]:
%matplotlib inline

In [2]:
# A lot of inspiration from https://github.com/loudinthecloud/pytorch-ntm. Hyperparameters were chosen based
# upon his experiments.

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import torch.utils.data as d
import random

In [20]:
def _split_cols(mat, lengths):
    """Split a 2D matrix to variable length columns."""
    assert mat.size()[1] == sum(lengths), "Lengths must be summed to num columns"
    l = np.cumsum([0] + lengths)
    results = []
    for s, e in zip(l[:-1], l[1:]):
        results += [mat[:, s:e]]
    return results

class NTM_Head(nn.Module):
    def __init__(self, memory, controller_output_size):
        super(NTM_Head, self).__init__()
        
        self.memory = memory
        self.controller_output_size = controller_output_size
        _, self.N, self.M = memory.size()
    
    def is_read_head(self):
        raise NotImplementedError
    
    def initialize_parameters(self):
        raise NotImplementedError
    
    def initialize_state(self):
        raise NotImplementedError
        
    

In [21]:
class NTM_Read_Head(NTM_Head):
    def __init__(self, memory, controller_output_size, batch_size):
        super(NTM_Read_Head, self).__init__(memory, controller_output_size)
        #TODO: Get rid of g in read and write mode, to be replaced with usage vec
        # key_vec, β, g, γ, read_mode, rfree_gate
        #self.M is the number of rows
        self.read_parameters_lengths = [self.M, 1, 1, 1, 3, self.M]
        self.fc_read_parameters = nn.Linear(controller_output_size, sum(self.read_parameters_lengths))
        
        self.batch_size = batch_size
        
        self.initial_address_vector = nn.Parameter(torch.zeros(self.N))
        self.initial_read = nn.Parameter(torch.zeros(1, self.M))
        
        self.reset_parameters()
        self.initialize_state()
    
    def reset_parameters(self):
        nn.init.xavier_uniform(self.fc_read_parameters.weight, gain=1.4)
        nn.init.normal(self.fc_read_parameters.bias, std=0.01)
    
    def initialize_state(self):
        self.prev_address_vector = self.initial_address_vector
        self.prev_read = self.initial_read.repeat(self.batch_size, 1)
    
    def is_read_head(self):
        return True
    
    def forward(self, x):
        read_parameters = self.fc_read_parameters(x)
        # key_vec, β, g, γ, read_mode, rfree_gate
        key_vec, β, g, γ, read_mode, rfree_gate = _split_cols(read_parameters, self.read_parameters_lengths)
        β = F.softplus(β)
        g = F.sigmoid(g)
        read_mode = F.softmax(read_mode)
        rfree_gate = F.sigmoid(rfree_gate)
        γ = 1 + F.softplus(γ)
        
        print('H')
        print(key_vec)
        print('I')
        print(β)
        print('J')
        print(g)
        print('K1')
        print(read_mode)
        print('K2')
        print(rfree_gate)
        print('L')
        print(γ)
                                               
        address_vec = self.address_memory(key_vec, β, g, γ, read_mode, rfree_gate)
        
        print('M')
        print(address_vec)
        new_read = torch.bmm(self.memory.transpose(1,2), address_vec.unsqueeze(2)).squeeze()
        
        print('N')
        print(new_read)
        self.prev_read = new_read
        return new_read
    
    def address_memory(self, key_vec, β, g, γ, read_mode, rfree_gate):
        result = F.cosine_similarity(key_vec.unsqueeze(1).expand_as(self.memory), self.memory, dim = 2)
        
        print('R')
        print(result)
        
        result = β * result
        
        print('S')
        print(result)
        
        result = result.exp() / result.sum()
        
        print('T')
        print(result)
        
#        result = g * result + (1 - g) * self.prev_address_vector
        print('U')
        print(result)
        

        print('V')
        print(result)
        
        result = result ** γ
        print('W')
        print(result)
        
        result = result / result.sum()
        print('X')
        print(result)
        
        return result

In [22]:
class NTM_Write_Head(NTM_Head):
    def __init__(self, memory, controller_output_size):
        super(NTM_Write_Head, self).__init__(memory, controller_output_size)
        #write_parameters are, in order: key_vec, β, g, write_gate, γ, erase_vec, add_vec
        self.write_parameters_lengths = [self.M, 1, 1, 1, 1,  self.M, self.M]
        self.fc_write_parameters = nn.Linear(controller_output_size, sum(self.write_parameters_lengths))
        self.initial_address_vector = nn.Parameter(torch.zeros(self.N))
        
        self.reset_parameters()
        self.initialize_state()
    
    def reset_parameters(self):
        nn.init.xavier_uniform(self.fc_write_parameters.weight, gain=1.4)
        nn.init.normal(self.fc_write_parameters.bias, std=0.01)
    
    def initialize_state(self):       
        self.prev_address_vector = Variable(self.initial_address_vector.data)
    
    def is_read_head(self):
        return False
    
    def forward(self, x):
        write_parameters = self.fc_write_parameters(x)
                                       
        key_vec, β, g, write_gate, γ, erase_vec, add_vec = _split_cols(write_parameters, self.write_parameters_lengths)
        β = F.softplus(β)
        g = F.sigmoid(g)
        write_gate = F.softmax(write_gate)
        γ = 1 + F.softplus(γ)
        erase_vec = F.sigmoid(erase_vec)
        
        print('A')
        print(key_vec)
        print('B')
        print(β)
        print('C')
        print(g)
        print('D')
        print(s)
        print('E')
        print(γ)
        print('F')
        print(erase_vec)
        print(key_vec)
        print(β)
        print(g)
        print(write_gate)
        print(γ)
        print('G')
        print(add_vec)
                                               
        address_vec = self.address_memory(key_vec, β, g, write_gate, γ)
        self.memory = self.memory * (1 - torch.bmm(address_vec.unsqueeze(2), erase_vec.unsqueeze(1)))
        self.memory += torch.bmm(address_vec.unsqueeze(2), add_vec.unsqueeze(1))
        
        print('O')
        print(self.memory)
        print('P')
        print(address_vec)
        
    def address_memory(self, key_vec, β, g, write_gate, γ):
        result = F.cosine_similarity(key_vec.unsqueeze(1).expand_as(self.memory), self.memory, dim = 2)
        
        print('R')
        print(result)
        
        result = β * result
        
        print('S')
        print(result)
        
        result = result.exp()
        result = result / result.sum()
        
        print('T')
        print(result)
        
        result = g * result + (1 - g) * self.prev_address_vector
        print('U')
        print(result)
        
        result = torch.cat((result[:, 1:], result[:, :1]), 1) * s[:, 0:1] + result * s[:, 1:2] + \
                 torch.cat((result[:, -1:], result[:, :-1]), 1) * s[:, 2:3]
        print('V')
        print(result)
        
        result = result ** γ
        print('W')
        print(result)
        
        result = result / result.sum()
        print('X')
        print(result)
        ######## WHAT I CHANGED ############
        result = result * write_gate
        return result

In [15]:
class NTM(nn.Module):
    def __init__(self, batch_size, controller, controller_output_size, 
                 output_size, address_count, address_dimension, heads):
        super(NTM, self).__init__()
        
        self.batch_size = batch_size
        
        # Initialize memory
        self.initial_memory = nn.Parameter(torch.zeros(1, address_count, address_dimension))
        self.memory = self.initial_memory.repeat(batch_size, 1, 1)
        
        # Construct the heads.
        self.heads = nn.ModuleList()
        
        for head_id in heads:
            if head_id == 0:
                self.heads.append(NTM_Read_Head(self.memory, controller_output_size, batch_size))
            else:
                self.heads.append(NTM_Write_Head(self.memory, controller_output_size))
        
        # Initialize controller
        self.controller = controller
        self.outputGate = nn.Linear(controller_output_size, output_size)
        self.softmax = nn.Softmax()
        
        self.initialize_state()
        self.reset_parameters()
    
    def initialize_state(self):
        self.prev_reads = []
        
        for head in self.heads:
            head.initialize_state()
            
            if head.is_read_head():
                self.prev_reads.append(head.prev_read)
        
        self.memory = self.initial_memory.repeat(self.batch_size, 1, 1)
        
    def reset_parameters(self):
        _, N, M = self.initial_memory.size()
        stdev = 1 / np.sqrt(N + M)
        nn.init.uniform(self.initial_memory, -stdev, stdev)
        
    def forward(self, x):
        self.prev_reads.append(x)
        controller_output = self.controller(torch.cat(self.prev_reads, 1)).squeeze()
        print('Q')
        print(controller_output)
                
        self.prev_reads = []
        
        for head in self.heads:
            if head.is_read_head():
                self.prev_reads.append(head(controller_output))
            else:
                head(controller_output)
        
        return self.softmax(self.outputGate(controller_output))

In [7]:
class CopyTaskDataset(d.Dataset):
    def __init__(self, num_batches, batch_size, lower, upper, seq_size):
        self.input_list = []
        for _ in range(num_batches):
            self.input_list.append(self.generate_batch(batch_size, lower, upper, seq_size))
        
        self.batch_size = batch_size

    def generate_batch(self, batch_size, lower, upper, seq_size):
        seq_length = random.randint(lower, upper)
        batch = torch.from_numpy(
            np.random.binomial(1, 0.5, (seq_length, batch_size, seq_size)))
        end_marker = torch.zeros(seq_length - 1, batch_size, 1)
        end_row = torch.ones(1, batch_size, 1)
        end_marker = torch.cat((end_marker, end_row), 0)
        batch = torch.cat((batch.float(), end_marker), 2)
        return batch 
    
    def __len__(self):
        return len(self.input_list)*self.batch_size
    
    def __getitem__(self, i):
        return self.input_list[i//self.batch_size][:, i % self.batch_size, :]

In [8]:
class EncapsulatedLSTM(nn.Module):
    def __init__(self, batch_size, all_hiddens, *args, **kwargs):
        super(EncapsulatedLSTM, self).__init__()
        self.lstm = nn.LSTM(*args, **kwargs)
        self.all_hiddens = all_hiddens
        self.batch_size = batch_size
                
        hidden_size = args[1]
        num_layers = args[2]
        
        self.initial_hidden_state = nn.Parameter(torch.randn(num_layers, 1, hidden_size))
        self.initial_cell_state = nn.Parameter(torch.randn(num_layers, 1, hidden_size))
        self.initialize_state()
          
    def initialize_state(self):
        self.state_tuple = (self.initial_hidden_state.repeat(1, self.batch_size, 1), 
                            self.initial_cell_state.repeat(1, self.batch_size, 1))
        
    def forward(self, input):
        output, self.state_tuple = self.lstm(input.unsqueeze(0), self.state_tuple)
        
        if self.all_hiddens:
            return self.state_tuple[0]
        else:
            return output

In [9]:
def trainAll(num_batches, batch_size=64, hidden_size=100, 
             num_layers=3, lower_seq_length=3, upper_seq_length=10, seq_size=8,
             address_count=128, address_size=20):
    # controller, controller_output_size, output_size, 
    # address_count, address_dimension, heads
    controller = EncapsulatedLSTM(batch_size, False, # all hiddens
                                  seq_size + address_size + 1, hidden_size, 
                                  num_layers)
    controller_output_size = hidden_size
    
    ntm = NTM(batch_size, controller, controller_output_size, 
              seq_size, address_count, address_size, [0, 1])   
    dataset = CopyTaskDataset(num_batches, batch_size, lower_seq_length, upper_seq_length, seq_size)
    
    data_loader = d.DataLoader(dataset, batch_size=batch_size)
    
    for batch in data_loader:
        batch = batch.squeeze()
        sequence_length = batch.size()[1]

        # Pass in one element of the sequence per time step
        for time_step in range(sequence_length):
            ntm(batch[:, time_step])
            break
                
#             # Now, don't pass in any elements.
#             for time_step in range(len(batch)):
#                 ntm(None, batch[time_step][:,:-1])
#     print("DONE!")

In [23]:
trainAll(1)

Q
Variable containing:
 5.0101e-03  4.3148e-02 -3.0836e-05  ...   3.3804e-01 -1.5142e-01  1.5596e-01
 5.0960e-03  4.2850e-02  3.4896e-06  ...   3.3847e-01 -1.5143e-01  1.5555e-01
 5.1468e-03  4.2708e-02  1.0764e-04  ...   3.3880e-01 -1.5143e-01  1.5553e-01
                ...                   ⋱                   ...                
 5.2387e-03  4.2841e-02  1.4981e-04  ...   3.3864e-01 -1.5182e-01  1.5579e-01
 5.2937e-03  4.2686e-02 -3.2832e-05  ...   3.3874e-01 -1.5121e-01  1.5601e-01
 5.1746e-03  4.2826e-02 -4.3817e-05  ...   3.3818e-01 -1.5140e-01  1.5634e-01
[torch.FloatTensor of size 64x100]

H
Variable containing:
-0.2219 -0.1801  1.0618  ...   0.0314  0.2972  0.3293
-0.2217 -0.1812  1.0620  ...   0.0311  0.2957  0.3295
-0.2212 -0.1812  1.0619  ...   0.0313  0.2959  0.3291
          ...             ⋱             ...          
-0.2213 -0.1808  1.0618  ...   0.0312  0.2973  0.3289
-0.2215 -0.1807  1.0610  ...   0.0303  0.2945  0.3287
-0.2217 -0.1802  1.0605  ...   0.0308  0.2963  0

Variable containing:
( 0 ,.,.) = 
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan
     ...       ⋱       ...    
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan

( 1 ,.,.) = 
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan
     ...       ⋱       ...    
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan

( 2 ,.,.) = 
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan
     ...       ⋱       ...    
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan
... 

(61 ,.,.) = 
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan
     ...       ⋱       ...    
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan
 nan nan nan  ...  nan nan nan

(62 ,.,.) = 
 nan nan nan  ...  nan nan nan
 nan n