# Neural Turing Machines
**Jin Yeom**  
jin.yeom@hudl.com

In this notebook, we attempt at implementing a Neural Turing Machine (NTM). While there are existing implementations, we aim to implement an NTM from scratch, thereby practice reproducing papers. This means that I can reference anything but existing implementation.

In [5]:
import torch
from torch import nn, optim
from torch.nn import functional as F
from matplotlib import pyplot as plt

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device =", device)

device = cuda


## Memory

According to [this paper](https://arxiv.org/abs/1807.08518), memory content initialization is crucial for the success of an NTM. But, we'll address (no pun intended) this later.

In [55]:
class Memory(nn.Module):
    def __init__(self, mem_size, cont_size):
        super(Memory, self).__init__()
        self.mem_size = mem_size
        self.cont_size = cont_size
        self.M = None
        self.w = None
        
    def reset(self, batch_size):
        # TODO: use a better initialization method
        self.M = torch.zeros(batch_size, self.mem_size, self.cont_size)
        self.w = torch.zeros(batch_size, self.mem_size)
    
    def read(self):
        return torch.matmul(self.w, self.M)
    
    def write(self, e, a):
        e = e.unsqueeze(1)
        a = a.unsqueeze(1)
        w = self.w.unsqueeze(-1)
        self.mem -= self.M * torch.matmul(w, e)
        self.mem += torch.matmul(w, a)
        
    def _content_addr(self, k, beta):
        k = k.unsqueeze(1)
        k = k.expand_as(self.M)
        sim = F.cosine_similarity(self.M, k)
        return F.softmax(sim, dim=-1)
        
    def _interpolate(self, wc, g):
        return g * wc + (1 - g) * self.w
        
    def _conv_shift(self, s):
        r = []
        padding = s.size(-1) // 2
        # NOTE: since a kernel for Conv1d cannot have
        # a shape with the batch size, unfortunately,
        # we have to loop through each batch to compute
        # convolutional shift.
        # TODO: find a better way to do this.
        for w, k in zip(self.w, s):
            shifted = F.conv1d(
                w.view(1, 1, -1), 
                k.view(1, 1, -1), 
                padding=padding, 
                padding_mode='circular')
            r.append(shifted.view(-1))
        return torch.stack(r)
    
    def _sharpen(self, gamma):
        w = self.w ** gamma
        return w / (w.sum(1, keepdim=True)+1e-8)
        
    def _location_addr(self, wc, g, s, gamma):
        self.w = self._interpolate(wc, g)
        self.w = self._conv_shift(s)
        self.w = self._sharpen(gamma)
        
    def forward(self, k, beta, g, s, gamma):
        wc = self.content_addr(k, beta)
        self.location_addr(wc, g, s, gamma)

## Controller

For the controller for our NTM, either a feedforward or recurrent network can be used; since we're implementing this for the first time, we're going to use a feedforward network for the sake of transparency in its behavior.

In [None]:
# TODO: implement read head
# TODO: implement write head
# TODO: implement controller