<a href="https://colab.research.google.com/github/mahi97/NTM/blob/master/NTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np

## Memory Module

In [2]:
def _convolve(w, s):
  """Circular convolution implementation."""
  assert s.size(0) == 3
  t = torch.cat([w[-1:], w, w[:1]])
  c = F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(-1)
  return c

class Memory(nn.Module):
  def __init__(self, N, M):
    super(Memory, self).__init__()

    self.M = M
    self.N = N
    
    # The memory bias allows the heads to learn how to initially address
    # memory locations by content
    self.register_buffer('mem_bias', torch.Tensor(N, M))
    std_dev = 1 / np.sqrt(N + M)
    nn.init.uniform(self.mem_bias, -std_dev, std_dev)
    
  def reset(self, batch_size):
    """Initialize memory from bias, for start-of-sequence."""
    self.batch_size = batch_size
    self.mem = self.mem_bias.clone().repeat(batch_size, 1, 1)
   
  def read(self, address):
    """
    :param address: Batched Tensor with Size of batch_size * N, contain value between 0 and 1 with sum equals to 1
    :return: Torch batched tensor with Size of batch_size * M, produce by sum over weighted elements of Memory
    """
    return address.unsqueeze(-1).matmul(self.mem).squeeze(1)

  def write(self, address, erase_vector, add_vector):
    self.prev_mem = self.mem
    self.mem = torch.Tensor(self.batch_size, self.N, self.M)
    erase = torch.matmul(w.unsqueeze(-1), e.unsqueeze(1))
    add = torch.matmul(w.unsqueeze(-1), a.unsqueeze(1))
    self.memory = self.prev_mem * (1 - erase) + add

  def address(self, key_vector, key_strength, gate, shift, sharpen, last_address):
    """NTM Addressing (according to section 3.3).
    Returns a softmax weighting over the rows of the memory matrix.
    :param key_vector: The key vector.
    :param key_strength: The key strength (focus).
    :param gate: Scalar interpolation gate (with previous weighting).
    :param shift: Shift weighting.
    :param sharpen: Sharpen weighting scalar.
    :param last_address: The weighting produced in the previous time step.
    """
    wc = F.softmax(key_strength * F.cosine_similarity(key_vector.unsqueeze(1), self.mem, dim=2), dim=1)
    wg = (gate * wc) + (1 - gate) * last_address
    ws = torch.from_numpy(np.array([(_convolve(wg[b], shift[b])).numpy() for b in range(self.batch_size)]))
    ws = (ws ** sharpen)
    wt = torch.true_divide(ws, torch.sum(ws, dim=1).view(-1, 1) + 1e-16)

    return wt


## Controller Module
### 1. LSTM


In [213]:
class LSTMController(nn.Module):
  def __init__(self, num_inputs, num_outputs, num_layers):
    super(LSTMController, self).__init__()

    self.num_inputs  = num_inputs
    self.num_layers  = num_layers
    self.num_outputs = num_outputs

    self.lstm = nn.LSTM(num_inputs, num_outputs, num_layers)

    self.lstm_h_state = nn.Parameter(torch.randn(num_layers, 1, num_outputs) * 0.05) # Why 0.05??
    self.lstm_c_state = nn.Parameter(torch.Tensor(num_layers, 1, num_outputs) * 0.05) 

  def create_new_state(self, batch_size):
    # Dimension: (num_layers * num_directions, batch, hidden_size)
    lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1)
    lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1)
    return lstm_h, lstm_c

  def reset_parameters(self):
    for p in self.lstm.parameters():
      if p.dim() == 1:
        nn.init.constant_(p, 0)
      else:
        stdev = 5 / (np.sqrt(self.num_inputs +  self.num_outputs))
        nn.init.uniform_(p, -stdev, stdev)

  def size(self):
    return self.num_inputs, self.num_outputs

  def forward(self, x, prev_state):
    x = x.unsqueeze(0)
    outp, state = self.lstm(x, prev_state)
    return outp.squeeze(0), state


### 2. Feed Forward

In [None]:
class FFController(nn.Module):
  def __init__(self, num_inputs, num_outputs, num_hidden):
    super(FFController, self).__init__()

    self.num_inputs  = num_inputs
    self.num_hidden  = num_hidden
    self.num_outputs = num_outputs

    self.fc1 = nn.Linear(num_inputs, num_hidden)
    self.sigmoid = nn.Sigmoid()
    self.fc2 = nn.Linear(num_hidden, num_outputs)
    self.reset_parameters()

  def reset_parameters(self):
    # Initialize the linear layers
    nn.init.xavier_uniform_(self.fc1.weight, gain=1.4)
    nn.init.normal_(self.fc1.bias, std=0.01)
    nn.init.xavier_uniform_(self.fc2.weight, gain=1.4)
    nn.init.normal_(self.fc2.bias, std=0.01)
  
  def size(self):
    return self.num_inputs, self.num_outputs

  def forward(self, x):
    out = self.fc1(x)
    out = self.sigmoid(out)
    out = self.fc2(out)
    return out

## Heads

### 1. Base Head

In [27]:
class BaseHead(nn.Module):
  def __init__(self, memory, controller):
    super(BaseHead, self).__init__()

    self.memo = memory
    self.ctrl = controller
    _, self.ctrl_size = controller.size()
    self.M = memory.M
    self.N = memory.N
  
  def create_new_state(self, batch_size):
    raise NotImplementedError

  def register_parameters(self):
    raise NotImplementedError

  def reset_parameters(self):
    raise NotImplementedError

  def is_read_head(self):
    return NotImplementedError

  def _address_memory(self, k, B, g, s, L, w_prev):
    # Handle Activations
    k = k.clone()
    B = F.softplus(B)
    g = F.sigmoid(g)
    s = F.softmax(s, dim=1)
    L = 1 + F.softplus(L)

    w = self.memory.address(k, B, g, s, L, w_prev)

    return w


def _split_cols(mat, lengths):
    """Split a 2D matrix to variable length columns."""
    assert mat.size()[1] == sum(lengths), "Lengths must be summed to num columns"
    l = np.cumsum([0] + lengths)
    results = []
    for s, e in zip(l[:-1], l[1:]):
        results += [mat[:, s:e]]
    return results

### 2. Read Head

In [28]:
class ReadHead(BaseHead):
  def __init__(self, memory, controller):
    super(ReadHead, self).__init__(memory, controller)

    #                     K, B, G, S, L
    self.read_vector = [self.M, 1, 1, 3, 1]
    self.fc_read = nn.Linear(self.ctrl_size, sum(self.read_vector))
    self.reset_parameters()
  
  def create_new_state(self, batch_size):
    return torch.zeros(batch_size, self.N)

  def reset_parameters(self):
    nn.init.xavier_uniform_(self.fc_read.weight, gain=1.4)
    nn.init.normal_(self.fc_read.bias, std=0.01)

  def is_read_head(self):
    return True

  def forward(self, input, last_w):
    out = self.fc_read(input)
    K, B, G, S, L = _split_cols(out, self.read_vector)
    w = self._address_memory(K, B, G, S, L, last_w)
    r = self.memo.read(w)
    return r, w

### 3. Write Head

In [29]:
class WriteHead(BaseHead):
  def __init__(self, memory, controller):
    super(ReadHead, self).__init__(memory, controller)

    #                     K, B, G, S, L, add, erase
    self.write_vector = [self.M, 1, 1, 3, 1, self.M, self.M]
    self.fc_write = nn.Linear(self.ctrl_size, sum(self.write_vector))
    self.reset_parameters()
  
  def create_new_state(self, batch_size):
    return torch.zeros(batch_size, self.N)

  def reset_parameters(self):
    nn.init.xavier_uniform_(self.fc_write.weight, gain=1.4)
    nn.init.normal_(self.fc_write.bias, std=0.01)

  def is_read_head(self):
    return False

  def forward(self, input, last_w):
    out = self.fc_write(input)
    K, B, G, S, L, A, E = _split_cols(out, self.write_vector)
    w = self._address_memory(K, B, G, S, L, last_w)
    self.memo.write(w, F.sigmoid(E), A)
    return w

## DataPath

In [31]:
class DataPath(nn.Module):
  """A DataPath for NTM."""
  def __init__(self, num_inputs, num_outputs, controller, memory, heads):
    """Initialize the DataPath.
    :param num_inputs: External input size.
    :param num_outputs: External output size.
    :param controller: :class:`LSTMController`
    :param memory: :class:`Memory`
    :param heads: list of :class:`ReadHead` or :class:`WriteHead`
    Note: This design allows the flexibility of using any number of read and
          write heads independently, also, the order by which the heads are
          called in controlled by the user (order in list)
    """
    super(DataPath, self).__init__()

    self.num_inputs = num_inputs
    self.num_outputs = num_outputs
    self.controller = controller
    self.memory = memory
    self.heads = heads

    self.N, self.M = memory.size()
    _, self.controller_size = controller.size()

    # Initialize a fully connected layer to produce the actual output:
    #   [controller_output; previous_reads ] -> output
    # self.fc = nn.Linear(self.controller_size + self.num_read_heads * self.M, num_outputs)
    #   [controller_output ] -> output
    self.fc = nn.Linear(self.controller_size, num_outputs)
    self.reset_parameters()

  def create_new_state(self, batch_size):
    init_r = [r.clone().repeat(batch_size, 1) for r in self.init_r]
    controller_state = self.controller.create_new_state(batch_size)
    heads_state = [head.create_new_state(batch_size) for head in self.heads]

    return init_r, controller_state, heads_state

  def reset_parameters(self):
    # Initialize the linear layer
    nn.init.xavier_uniform_(self.fc.weight, gain=1)
    nn.init.normal_(self.fc.bias, std=0.01)
    
    # Initialize the initial previous read values to random biases
    self.num_read_heads = 0
    self.init_r = []
    for head in heads:
      if head.is_read_head():
        init_r_bias = torch.randn(1, self.M) * 0.01
        self.register_buffer("read{}_bias".format(self.num_read_heads), init_r_bias.data)
        self.init_r += [init_r_bias]
        self.num_read_heads += 1

    assert self.num_read_heads > 0, "heads list must contain at least a single read head"

  def forward(self, x, prev_state):
    """DataPath forward function.
    :param x: input vector (batch_size x num_inputs)
    :param prev_state: The previous state of the DataPath
    """
    # Unpack the previous state
    prev_reads, prev_controller_state, prev_heads_states = prev_state

    # Use the controller to get an embeddings
    inp = torch.cat([x] + prev_reads, dim=1)
    controller_outp, controller_state = self.controller(inp, prev_controller_state)

    # Read/Write from the list of heads
    reads = []
    heads_states = []
    for head, prev_head_state in zip(self.heads, prev_heads_states):
      if head.is_read_head():
        r, head_state = head(controller_outp, prev_head_state)
        reads += [r]
      else:
        head_state = head(controller_outp, prev_head_state)
        heads_states += [head_state]

    # Generate Output
    # inp2 = torch.cat([controller_outp] + reads, dim=1)
    # o = F.sigmoid(self.fc(inp2))
    o = F.sigmoid(self.fc(controller_outp))

    # Pack the current state
    state = (reads, controller_state, heads_states)

    return o, state

## NTM

In [30]:
class NTM(nn.Module):

  def __init__(self, num_inputs, num_outputs, controller_size, controller_layers, num_read_heads, num_write_heads, N, M):
    """Initialize an NTM.
    :param num_inputs: External number of inputs.
    :param num_outputs: External number of outputs.
    :param controller_size: The size of the internal representation.
    :param controller_layers: Controller number of layers.
    :param num_heads: Number of heads.
    :param N: Number of rows in the memory bank.
    :param M: Number of cols/features in the memory bank.
    """
    super(NTM, self).__init__()

    # Save args
    self.num_inputs = num_inputs
    self.num_outputs = num_outputs
    self.controller_size = controller_size
    self.controller_layers = controller_layers
    self.num_read_heads = num_read_heads
    self.num_write_heads = num_write_heads
    self.N = N
    self.M = M

    # Create the NTM components
    self.memory = Memory(N, M)
    self.controller = LSTMController(num_inputs + M*num_heads, controller_size, controller_layers)
    # controller = FFController(num_inputs + M*num_heads, controller_size, controller_layers)
    self.heads = nn.ModuleList([ReadHead(self.memory, self.controller) for _ in range(num_read_heads)])
    self.heads += [WriteHead(self.memory, self.controller) for _ in range(num_write_heads)]

    self.data_path = DataPath(num_inputs, num_outputs, self.controller, self.memory, self.heads)

  def init_sequence(self, batch_size):
    """Initializing the state."""
    self.batch_size = batch_size
    self.memory.reset(batch_size)
    self.previous_state = self.data_path.create_new_state(batch_size)

  def forward(self, x=None):
    if x is None:
      x = torch.zeros(self.batch_size, self.num_inputs)
    
    o, self.previous_state = self.data_path(x, self.previous_state)
    return o, self.previous_state

  def calculate_num_params(self):
    """Returns the total number of parameters."""
    num_params = 0
    for p in self.parameters():
      num_params += p.data.view(-1).size(0)
    
    return num_params

In [34]:
a = [1,2,4]
a += [1,4,6,7]
a

[1, 2, 4, 1, 4, 6, 7]