In [1]:
%matplotlib inline

In [1]:
# A lot of inspiration from https://github.com/loudinthecloud/pytorch-ntm. Hyperparameters were chosen based
# upon his experiments.

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

import numpy as np
import random

In [2]:
def init_seed(seed=None):
    """Seed the RNGs for predicatability/reproduction purposes."""
    if seed is None:
        seed = int(get_ms() // 1000)

    print("Using seed=%d", seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)

In [3]:
class DNC_Memory(nn.Module):
    def __init__(self, address_count, address_dimension):
        super(DNC_Memory, self).__init__()
        self.initial_memory = nn.Parameter(torch.zeros(1, address_count, address_dimension))
        reset_parameters()
        initialize_state()
    
    def reset_parameters(self):
        _, N, M = self.initial_memory.size()
        stdev = 1 / np.sqrt(N + M)
        nn.init.uniform(self.initial_memory, -stdev, stdev)
    
    def initialize_state(self):
        self.memory = self.initial_memory.repeat(self.batch_size, 1, 1)
    
    def content_address_memory(self, key_vec, prev_address_vec, β, γ, sharp_on):
        result = F.cosine_similarity(key_vec.unsqueeze(1).expand_as(self.memory), 
                                     self.memory, dim = 2)
        result = β * result
        result = result.exp() 
        result = result / result.sum()
        
        if sharp_on:
            result = result ** γ
            result = result / result.sum()
        
        return result
    
    def read_memory(self, address_vec):
        return torch.bmm(self.memory.transpose(1,2), address_vec.unsqueeze(2)).squeeze()
    
    def update_memory(self, address_vec, erase_vec, add_vec):
        self.memory = self.memory * (1 - torch.bmm(address_vec.unsqueeze(2), erase_vec.unsqueeze(1)))
        self.memory += torch.bmm(address_vec.unsqueeze(2), add_vec.unsqueeze(1))

In [None]:
class DNC_Usage(nn.Module):
    def __init__(self, address_count, batch_size):
        super(DNC_Usage, self).__init__()
        self.initial_usage = Variable(torch.zeros(1, address_count))
        self.batch_size = batch_size
        self.initialize_state()
        
    def initialize_state(self):
        self.usage = self.initial_usage.repeat(self.batch_size, 1)
    
    def read_update_usage(self, address_vec, rfree_weights):
        self.usage *= 1 - rfree_weights * address_vec
    
    def write_update_usage(self, address_vec):
        self.usage += (1 - self.usage) * address_vec
    
    def allocation_weights(self):
        sorted_usage, indices_usage = torch.sort(self.usage)
        prod_sorted_usage = torch.cumprod(torch.cat((Variable(torch.ones(self.batch_size, 1)), 
                                                     sorted_usage), dim=1), dim=1)[:, :-1]
        sorted_allocation = (1 - sorted_usage) * prod_sorted_usage
        return sorted_allocation.gather(1, indices_usage)

In [4]:
def _split_cols(mat, lengths):
    """Split a 2D matrix to variable length columns."""
    assert mat.size()[1] == sum(lengths), "Lengths must be summed to num columns"
    l = np.cumsum([0] + lengths)
    results = []
    for s, e in zip(l[:-1], l[1:]):
        results += [mat[:, s:e]]
    return results

class DNC_Head(nn.Module):
    def __init__(self, address_count, address_dimension, 
                 controller_output_size):
        super(DNC_Head, self).__init__()
        
        self.controller_output_size = controller_output_size
        self.N = address_count
        self.M = address_dimension
    
    def is_read_head(self):
        raise NotImplementedError
    
    def reset_parameters(self):
        raise NotImplementedError
    
    def initialize_state(self):
        raise NotImplementedError

In [5]:
class DNC_Read_Head(DNC_Head):
    def __init__(self, controller_output_size, batch_size, num_write_heads):
        super(DNC_Read_Head, self).__init__(controller_output_size)
        # key_vec, β, γ, read_mode, rfree_gate
        #self.M is the number of rows
        self.num_write_heads = num_write_heads
        self.read_parameters_lengths = [self.M, 1, 1, 2 * self.num_write_heads + 1, self.N]
        self.fc_read_parameters = nn.Linear(controller_output_size, sum(self.read_parameters_lengths))
        
        self.batch_size = batch_size

        self.reset_parameters()
        self.initialize_state()
    
    def reset_parameters(self):
        nn.init.xavier_uniform(self.fc_read_parameters.weight, gain=1.4)
        nn.init.normal(self.fc_read_parameters.bias, std=0.01)
        
        self.initial_address_vector = Variable(1, torch.zeros(self.N))
        self.initial_read = nn.Parameter(torch.randn(1, self.M) * 0.01)
    
    def initialize_state(self):
        self.prev_address_vec = self.initial_address_vec.repeat(self.batch_size, 1)
        self.prev_read = self.initial_read.repeat(self.batch_size, 1)
    
    def is_read_head(self):
        return True
    
    def forward(self, x, memory, usage, linked_matrices):
        read_parameters = self.fc_read_parameters(x)
        
        key_vec, β, γ, read_modes, rfree_gate = _split_cols(read_parameters, self.read_parameters_lengths)
        β = F.softplus(β)
        γ = 1 + F.softplus(γ)
        rfree_gate = F.sigmoid(rfree_gate)
        read_modes = F.softmax(read_modes)
        content_address_vec = memory.content_address_memory(key_vec, self.prev_address_vec, β, γ, False)
        
        forward, backward = []
        
        address_vec = content_address_vec * read_modes[0]
        
        for i, linked_matrix in enumerate(linked_matrices):
            address_vec += read_modes[i+1] + torch.bmm(linked_matrix, self.prev_address_vec.unsquueze(2)).squeeze()
            address_vec += read_modes[i+self.num_write_heads+1] * torch.bmm(linked_matrix.transpose(1,2), 
                                                                            self.prev_address_vec.unsqueeze(2)).squeeze()
        
        new_read = memory.read_memory(address_vec)
        self.prev_read = new_read
        return new_read

In [6]:
class DNC_Write_Head(DNC_Head):
    def __init__(self, controller_output_size, batch_size):
        super(DNC_Write_Head, self).__init__(controller_output_size)
        self.batch_size = batch_size
        self.write_parameters_lengths = [self.M, 1, 1, 1, self.M, self.M]
        self.fc_write_parameters = nn.Linear(controller_output_size, sum(self.write_parameters_lengths))
        
        self.reset_parameters()
        self.initialize_state()
        
        #initialize the linked matrix
        self.linked_matrix = torch.zeros(batch_size, self.M, self.M)
    
    def reset_parameters(self):
        nn.init.xavier_uniform(self.fc_write_parameters.weight, gain=1.4)
        nn.init.normal(self.fc_write_parameters.bias, std=0.01)
        
        self.initial_address_vec = nn.Parameter(torch.zeros(self.N))
    
    def initialize_state(self):       
        self.prev_address_vec = self.initial_address_vec
    
    def is_read_head(self):
        return False
    
    def forward(self, x, memory):
        write_parameters = self.fc_write_parameters(x)
                                       
        key_vec, β, g, γ, erase_vec, add_vec = _split_cols(write_parameters, self.write_parameters_lengths)
        β = F.softplus(β)
        g = F.sigmoid(g)
        γ = 1 + F.softplus(γ)
        erase_vec = F.sigmoid(erase_vec)
                                               
        address_vec = memory.address_memory(key_vec, self.prev_address_vec, β, g, γ)
        memory.upate_memory(address_vec, erase_vec, add_vec)

In [7]:
class DNC(nn.Module):
    def __init__(self, batch_size, controller, controller_output_size, 
                 output_size, address_count, address_dimension, heads):
        super(DNC, self).__init__()
        
        self.batch_size = batch_size
        
        # Initialize controller
        self.controller = controller
        
        # Create output gate. No activation function is used with it because
        # I used BCEWithLogitsLoss which deals with the sigmoid in a more
        # numerically stable manner.
        self.outputGate = nn.Linear(controller_output_size, output_size)
        
        # Initialize memory
        self.memory = DNC_Memory(address_count, address_dimension)

        # Construct the heads.
        self.heads = nn.ModuleList()
        
        # Initialize usage vector, might not need batch size
        self.usage = torch.zeros(batch_size, address_count)
        num_writes = heads.count(1)
        
        for head_id in heads:
            if head_id == 0:
                self.heads.append(DNC_Read_Head(self, controller_output_size, batch_size, num_writes))
            else:
                self.heads.append(DNC_Write_Head(self, controller_output_size))
        
        self.initialize_state()
        
    def initialize_state(self):
        self.prev_reads = []
        
        for head in self.heads:
            head.initialize_state()
            
            if head.is_read_head():
                self.prev_reads.append(head.prev_read)
        
        self.memory.initialize_state()
        
    def reset_parameters(self):
        nn.init.xavier_uniform(self.outputGate.weight)
        nn.init.normal(self.outputGate.bias, std=0.01)
        
    def forward(self, x):
        outputs = []
        
        for current_observation in x.transpose(0,1):
            self.prev_reads.append(current_observation)
            controller_input = torch.cat(self.prev_reads, 1)
            controller_output = self.controller(controller_input).squeeze()

            self.prev_reads = []

            for head in self.heads:                
                if head.is_read_head():
                    self.prev_reads.append(head(controller_output, self.memory))
                else:
                    head(controller_output, self.memory)
                    
            current_output = self.outputGate(controller_output)
            outputs.append(current_output)
        
        return torch.stack(outputs)

In [8]:
# TODO: Have the dataset also return labels.

class CopyTaskDataset(data.Dataset):
    def __init__(self, num_batches, batch_size, lower, upper, seq_size):
        self.inputs_list = []
        self.label_list = []
        
        for _ in range(num_batches):
            
            self.input_list.append()
        
        self.batch_size = batch_size

    def generate_batch(self, batch_size, lower, upper, seq_size):
        seq_length = random.randint(lower, upper)
        seq = torch.from_numpy(
                np.random.binomial(1, 0.5, (seq_length, batch_size, seq_size)))
        end_marker = torch.zeros(seq_length, batch_size, 1)
        seq = torch.cat((seq.float(), end_marker), 2)
        delimiter_column = torch.zeros(1, batch_size, seq_size+1)
        delimiter_column[0, :, seq_size] = 1
        seq = torch.cat((seq, delimiter_column), 0)
        output_time = torch.zeros(seq_length, batch_size, seq_size+1)
        seq = torch.cat((seq, output_time), 0)
        return seq
    
    def __len__(self):
        return len(self.input_list)*self.batch_size
    
    def __getitem__(self, i):
        return self.input_list[i//self.batch_size][:, i % self.batch_size, :]

In [9]:
class EncapsulatedLSTM(nn.Module):
    def __init__(self, batch_size, all_hiddens, *args, **kwargs):
        super(EncapsulatedLSTM, self).__init__()
        self.lstm = nn.LSTM(*args, **kwargs)
        self.all_hiddens = all_hiddens
        self.batch_size = batch_size
                
        self.num_inputs = args[0]
        self.hidden_size = args[1]
        self.num_layers = args[2]
        
        self.reset_parameters()
        self.initialize_state()
          
    def initialize_state(self):
        self.state_tuple = (self.initial_hidden_state.repeat(1, self.batch_size, 1), 
                            self.initial_cell_state.repeat(1, self.batch_size, 1))
    
    def reset_parameters(self):
        self.initial_hidden_state = nn.Parameter(torch.randn(self.num_layers, 1, self.hidden_size) * 0.05)
        self.initial_cell_state = nn.Parameter(torch.randn(self.num_layers, 1, self.hidden_size) * 0.05)
        
        for p in self.lstm.parameters():
            if p.dim() == 1:
                nn.init.constant(p, 0)
            else:
                stdev = 5 / (np.sqrt(self.num_inputs +  self.hidden_size))
                nn.init.uniform(p, -stdev, stdev)
        
    def forward(self, input):
        output, self.state_tuple = self.lstm(input.unsqueeze(0), self.state_tuple)
        
        if self.all_hiddens:
            return self.state_tuple[0]
        else:
            return output

In [10]:
# TODO: Implement the loss function. Just use BCE logits form and cut out the right part of the outputs.
def copy_task_loss(output, label):
    pass

In [11]:
# TODO: Generalize the train_model function a bit to accomodate gradient clipping.

In [12]:
batch_size = 64
hidden_size = 100
num_layers = 3
seq_size = 8
address_size = 20
controller = EncapsulatedLSTM(batch_size, False, # all hiddens
                              seq_size + address_size + 1, hidden_size, 
                              num_layers)

In [13]:
address_count = 128
controller_output_size = hidden_size

dnc = DNC(batch_size, controller, controller_output_size, 
          seq_size, address_count, address_size, [0, 1])   

AttributeError: cannot assign parameters before Module.__init__() call

In [100]:
lower_seq_length = 3
upper_seq_length = 10
num_batches = 600

dataset = CopyTaskDataset(num_batches, batch_size, lower_seq_length, upper_seq_length, seq_size)
data_loader = data.DataLoader(dataset, batch_size=batch_size)

In [105]:
optimizer = optim.RMSprop(dnc.parameters(), momentum=0.9,
                          alpha=0.95, lr=1e-4)

In [93]:
def trainAll(num_batches, batch_size=64, hidden_size=100, 
             num_layers=3, lower_seq_length=3, upper_seq_length=10, seq_size=8,
             address_count=128, address_size=20):
    # controller, controller_output_size, output_size, 
    # address_count, address_dimension, heads
    controller = EncapsulatedLSTM(batch_size, False, # all hiddens
                                  seq_size + address_size + 1, hidden_size, 
                                  num_layers)
    controller_output_size = hidden_size
    
    dnc = DNC(batch_size, controller, controller_output_size, 
              seq_size, address_count, address_size, [0, 1])   
    dataset = CopyTaskDataset(num_batches, batch_size, lower_seq_length, upper_seq_length, seq_size)
    
    data_loader = data.DataLoader(dataset, batch_size=batch_size)
    
    for batch in data_loader:
        print(dnc(batch))

In [94]:
trainAll(1)

Variable containing:
(0 ,.,.) = 
  0.5173  0.5048  0.4749  ...   0.5030  0.5214  0.5072
  0.5027  0.5128  0.4643  ...   0.4834  0.5122  0.5119
  0.4870  0.5245  0.4737  ...   0.4905  0.4966  0.5319
           ...             ⋱             ...          
  0.4862  0.5271  0.4792  ...   0.4905  0.5020  0.5173
  0.4963  0.5169  0.4738  ...   0.5036  0.4933  0.5348
  0.4966  0.5282  0.4860  ...   0.4939  0.5077  0.5253

(1 ,.,.) = 
  0.4688  0.5047  0.4738  ...   0.5163  0.5263  0.5115
  0.4828  0.5057  0.4833  ...   0.4944  0.5436  0.5089
  0.4584  0.5407  0.5048  ...   0.5077  0.5040  0.5155
           ...             ⋱             ...          
  0.4573  0.5416  0.5043  ...   0.5027  0.5165  0.5056
  0.4559  0.5399  0.5036  ...   0.5106  0.4725  0.5086
  0.4671  0.5350  0.4996  ...   0.5046  0.5082  0.5121

(2 ,.,.) = 
  0.4073  0.5189  0.5012  ...   0.5110  0.5611  0.4872
  0.4238  0.4922  0.4849  ...   0.4829  0.5555  0.5075
  0.4273  0.5215  0.5224  ...   0.4918  0.5050  0.5312
      

In [12]:
dataset = CopyTaskDataset(1, 16, 3, 8, 3)

In [13]:
print(dataset.input_list[0])


(0 ,.,.) = 
   0   1   0   0
   1   1   1   0
   1   0   1   0
   1   0   1   0
   0   0   0   0
   1   1   0   0
   1   1   1   0
   0   1   0   0
   1   1   0   0
   1   1   1   0
   0   0   0   0
   0   1   1   0
   0   1   0   0
   1   1   1   0
   0   0   0   0
   0   1   1   0

(1 ,.,.) = 
   1   0   0   0
   0   1   0   0
   1   0   0   0
   1   1   0   0
   0   0   0   0
   1   1   0   0
   1   0   0   0
   0   0   1   0
   1   1   1   0
   1   1   0   0
   0   0   0   0
   0   1   1   0
   0   1   0   0
   1   0   1   0
   1   0   1   0
   0   0   0   0

(2 ,.,.) = 
   1   0   0   0
   1   0   1   0
   1   0   1   0
   1   1   1   0
   0   1   0   0
   0   1   0   0
   0   1   1   0
   1   1   0   0
   1   0   1   0
   1   0   0   0
   1   1   1   0
   0   1   1   0
   1   1   1   0
   1   1   1   0
   1   1   1   0
   1   0   1   0

(3 ,.,.) = 
   1   1   0   0
   1   0   1   0
   1   1   1   0
   0   0   0   0
   1   1   1   0
   0   0   0   0
   0   1   0   0
   0   1   0 