<a href="https://colab.research.google.com/github/mahi97/NTM/blob/master/NTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn.functional as F
from torch import nn
from torch import optim
import numpy as np
import random
from attr import attrs, attrib, Factory
import time

## Memory Module

In [2]:
def _convolve(wg, sg, batch_size):
  """Circular convolution implementation."""
  result = torch.zeros(wg.size())
  for i in range(batch_size):
    w = wg[i]
    s = sg[i]
    assert s.size(0) == 3
    t = torch.cat([w[-1:], w, w[:1]])
    result[i] = F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(-1)
  return result

class Memory(nn.Module):
  def __init__(self, N, M):
    super(Memory, self).__init__()

    self.M = M
    self.N = N
    
    # The memory bias allows the heads to learn how to initially address
    # memory locations by content
    self.register_buffer('mem_bias', torch.Tensor(N, M))
    std_dev = 1 / np.sqrt(N + M)
    nn.init.uniform(self.mem_bias, -std_dev, std_dev)
    
  def reset(self, batch_size):
    """Initialize memory from bias, for start-of-sequence."""
    self.batch_size = batch_size
    self.mem = self.mem_bias.clone().repeat(batch_size, 1, 1)
   
  def read(self, address):
    """
    :param address: Batched Tensor with Size of batch_size * N, contain value between 0 and 1 with sum equals to 1
    :return: Torch batched tensor with Size of batch_size * M, produce by sum over weighted elements of Memory
    """
    return address.unsqueeze(1).matmul(self.mem).squeeze(1)

  def write(self, address, erase_vector, add_vector):
    self.prev_mem = self.mem
    self.mem = torch.Tensor(self.batch_size, self.N, self.M)
    erase = torch.matmul(address.unsqueeze(-1), erase_vector.unsqueeze(1))
    add = torch.matmul(address.unsqueeze(-1), add_vector.unsqueeze(1))
    self.memory = self.prev_mem * (1 - erase) + add

  def address(self, key_vector, key_strength, gate, shift, sharpen, last_address):
    """NTM Addressing (according to section 3.3).
    Returns a softmax weighting over the rows of the memory matrix.
    :param key_vector: The key vector.
    :param key_strength: The key strength (focus).
    :param gate: Scalar interpolation gate (with previous weighting).
    :param shift: Shift weighting.
    :param sharpen: Sharpen weighting scalar.
    :param last_address: The weighting produced in the previous time step.
    """
    wc = F.softmax(key_strength * F.cosine_similarity(key_vector.unsqueeze(1), self.mem, dim=2), dim=1)
    wg = (gate * wc) + (1 - gate) * last_address
    ws = _convolve(wg, shift, self.batch_size)
    ws = (ws ** sharpen)
    wt = torch.true_divide(ws, torch.sum(ws, dim=1).view(-1, 1) + 1e-16)

    return wt

  def size(self):
    return self.N, self.M


## Controller Module
### 1. LSTM


In [3]:
class LSTMController(nn.Module):
  def __init__(self, num_inputs, num_outputs, num_layers):
    super(LSTMController, self).__init__()

    self.num_inputs  = num_inputs
    self.num_layers  = num_layers
    self.num_outputs = num_outputs

    self.lstm = nn.LSTM(num_inputs, num_outputs, num_layers)

    self.lstm_h = nn.Parameter(torch.randn(num_layers, 1, num_outputs) * 0.05) # Why 0.05??
    self.lstm_c = nn.Parameter(torch.Tensor(num_layers, 1, num_outputs) * 0.05) 

  def create_new_state(self, batch_size):
    # Dimension: (num_layers * num_directions, batch, hidden_size)
    lstm_h = self.lstm_h.clone().repeat(1, batch_size, 1)
    lstm_c = self.lstm_c.clone().repeat(1, batch_size, 1)
    return lstm_h, lstm_c

  def reset_parameters(self):
    for p in self.lstm.parameters():
      if p.dim() == 1:
        nn.init.constant_(p, 0)
      else:
        stdev = 5 / (np.sqrt(self.num_inputs +  self.num_outputs))
        nn.init.uniform_(p, -stdev, stdev)

  def size(self):
    return self.num_inputs, self.num_outputs

  def forward(self, x, prev_state):
    x = x.unsqueeze(0)
    outp, state = self.lstm(x, prev_state)
    return outp.squeeze(0), state


### 2. Feed Forward

In [4]:
class FFController(nn.Module):
  def __init__(self, num_inputs, num_outputs, num_hidden):
    super(FFController, self).__init__()

    self.num_inputs  = num_inputs
    self.num_hidden  = num_hidden
    self.num_outputs = num_outputs

    self.fc1 = nn.Linear(num_inputs, num_hidden)
    self.sigmoid = nn.Sigmoid()
    self.fc2 = nn.Linear(num_hidden, num_outputs)
    self.reset_parameters()

  def reset_parameters(self):
    # Initialize the linear layers
    nn.init.xavier_uniform_(self.fc1.weight, gain=1.4)
    nn.init.normal_(self.fc1.bias, std=0.01)
    nn.init.xavier_uniform_(self.fc2.weight, gain=1.4)
    nn.init.normal_(self.fc2.bias, std=0.01)
  
  def size(self):
    return self.num_inputs, self.num_outputs

  def forward(self, x):
    out = self.fc1(x)
    out = self.sigmoid(out)
    out = self.fc2(out)
    return out

## Heads

### 1. Base Head

In [5]:
class BaseHead(nn.Module):
  def __init__(self, memory, controller):
    super(BaseHead, self).__init__()

    self.memory = memory
    self.controller = controller
    _, self.ctrl_size = controller.size()
    self.M = memory.M
    self.N = memory.N
  
  def create_new_state(self, batch_size):
    raise NotImplementedError

  def register_parameters(self):
    raise NotImplementedError

  def reset_parameters(self):
    raise NotImplementedError

  def is_read_head(self):
    return NotImplementedError

  def _address_memory(self, k, B, g, s, L, w_prev):
    # Handle Activations
    k = k.clone()
    B = F.softplus(B)
    g = F.sigmoid(g)
    s = F.softmax(s, dim=1)
    L = 1 + F.softplus(L)

    w = self.memory.address(k, B, g, s, L, w_prev)

    return w


def _split_cols(mat, lengths):
    """Split a 2D matrix to variable length columns."""
    assert mat.size()[1] == sum(lengths), "Lengths must be summed to num columns"
    l = np.cumsum([0] + lengths)
    results = []
    for s, e in zip(l[:-1], l[1:]):
        results += [mat[:, s:e]]
    return results

### 2. Read Head

In [6]:
class ReadHead(BaseHead):
  def __init__(self, memory, controller):
    super(ReadHead, self).__init__(memory, controller)

    #                     K, B, G, S, L
    self.read_vector = [self.M, 1, 1, 3, 1]
    self.fc_read = nn.Linear(self.ctrl_size, sum(self.read_vector))
    self.reset_parameters()
  
  def create_new_state(self, batch_size):
    return torch.zeros(batch_size, self.N)

  def reset_parameters(self):
    nn.init.xavier_uniform_(self.fc_read.weight, gain=1.4)
    nn.init.normal_(self.fc_read.bias, std=0.01)

  def is_read_head(self):
    return True

  def forward(self, input, last_w):
    out = self.fc_read(input)
    K, B, G, S, L = _split_cols(out, self.read_vector)
    w = self._address_memory(K, B, G, S, L, last_w)
    r = self.memory.read(w)
    return r, w

### 3. Write Head

In [7]:
class WriteHead(BaseHead):
  def __init__(self, memory, controller):
    super(WriteHead, self).__init__(memory, controller)

    #                     K, B, G, S, L, add, erase
    self.write_vector = [self.M, 1, 1, 3, 1, self.M, self.M]
    self.fc_write = nn.Linear(self.ctrl_size, sum(self.write_vector))
    self.reset_parameters()
  
  def create_new_state(self, batch_size):
    return torch.zeros(batch_size, self.N)

  def reset_parameters(self):
    nn.init.xavier_uniform_(self.fc_write.weight, gain=1.4)
    nn.init.normal_(self.fc_write.bias, std=0.01)

  def is_read_head(self):
    return False

  def forward(self, input, last_w):
    out = self.fc_write(input)
    K, B, G, S, L, A, E = _split_cols(out, self.write_vector)
    w = self._address_memory(K, B, G, S, L, last_w)
    self.memory.write(w, F.sigmoid(E), A)
    return w

## DataPath

In [8]:
class DataPath(nn.Module):
  """A DataPath for NTM."""
  def __init__(self, num_inputs, num_outputs, controller, memory, heads):
    """Initialize the DataPath.
    :param num_inputs: External input size.
    :param num_outputs: External output size.
    :param controller: :class:`LSTMController`
    :param memory: :class:`Memory`
    :param heads: list of :class:`ReadHead` or :class:`WriteHead`
    Note: This design allows the flexibility of using any number of read and
          write heads independently, also, the order by which the heads are
          called in controlled by the user (order in list)
    """
    super(DataPath, self).__init__()

    self.num_inputs = num_inputs
    self.num_outputs = num_outputs
    self.controller = controller
    self.memory = memory
    self.heads = heads

    self.N, self.M = memory.size()
    _, self.controller_size = controller.size()

    # Initialize a fully connected layer to produce the actual output:
    #   [controller_output; previous_reads ] -> output
    # self.fc = nn.Linear(self.controller_size + self.num_read_heads * self.M, num_outputs)
    #   [controller_output ] -> output
    self.fc = nn.Linear(self.controller_size, num_outputs)
    self.reset_parameters()

  def create_new_state(self, batch_size):
    init_r = [r.clone().repeat(batch_size, 1) for r in self.init_r]
    controller_state = self.controller.create_new_state(batch_size)
    heads_state = [head.create_new_state(batch_size) for head in self.heads]

    return init_r, controller_state, heads_state

  def reset_parameters(self):
    # Initialize the linear layer
    nn.init.xavier_uniform_(self.fc.weight, gain=1)
    nn.init.normal_(self.fc.bias, std=0.01)
    
    # Initialize the initial previous read values to random biases
    self.num_read_heads = 0
    self.init_r = []
    for head in self.heads:
      if head.is_read_head():
        init_r_bias = torch.randn(1, self.M) * 0.01
        self.register_buffer("read{}_bias".format(self.num_read_heads), init_r_bias.data)
        self.init_r += [init_r_bias]
        self.num_read_heads += 1

    assert self.num_read_heads > 0, "heads list must contain at least a single read head"

  def forward(self, x, prev_state):
    """DataPath forward function.
    :param x: input vector (batch_size x num_inputs)
    :param prev_state: The previous state of the DataPath
    """
    # Unpack the previous state
    prev_reads, prev_controller_state, prev_heads_states = prev_state

    # Use the controller to get an embeddings
    inp = torch.cat([x] + prev_reads, dim=1)
    controller_outp, controller_state = self.controller(inp, prev_controller_state)

    # Read/Write from the list of heads
    reads = []
    heads_states = []
    for head, prev_head_state in zip(self.heads, prev_heads_states):
      if head.is_read_head():
        r, head_state = head(controller_outp, prev_head_state)
        reads += [r]
      else:
        head_state = head(controller_outp, prev_head_state)
      heads_states += [head_state]

    # Generate Output
    # inp2 = torch.cat([controller_outp] + reads, dim=1)
    # o = F.sigmoid(self.fc(inp2))
    o = F.sigmoid(self.fc(controller_outp))

    # Pack the current state
    state = (reads, controller_state, heads_states)

    return o, state

## NTM

In [9]:
class NTM(nn.Module):

  def __init__(self, num_inputs, num_outputs, controller_size, controller_layers, num_read_heads, num_write_heads, N, M):
    """Initialize an NTM.
    :param num_inputs: External number of inputs.
    :param num_outputs: External number of outputs.
    :param controller_size: The size of the internal representation.
    :param controller_layers: Controller number of layers.
    :param num_heads: Number of heads.
    :param N: Number of rows in the memory bank.
    :param M: Number of cols/features in the memory bank.
    """
    super(NTM, self).__init__()

    # Save args
    self.num_inputs = num_inputs
    self.num_outputs = num_outputs
    self.controller_size = controller_size
    self.controller_layers = controller_layers
    self.num_read_heads = num_read_heads
    self.num_write_heads = num_write_heads
    self.N = N
    self.M = M

    # Create the NTM components
    self.memory = Memory(N, M)
    self.controller = LSTMController(num_inputs + M*(num_read_heads), controller_size, controller_layers)
    # controller = FFController(num_inputs + M*num_heads, controller_size, controller_layers)
    self.heads = nn.ModuleList([ReadHead(self.memory, self.controller) for _ in range(num_read_heads)])
    self.heads += [WriteHead(self.memory, self.controller) for _ in range(num_write_heads)]

    self.data_path = DataPath(num_inputs, num_outputs, self.controller, self.memory, self.heads)

  def init_sequence(self, batch_size):
    """Initializing the state."""
    self.batch_size = batch_size
    self.memory.reset(batch_size)
    self.previous_state = self.data_path.create_new_state(batch_size)

  def forward(self, x=None):
    if x is None:
      x = torch.zeros(self.batch_size, self.num_inputs)
    o, self.previous_state = self.data_path(x, self.previous_state)
    return o, self.previous_state

  def calculate_num_params(self):
    """Returns the total number of parameters."""
    num_params = 0
    for p in self.parameters():
      num_params += p.data.view(-1).size(0)
    
    return num_params

## Tasks
### 1. Copy Task


#### Model and Parameters

In [10]:
"""Copy Task NTM model."""


# Generator of randomized test sequences
def dataloader(num_batches, batch_size, seq_width, min_len, max_len):
  """Generator of random sequences for the copy task.
  Creates random batches of "bits" sequences.
  All the sequences within each batch have the same length.
  The length is [`min_len`, `max_len`]
  :param num_batches: Total number of batches to generate.
  :param seq_width: The width of each item in the sequence.
  :param batch_size: Batch size.
  :param min_len: Sequence minimum length.
  :param max_len: Sequence maximum length.
  NOTE: The input width is `seq_width + 1`, the additional input
  contain the delimiter.
  """
  for batch_num in range(num_batches):

    # All batches have the same sequence length
    seq_len = random.randint(min_len, max_len)
    seq = np.random.binomial(1, 0.5, (seq_len, batch_size, seq_width))
    seq = torch.from_numpy(seq)

    # The input includes an additional channel used for the delimiter
    inp = torch.zeros(seq_len + 1, batch_size, seq_width + 1)
    inp[:seq_len, :, :seq_width] = seq
    inp[seq_len, :, seq_width] = 1.0 # delimiter in our control channel
    outp = seq.clone()

    yield batch_num+1, inp.float(), outp.float()


@attrs
class CopyTaskParams(object):
  name = attrib(default="copy-task")
  controller_size = attrib(default=100, converter=int)
  controller_layers = attrib(default=1,converter=int)
  num_read_heads = attrib(default=1, converter=int)
  num_write_heads = attrib(default=1, converter=int)
  sequence_width = attrib(default=8, converter=int)
  sequence_min_len = attrib(default=1,converter=int)
  sequence_max_len = attrib(default=20, converter=int)
  memory_n = attrib(default=128, converter=int)
  memory_m = attrib(default=20, converter=int)
  num_batches = attrib(default=50000, converter=int)
  batch_size = attrib(default=1, converter=int)
  rmsprop_lr = attrib(default=1e-4, converter=float)
  rmsprop_momentum = attrib(default=0.9, converter=float)
  rmsprop_alpha = attrib(default=0.95, converter=float)

@attrs
class CopyTaskModelTraining(object):
  params = attrib(default=Factory(CopyTaskParams))
  net = attrib()
  dataloader = attrib()
  criterion = attrib()
  optimizer = attrib()

  @net.default
  def default_net(self):
    # We have 1 additional input for the delimiter which is passed on a
    # separate "control" channel
    net = NTM(self.params.sequence_width + 1, self.params.sequence_width,
              self.params.controller_size, self.params.controller_layers,
              self.params.num_read_heads, self.params.num_write_heads,
              self.params.memory_n, self.params.memory_m)
    return net

  @dataloader.default
  def default_dataloader(self):
    return dataloader(self.params.num_batches, self.params.batch_size,
                      self.params.sequence_width,
                      self.params.sequence_min_len, self.params.sequence_max_len)

  @criterion.default
  def default_criterion(self):
    return nn.BCELoss()

  @optimizer.default
  def default_optimizer(self):
    return optim.RMSprop(self.net.parameters(),
                        momentum=self.params.rmsprop_momentum,
                        alpha=self.params.rmsprop_alpha,
                        lr=self.params.rmsprop_lr)

#### Training

Helper Parameter and Fucntions

In [11]:
def get_ms():
  return time.time() * 1000

def progress_clean():
  """Clean the progress bar."""
  print("\r{}".format(" " * 80), end='\r')

def progress_bar(batch_num, report_interval, last_loss):
  """Prints the progress until the next report."""
  progress = (((batch_num-1) % report_interval) + 1) / report_interval
  fill = int(progress * 40)
  print("\r[{}{}]: {} (Loss: {:.4f})".format("=" * fill, " " * (40 - fill), batch_num, last_loss), end='')

### Training Fucntions

In [12]:
"""Training for the Copy Task in Neural Turing Machines."""

def clip_grads(net):
    """Gradient clipping to the range [10, 10]."""
    parameters = list(filter(lambda p: p.grad is not None, net.parameters()))
    for p in parameters:
        p.grad.data.clamp_(-10, 10)

def train_batch(net, criterion, optimizer, X, Y):
    """Trains a single batch."""
    optimizer.zero_grad()
    inp_seq_len = X.size(0)
    outp_seq_len, batch_size, _ = Y.size()

    # New sequence
    net.init_sequence(batch_size)

    # Feed the sequence + delimiter
    for i in range(inp_seq_len):
        net(X[i])

    # Read the output (no input given)
    y_out = torch.zeros(Y.size())
    for i in range(outp_seq_len):
        y_out[i], _ = net()

    loss = criterion(y_out, Y)
    loss.backward()
    clip_grads(net)
    optimizer.step()

    y_out_binarized = y_out.clone().data
    y_out_binarized.apply_(lambda x: 0 if x < 0.5 else 1)

    # The cost is the number of error bits per sequence
    cost = torch.sum(torch.abs(y_out_binarized - Y.data))

    return loss.item(), cost.item() / batch_size

def train_model(model, report_interval=500):
    num_batches = model.params.num_batches
    batch_size = model.params.batch_size

    losses = []
    costs = []
    seq_lengths = []
    start_ms = get_ms()

    for batch_num, x, y in model.dataloader:
        loss, cost = train_batch(model.net, model.criterion, model.optimizer, x, y)
        losses += [loss]
        costs += [cost]
        seq_lengths += [y.size(0)]

        # Update the progress bar
        progress_bar(batch_num, report_interval, loss)

        # Report
        if batch_num % report_interval == 0:
            mean_loss = np.array(losses[-report_interval:]).mean()
            mean_cost = np.array(costs[-report_interval:]).mean()
            mean_time = int(((get_ms() - start_ms) / report_interval) / batch_size)
            progress_clean()
            start_ms = get_ms()
    print(losses)
    print(costs)

In [None]:
TASK = {
    'copy': (CopyTaskModelTraining, CopyTaskParams)
}
SEED = 1

np.random.seed(SEED)
torch.manual_seed(SEED)
random.seed(SEED)

# Initialize the Model
model_cls, params_cls = TASK['copy']
model = model_cls(params=params_cls())

# Train Model
train_model(model, 500)





### Test Model

In [None]:
def evaluate(net, criterion, X, Y):
    """Evaluate a single batch (without training)."""
    inp_seq_len = X.size(0)
    outp_seq_len, batch_size, _ = Y.size()

    # New sequence
    net.init_sequence(batch_size)

    # Feed the sequence + delimiter
    states = []
    for i in range(inp_seq_len):
        o, state = net(X[i])
        states += [state]

    # Read the output (no input given)
    y_out = torch.zeros(Y.size())
    for i in range(outp_seq_len):
        y_out[i], state = net()
        states += [state]

    loss = criterion(y_out, Y)

    y_out_binarized = y_out.clone().data
    y_out_binarized.apply_(lambda x: 0 if x < 0.5 else 1)

    # The cost is the number of error bits per sequence
    cost = torch.sum(torch.abs(y_out_binarized - Y.data))

    result = {
        'loss': loss.data[0],
        'cost': cost / batch_size,
        'y_out': y_out,
        'y_out_binarized': y_out_binarized,
        'states': states
    }

    return result

def test_model(model, report_interval=500):
    num_batches = model.params.num_batches
    batch_size = model.params.batch_size

    losses = []
    costs = []
    seq_lengths = []
    start_ms = get_ms()

    for batch_num, x, y in model.dataloader:
        res = evaluate(model.net, model.criterion, x, y)
        losses += [res['loss']]
        costs += [res['cost']]
        seq_lengths += [y.size(0)]

        # Update the progress bar
        progress_bar(batch_num, report_interval, loss)

        # Report
        if batch_num % report_interval == 0:
            mean_loss = np.array(losses[-report_interval:]).mean()
            mean_cost = np.array(costs[-report_interval:]).mean()
            mean_time = int(((get_ms() - start_ms) / report_interval) / batch_size)
            progress_clean()
            start_ms = get_ms()

In [None]:
# Test the Model
test_model(model)