In [None]:
%matplotlib inline

In [45]:
import torch.nn as nn
from torch.autograd import Variable

n_input = 8 # Copy task (pg 10, https://arxiv.org/pdf/1410.5401.pdf)
n_hidden = 50 # Kinda random...
n_output = 8 # Copy task

class NTM(nn.Module):
    def __init__(self, mem_length, mem_width, read_heads, write_heads, key_strength):
        super(NTM, self).__init__()
        
        # Save useful params
        self.read_heads = read_heads
        self.write_heads = write_heads
        self.key_strength = key_strength      
        self.M = mem_length
        self.N = mem_width
        
        # Initialize controller
        self.lstm = myLSTM(n_input, [n_hidden, n_hidden], n_output, 2)
        
        # Initialize memory
        self.memory = Variable(torch.zeros(mem_length, mem_width))
        # Initialize read_weights
        self.read_weights = Variable(torch.zeros(mem_length))
        
        # Functions we'll need later
        self.cosine_similarity = nn.CosineSimilarity()
        self.softmax = nn.LogSoftmax()
    
    # Equation 2
    def read(self, weight_vec):
        # weight_vec is N; memory is NxM
        return torch.matmul(self.weight_vec * self.memory)
        
    def write(self, weight_vec, erase_vec, add_vec):
        erase_matrix = erase(weight_vec, erase_vec)
        self.add(weight_vec, add_vec, erase_matrix)
    
    # Equation 3
    def erase(self, weight_vec, erase_vec):
        # weight_vec is size N; erase_vec is M
        erase_weighting = torch.matmul(weight_vec, erase_vec.unsqueeze(0))
        # erase_wighting is NxM; memory is NxM
        erase_weighting = torch.ones(self.N, self.M) - erase_weighting
        return self.memory * erase_weighting
    
    # Equation 4
    def add(self, weight_vec, add_vec, erase_matrix):
        # weight_vec is size N; add_vec is M
        add_weighting = torch.matmul(weight_vec, add_vec.unsqueeze(0))
        # add_wighting is NxM; erase_matrix is NxM
        return self.erase_matrix + add_weighting
      
    def read_head(self, key_vecs, interpolation_gates, gammas, shift_weights):  
        read_vecs = torch.Tensor(self.read_heads)
        for i in range(self.read_heads):
            content_weights = self.content_focus(key_vecs[i], self.key_strength)
            combined_weights = self.location_focus(interpolation_gates[i], 
                                                   gammas[i], 
                                                   self.read_weights[i], 
                                                   self.content_weights[i], 
                                                   shift_weights[i])
            self.read_weights[i] = combined_weights
            read_vecs[i] = self.read(self.read_weights[i])
        return read_vecs
    
    def write_head(self, key_vecs, interpolation_gates, gammas, shift_weights, erase_vecs, add_vecs):
        for i in range(self.write_heads):
            content_weights = self.content_focus(key_vecs[i], self.key_strength)
            combined_weights = self.location_focus(interpolation_gates[i], 
                                                   gammas[i], 
                                                   self.write_weights[i], 
                                                   self.content_weights[i], 
                                                   shift_weights[i])
            self.write_weights[i] = combined_weights
            self.write(self.read_weights[i], erase_vec[i], add_vec[i])
    
    # Equations 5, 6
    def content_focus(self, key_vec, key_strength):
        # TODO - check that this operation is actually doing what we want!
        # key_vec.unsqueeze(0) is 1xM; memory is NxM
        sim_vec = self.cosine_similarity(key_vec.unsqueeze(0), self.memory)
        # sim_vec is a matrix of length N
        sim_vec = key_strength * sim_vec
        return self.softmax(sim_vec)
     
    def location_focus(self, g, gamma, old_weight, content_weight, shift_weights):
        # Equation 7
        gated_weight = g * content_weight + (1- g) * self.old_weight
        
        # Equation 8
        w_bar = torch.zeros(self.M)
        for i in range(self.M):
            for j in range(self.N):
                w_bar[i] += gated_weight[j] * shift_weights[i-j]
        
        # Equation 9
        weights_to_power = torch.pow(x, gamma)
        weights_power_sum = torch.sum(weights_to_power)
        return torch.mul(weights_to_power, 1/weights_power_sum)
    
    def parse_lstm_output(output):
        pass
        # TODO - get needed read/write head info from the controller. We need:
        
        #key_vecs (rh x M), interpolation_gates (rh), gammas (rh), shift_weights (rh x M)
                                 
        #key_vecs (wh x M), interpolation_gates (wh), gammas (wh), shift_weights (wh x M), erase_vecs (wh x M), add_vecs(wh x M),
        
        # rh=read_heads; wh=write_heads
    
    def backProp():
        pass
        
    def forward(self, input, label):
        
        # Concatenate inputs
        read_vec = torch.cat(self.read_vecs)
        combined_input = torch.cat(input, read_vec)
        
        # Call controller, get output
        output = lstm(combined_input)
        
        # Parse output
        self.parse_lstm_output(output)

        # Write
        self.write_head()
        self.read_head() # We should throw params in here.
        
        
# Create train function down here
        

In [26]:
import torch.nn as nn
from torch.autograd import Variable

class MyLSTMLayer(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(MyLSTMLayer, self).__init__()
        
        self.hidden_size = hidden_size
        
        self.forgetGate = nn.Linear(input_size + hidden_size, hidden_size)
        self.incorporatePositionGate = nn.Linear(input_size + hidden_size, hidden_size)
        self.incorporateValueGate = nn.Linear(input_size + hidden_size, hidden_size)
        self.hiddenValueGate = nn.Linear(hidden_size, hidden_size)
        
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def forward_step(self, input, hidden, cell_state):
        combined = torch.cat((input, hidden), 1)
        
        f = self.sigmoid(self.forgetGate(combined))
        i = self.sigmoid(self.incorporatePositionGate(combined))
        C_new = self.tanh(self.incorporateValueGate(combined))
        
        cell_state = f * cell_state + i * C_new
        
        hidden = self.relu(self.hiddenValueGate(cell_state))
        
        return hidden, cell_state
    
    def forward(self, input):
        hidden, cell_state = self.initAll()
        
        outputs = []
        
        for i in range(input.size()[0]):
            hidden, cell_state = self.forward_step(input[i], hidden, cell_state)
            outputs.append(hidden)
        
        return torch.stack(outputs)

    def initHidden(self):
        return Variable(torch.zeros(1, self.hidden_size)).cuda()
    
    def initCellState(self):
        return Variable(torch.zeros(1, self.hidden_size)).cuda()
    
    def initAll(self):
        return self.initHidden(), self.initCellState()

class myLSTM(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size, layers):
        super(myLSTM, self).__init__()
        
        self.lstm_layers = nn.ModuleList()
        input_sizes = [input_size] + hidden_sizes[:-1] 
        
        for input_size, hidden_size in zip(input_sizes, hidden_sizes):
            self.lstm_layers.append(MyLSTMLayer(input_size, hidden_size))
        
        self.outputGate = nn.Linear(hidden_sizes[-1], output_size)
        self.softmax = nn.LogSoftmax()

    def forward(self, input):
        hiddens = input
        
        for i, lstm_layer in enumerate(self.lstm_layers):
            hiddens = lstm_layer(hiddens)
        
        return self.softmax(self.outputGate(hiddens[-1]))
        