In [1]:
import torch
import torch.nn as nn

In [11]:
class Addressing:
    def __init__(self):
        pass
    
    def similarity(key, value):
        return torch.dot(key, value) / torch.sqrt(torch.dot(key, key) * torch.dot(value, value))
    
    def address_by_content(self, key_strength, key, memory):
        memory_normalized = nn.functional.normalize(memory, dim=1)
        key_normalized = nn.functional.normalize(key)
        
        result = torch.einsum("nm, m->n", memory_normalized, key_normalized)
        result *= key_strength
        
        result = nn.functional.softmax(result, dim=0)
        
        return result
    
    def interpolate(self, gate, weights_old, weights):
        result = gate * weights + (1 - gate) * weights_old
        return result
    
    def shift_by_convolution(self, shift, weights):
        result = nn.functional.conv1d(weights, shift)
        return result
    
    def sharpen(self, gamma, weights):
        result = torch.pow(weights, gamma)
        result /= torch.sum(weights)
        return result
    
    def get_weights(self, weights, memory, key, key_strength, gate, shift, gamma):
        result = self.address_by_content(key_strength, key, memory)
        result = self.interpolate(gate, weights, result)
        result = self.shift_by_convolution(shift, result)
        result = self.sharpen(gamma, result)
        
        return result
    
class Linear(nn.Module):
    def __init__(self, d_in=None, d_out=None):
        self.W = nn.Parameter(torch.randn(d_out, d_in))
        self.b = nn.Parameter(torch.randn(d_out))
        
    def forward(self, x):
        out = x
        out = torch.einsum("oi, ni -> no", self.W, out)
        out += self.b
        
        return out

class Controller(nn.Module):
    def __init__(self, n_layers=None, d_input=None, d_hidden=None, d_output=None):
        self.d_model = d_model
        self.n_layers = n_layers
        self.activation = nn.ReLU()
        self.input_layer = Linear(d_in=d_input, d_out=d_hidden)
        self.hidden_layers = []
        for _ in range(n_layers-1):
            self.hidden_layers.append(Linear(d_in=d_hidden, d_out=d_hidden))
        self.hi
            
    def forward(self, x):
        out = x
        out = self.input_layer(out)
        out = self.activation(out)
        
        for layer in self.hidden_layers:
            out = layer(out)
            out = self.activation(out)
            
        out = self.output_layer(out)
        
        return out
            

class NTM(nn.Module):
    def __init__(self, n_memory=None, d_memory=None, d_input=None, d_hidden=None, d_output=None, n_layers=None):
        self.d_memory = d_memory
        self.memory = torch.zeros(n_memory, d_memory)
        self.addressing = Addressing()
        self.d_controller_input = d_input + d_memory
        self.d_controller_output = 2 * d_memory + 3 + d_out
        self.controller = Controller(n_layers=n_layers, d_input=self.d_controller_input, d_hidden=d_hidden, d_output=self.d_controller_output)
        self.weights = torch.randn(d_memory)
        self.n_address_params = 2 * d_memory + 3
        
        self.W_write = nn.Parameter(torch.randn(self.d_controller_output, 2*self.d_memory))
    
    def read(self, weights):
        result = torch.einsum("n, nm->m", weights, self.memory)
        return result
    
    def write(self, weights, erase, add):
        self.memory = self.memory - torch.einsum("nm, n, m -> nm", self.memory, weights, erase)
        self.memory = self.memory + torch.einsum("n, m->n", weights, add)
    
    def forward(self, x):
        out = x
        read = torch.zeros(d_memory)
        out = torch.concat(out, read, dim=1)
        
        out = self.controller(out)
        
        M = self.d_memory
        key = out[:M]
        key_strength = out[M]
        gate = out[M+1]
        shift = out[M+2:2*M+2]
        gamma = out[2*M+2]
        
        weights = self.get_weights(self.weights, self.memory, key, key_strength, gate, shift, gamma)
        
        read = self.read(weights)
        
        out = x
        out = torch.concat(out, read, dim=1)
        
        out = self.controller(out)
        
        key = out[:M]
        key_strength = out[M]
        gate = out[M+1]
        shift = out[M+2:2*M+2]
        gamma = out[2*M+2]
        
        weights = self.get_weights(self.weights, self.memory, key, key_strength, gate, shift, gamma)
        
        out = out[2*M+3:]
        
        write = torch.einsum("wc, c->w", self.W_write, out)
        
        erase = out[:M]
        add = out[:M]
        
        self.write(weights, erase, add)
        
        return out
        
            
        
        