<a href="https://colab.research.google.com/github/mahi97/NTM/blob/master/NTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np

## Memory Module

In [2]:
def _convolve(w, s):
  """Circular convolution implementation."""
  assert s.size(0) == 3
  t = torch.cat([w[-1:], w, w[:1]])
  c = F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(-1)
  return c

class Memory(nn.Module):
  def __init__(self, N, M):
    super(Memory, self).__init__()

    self.M = M
    self.N = N
    
    # The memory bias allows the heads to learn how to initially address
    # memory locations by content
    self.register_buffer('mem_bias', torch.Tensor(N, M))
    std_dev = 1 / np.sqrt(N + M)
    nn.init.uniform(self.mem_bias, -std_dev, std_dev)
    
  def reset(self, batch_size):
    """Initialize memory from bias, for start-of-sequence."""
    self.batch_size = batch_size
    self.mem = self.mem_bias.clone().repeat(batch_size, 1, 1)
   
  def read(self, address):
    """
    :param address: Batched Tensor with Size of batch_size * N, contain value between 0 and 1 with sum equals to 1
    :return: Torch batched tensor with Size of batch_size * M, produce by sum over weighted elements of Memory
    """
    return address.unsqueeze(-1).matmul(self.mem).squeeze(1)

  def write(self, address, erase_vector, add_vector):
    self.prev_mem = self.mem
    self.mem = torch.Tensor(self.batch_size, self.N, self.M)
    erase = torch.matmul(w.unsqueeze(-1), e.unsqueeze(1))
    add = torch.matmul(w.unsqueeze(-1), a.unsqueeze(1))
    self.memory = self.prev_mem * (1 - erase) + add

  def address(self, key_vector, key_strength, gate, shift, sharpen, last_address):
    """NTM Addressing (according to section 3.3).
    Returns a softmax weighting over the rows of the memory matrix.
    :param key_vector: The key vector.
    :param key_strength: The key strength (focus).
    :param gate: Scalar interpolation gate (with previous weighting).
    :param shift: Shift weighting.
    :param sharpen: Sharpen weighting scalar.
    :param last_address: The weighting produced in the previous time step.
    """
    wc = F.softmax(key_strength * F.cosine_similarity(key_vector.unsqueeze(1), self.mem, dim=2), dim=1)
    wg = (gate * wc) + (1 - gate) * last_address
    ws = torch.from_numpy(np.array([(_convolve(wg[b], shift[b])).numpy() for b in range(self.batch_size)]))
    ws = (ws ** sharpen)
    wt = torch.true_divide(ws, torch.sum(ws, dim=1).view(-1, 1) + 1e-16)

    return wt


## Controller Module
### 1. LSTM


In [213]:
class LSTMController(nn.Module):
  def __init__(self, num_inputs, num_outputs, num_layers):
    super(LSTMController, self).__init__()

    self.num_inputs  = num_inputs
    self.num_layers  = num_layers
    self.num_outputs = num_outputs

    self.lstm = nn.LSTM(num_inputs, num_outputs, num_layers)

    self.lstm_h_state = nn.Parameter(torch.randn(num_layers, 1, num_outputs) * 0.05) # Why 0.05??
    self.lstm_c_state = nn.Parameter(torch.Tensor(num_layers, 1, num_outputs) * 0.05) 

  def create_new_state(self, batch_size):
    # Dimension: (num_layers * num_directions, batch, hidden_size)
    lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1)
    lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1)
    return lstm_h, lstm_c

  def reset_parameters(self):
    for p in self.lstm.parameters():
      if p.dim() == 1:
        nn.init.constant_(p, 0)
      else:
        stdev = 5 / (np.sqrt(self.num_inputs +  self.num_outputs))
        nn.init.uniform_(p, -stdev, stdev)

  def size(self):
    return self.num_inputs, self.num_outputs

  def forward(self, x, prev_state):
    x = x.unsqueeze(0)
    outp, state = self.lstm(x, prev_state)
    return outp.squeeze(0), state


### 2. Feed Forward

In [None]:
class FFController(nn.Module):
  def __init__(self, num_inputs, num_hidden, num_outputs):
    super(FFController, self).__init__()

    self.num_inputs  = num_inputs
    self.num_hidden  = num_hidden
    self.num_outputs = num_outputs

    self.fc1 = nn.Linear(num_inputs, num_hidden)
    self.sigmoid = nn.Sigmoid()
    self.fc2 = nn.Linear(num_hidden, num_outputs)
    

  def size(self):
    return self.num_inputs, self.num_outputs

  def forward(self, x):
    out = self.fc1(x)
    out = self.sigmoid(out)
    out = self.fc2(out)
    return out

## Heads