In [1]:
%matplotlib inline
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import torch.utils.data as d
import random

In [None]:
def _split_cols(mat, lengths):
    """Split a 2D matrix to variable length columns."""
    assert mat.size()[1] == sum(lengths), "Lengths must be summed to num columns"
    l = np.cumsum([0] + lengths)
    results = []
    for s, e in zip(l[:-1], l[1:]):
        results += [mat[:, s:e]]
    return results

class NTM_Head(nn.Module):
    def __init__(self, memory, controller_output_size,sharp_on):
        super(NTM_Head, self).__init__()
        
        self.memory = memory
        self.controller_output_size = controller_output_size
        self.N, self.M = memory.size()
        self.sharp_on = sharp_on
    
    def is_read_head(self):
        raise NotImplementedError
    
    def initialize_parameters(self):
        raise NotImplementedError
    
    def initialize_state(self):
        raise NotImplementedError
        
    def address_memory(self, key_vec, β, g, γ):
        result = F.cosine_similarity(key_vec.unsqueeze(0).expand(self.N, -1), self.memory, dim = 0)
        result = β * result
        result = result.exp() / result.sum()
        result = g * result + (1 - g) * self.prev_address_vector
        if self.sharp_on:
            result = result ** γ
            result = result / result.sum()
        return resultclass NTM_Read_Head(NTM_Head):
    def __init__(self, memory, controller_output_size):
        super(NTM_Read_Head).__init__(memory, controller_output_size)
        
        self.read_parameters_lengths = [self.N, 1, 1, 1]
        self.fc_read_parameters = nn.Linear(controller_output_size, sum(read_parameters_lengths))
        
        initialize_parameters()
        initialize_state()
    
    def reset_parameters(self):
        nn.init.xavier_uniform(self.fc_read_parameters.weight, gain=1.4)
        nn.init.normal(self.fc_read_parameters.bias, std=0.01)
    
    def initialize_state(self):
        self.initial_address_vector = nn.Parameter(torch.zeros(self.N))
        self.initial_read = nn.Parameter(torch.zeros(self.M))
        
        self.prev_address_vector = Variable(self.initial_address_vector)
        self.prev_read = Variable(self.initial_read)
    
    def is_read_head(self):
        return True
    
    def forward(self, x):
        read_parameters = self.fc_read_parameters(x)
        
        key_vec, β, g, γ = _split_cols(read_parameters, self.read_parameters_length)
        β = F.softplus(β)
        g = F.sigmoid(g)
        γ = 1 + F.softplus(γ)
                                               
        address_vec = address_memory(key_vec, β, g, γ)
        new_read = self.M.transpose() * address_vec
        self.prev_read = new_read
        return new_read


In [None]:
class NTM_Read_Head(NTM_Head):
    def __init__(self, memory, controller_output_size, sharp_on):
        super(NTM_Read_Head).__init__(memory, controller_output_size, sharp_on)
        
        self.read_parameters_lengths = [self.N, 1, 1, 1]
        self.fc_read_parameters = nn.Linear(controller_output_size, sum(read_parameters_lengths))
        
        initialize_parameters()
        initialize_state()
    
    def reset_parameters(self):
        nn.init.xavier_uniform(self.fc_read_parameters.weight, gain=1.4)
        nn.init.normal(self.fc_read_parameters.bias, std=0.01)
    
    def initialize_state(self):
        self.initial_address_vector = nn.Parameter(torch.zeros(self.N))
        self.initial_read = nn.Parameter(torch.zeros(self.M))
        
        self.prev_address_vector = Variable(self.initial_address_vector)
        self.prev_read = Variable(self.initial_read)
    
    def is_read_head(self):
        return True
    
    def forward(self, x):
        read_parameters = self.fc_read_parameters(x)
        
        key_vec, β, g, γ = _split_cols(read_parameters, self.read_parameters_length)
        β = F.softplus(β)
        g = F.sigmoid(g)
        γ = 1 + F.softplus(γ)
                                               
        address_vec = address_memory(key_vec, β, g, γ)
        new_read = self.M.transpose() * address_vec
        self.prev_read = new_read
        return new_read

In [None]:
class NTM_Write_Head(NTM_Head):
    def __init__(self, memory, controller_output_size,sharp_on):
        super(NTM_Write_Head).__init__(memory, controller_output_size, sharp_on)
        
        self.write_parameters_lengths = [self.N, 1, 1, 1, self.M, self.M]
        self.fc_write_parameters = nn.Linear(controller_output_size, sum(write_parameters_lengths))
        self.sharp_on = sharp_on
        
        reset_parameters()
        initialize_state()
    
    def reset_parameters(self):
        nn.init.xavier_uniform(self.fc_write_parameters.weight, gain=1.4)
        nn.init.normal(self.fc_write_parameters.bias, std=0.01)
    
    def initialize_state(self):
        self.initial_address_vector = nn.Parameter(torch.zeros(self.N))       
        self.prev_address_vector = Variable(self.initial_address_vector)
    
    def is_read_head(self):
        return False
    
    def forward(self, x):
        write_parameters = self.fc_write_parameters(x)
                                       
        key_vec, β, g, γ, erase_vec, add_vec = _split_cols(write_parameters, self.write_parameters_lengths)
        β = F.softplus(β)
        g = F.sigmoid(g)
#         γ = 1 + F.softplus(γ)
        erase_vec = F.sigmoid(erase_vec)
                                               
        address_vec = address_memory(key_vec, β, g, γ)
        self.M *= 1 - torch.bmm(address_vec.unsqueeze(2), erase_vec.unsquueze(1))
        self.M += torch.bmm(address_vec.unsqueeze(2), add_vec.unsquueze(1))

In [None]:
class EncapsulatedLSTM(nn.Module):
    def __init__(self, batch_size, *args, **kwargs):
        super(EncapsulatedLSTM, self).__init__()
        self.lstm = nn.LSTM(*args, **kwargs)
        hidden_size = args[1]
        num_layers = args[2]
        
        initial_hidden_state = nn.Parameter(torch.randn(num_layers, 1, hidden_size))
        initial_hidden_state.expand(num_layers, batch_size, hidden_size)
        initial_cell_state = nn.Parameter(torch.randn(num_layers, 1, hidden_size))
        initial_cell_state.expand(num_layers, batch_size, hidden_size)
        self.state_tuple = (initial_hidden_state, initial_cell_state)
        
    def forward(self, input):
        output, self.state_tuple = self.lstm(input, self.state_tuple)
        return output

In [None]:
class DNC(nn.Module):
    def __init__(self, controller, controller_output_size, output_size, 
                 address_count, address_dimension, heads, sharp_on=True):
        super(DNC, self).__init__()
        
        # Initialize memory
        self.memory = nn.Parameter(torch.zeros(address_count, address_dimension))
        
        # Construct the heads.
        self.heads = nn.ModuleList()
        
        for head_id in heads:
            if head_id = 0:
                self.heads.append(NTM_Read_Head(memory, controller_output_size))
            else:
                self.heads.append(NTM_Write_Head(memory, controller_output_size))
        
        # Initialize controller
        self.controller = controller
        self.outputGate = nn.Linear(controller_output_size, output_size)
        self.softmax = nn.Softmax()
        self.sharp_on = sharp_on
        
        initialize_state()
        reset_parameters()
    
    def initialize_state(self):
        self.prev_reads = []
        
        for head in self.heads:
            if head.is_read_head():
                self.prev_reads.append(head.prev_read)
        
    def reset_parameters(self):
        stdev = 1 / (np.sqrt(N + M))
        nn.init.uniform(self.memory, -stdev, stdev)
        
    def forward(self, x):
        controller_output = controller(torch.cat(self.prev_reads.append(x), dim=1))
        self.prev_reads = []
        
        for head in self.heads:
            if head.is_read_head():
                self.prev_reads.append(head(controller_output))
            else:
                head(controller_output)
        return self.softmax(self.outputGate(controller_output))
        