# Neural Turing Machines
**Jin Yeom**  
jin.yeom@hudl.com

In this notebook, we attempt at implementing a Neural Turing Machine (NTM). While there are existing implementations, we aim to implement an NTM from scratch, thereby practice reproducing papers. This means that I can reference anything but existing implementation.

In [1]:
import torch
from torch import nn, optim
from torch.nn import functional as F
from matplotlib import pyplot as plt

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device =", device)

device = cuda


## Tasks

In [None]:
# TODO: implement tasks

## Memory

According to [this paper](https://arxiv.org/abs/1807.08518), memory content initialization is crucial for the success of an NTM. But, we'll address (no pun intended) this later.

In [3]:
class Memory:
    def __init__(self, n, m):
        self.n = n  # memory size
        self.m = m  # content size
        self.M = None
        self.w = None
        
    @property
    def state(self):
        M = self.M.detach().cpu()
        w = self.w.detach().cpu()
        return M, w
        
    def reset(self, batch_size, device):
        # TODO: use a better initialization method
        self.M = torch.zeros(batch_size, self.n, self.m).to(device)
        self.w = torch.zeros(batch_size, self.m).to(device)
    
    def read(self):
        return torch.matmul(self.w, self.M)
    
    def write(self, e, a):
        e = e.unsqueeze(1)
        a = a.unsqueeze(1)
        w = self.w.unsqueeze(-1)
        self.mem -= self.M * torch.matmul(w, e)
        self.mem += torch.matmul(w, a)
        
    def _cont_addr(self, k, beta):
        """Content-based addressing"""
        k = k.unsqueeze(1)
        k = k.expand_as(self.M)
        sim = F.cosine_similarity(self.M, k)
        return F.softmax(sim, dim=-1)
        
    def _interpolate(self, wc, w, g):
        """Interpolation between w_c and w at t - 1"""
        return g * wc + (1 - g) * w
        
    def _conv_shift(self, w, s):
        """Convolutional shift"""
        r = []
        padding = s.size(-1) // 2
        # NOTE: since a kernel for Conv1d cannot have
        # a shape with the batch size, unfortunately,
        # we have to loop through each batch to compute
        # convolutional shift.
        # TODO: find a better way to do this.
        for i, k in zip(w, s):
            shifted = F.conv1d(
                i.view(1, 1, -1), 
                k.view(1, 1, -1), 
                padding=padding,
                padding_mode='circular')
            r.append(shifted.view(-1))
        return torch.stack(r)
    
    def _sharpen(self, w, gamma):
        """Sharpening"""
        w = w ** gamma
        return w / (w.sum(-1, keepdim=True)+1e-8)
        
    def _loc_addr(self, wc, w, g, s, gamma):
        """Location-based addressing"""
        w = self._interpolate(wc, w, g)
        w = self._conv_shift(w, s)
        return self._sharpen(w, gamma)
        
    def step(self, k, beta, g, s, gamma):
        wc = self._cont_addr(k, beta)
        self.w = self._loc_addr(wc, self.w, g, s, gamma)
        return self.w.detach().cpu()

## Controller

For the controller for our NTM, either a feedforward or recurrent network can be used; since we're implementing this for the first time, we're going to use a feedforward network for the sake of transparency in its behavior.

In [None]:
# TODO: implement read head
# TODO: implement write head
# TODO: implement controller

## References

- [Neural Turing Machines](https://arxiv.org/abs/1410.5401)
- [Implementing Neural Turing Machines](https://arxiv.org/abs/1807.08518)