You might need to modify the third line in the code cell below, to make sure you cd to the actual directory which your ipynb file is located in.

**Caution**: due to the nature of this project's setup, everytime you want to rerun some code cell below, please click **Runtime -> Restart and run all**; this operation clears the computational graphs and the local variables but allow training and testing data that are already loaded from google drive to stay in the colab runtime space. Please do **not** do the following if you just wish to rerun code: click Runtime -> reset all runtimes, and then click Runtime -> Run all; it will remount your google drive, and remove the training and testing data already loaded in your colab runtime space. **Runtime -> Restart and run all** automatically avoids remounting the drive after the first time you run the notebook file; the loaded data can usually stay in your colab runtime space for many hours.

Loading the training and testing data after remounting your google drive takes 30 - 40 minutes.

In [0]:
from google.colab import drive
drive.mount("/content/gdrive/", force_remount=True)
%cd gdrive/My Drive/Neural_Turing_Machine/NTM_small

In [0]:
from utils import OmniglotDataLoader, one_hot_decode, five_hot_decode
import tensorflow as tf
import argparse
import numpy as np
%tensorflow_version 1.x
print(tf.__version__)


Already implemented, no need to change.

This class is part of the training loop.

In [0]:
class NTMOneShotLearningModel():
  def __init__(self, model, n_classes, batch_size, seq_length, image_width, image_height,
                rnn_size, num_memory_slots, rnn_num_layers, read_head_num, write_head_num, memory_vector_dim, learning_rate):
    self.output_dim = n_classes

    # Note: the images are flattened to 1D tensors
    # The input data structure is of the following form:
    # self.x_image[i,j,:] = jth image in the ith sequence (or, episode)
    self.x_image = tf.placeholder(dtype=tf.float32, shape=[batch_size, seq_length, image_width * image_height])
    # Model's output label is one-hot encoded
    # The data structure is of the following form:
    # self.x_label[i,j,:] = one-hot label of the jth image in 
    #             the ith sequence (or, episode)
    self.x_label = tf.placeholder(dtype=tf.float32, shape=[batch_size, seq_length, self.output_dim])
    # Target label is one-hot encoded
    self.y = tf.placeholder(dtype=tf.float32, shape=[batch_size, seq_length, self.output_dim])
    
    # The dense layer for mapping controller output and retrieved
    # memory content to classification labels
    self.controller_output_to_ntm_output = tf.keras.layers.Dense(units=self.output_dim, use_bias=True)

    if model == 'LSTM':
      # Using a LSTM layer to serve as the controller, no memory
      def rnn_cell(rnn_size):
        return tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
      cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell(rnn_size) for _ in range(rnn_num_layers)])
      state = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
    
    # Initialize the controller model, including wiping its memory
    # Also, get the initial state of the MANN model
    
    self.state_list = [state]
    # Setup the NTM's output
    self.o = []
    
    # Now iterate over every sample in the sequence 
    for t in range(seq_length):
      output, state = cell(tf.concat([self.x_image[:, t, :], self.x_label[:, t, :]], axis=1), state)
      # Map controller output (with retrieved memory) + current (offseted) label 
      # to the overall ntm's output with an affine operation
      # The output is the classification labels
      output = self.controller_output_to_ntm_output(output)
      output = tf.nn.softmax(output, axis=1)
      self.o.append(output)
      self.state_list.append(state)
    # post-process the output of the classifier
    self.o = tf.stack(self.o, axis=1)
    self.state_list.append(state)

    eps = 1e-8
    # cross entropy, between model output labels and target labels
    self.learning_loss = -tf.reduce_mean(  
        tf.reduce_sum(self.y * tf.log(self.o + eps), axis=[1, 2])
    )
    
    self.o = tf.reshape(self.o, shape=[batch_size, seq_length, -1])
    self.learning_loss_summary = tf.summary.scalar('learning_loss', self.learning_loss)

    with tf.variable_scope('optimizer'):
      self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
      self.train_op = self.optimizer.minimize(self.learning_loss)

The training and testing functions

In [0]:
def train(learning_rate, image_width, image_height, n_train_classes, n_test_classes, restore_training, \
         num_epochs, n_classes, batch_size, seq_length, num_memory_slots, augment, save_dir, model_path, tensorboard_dir):
  
  # We always use one-hot encoding of the labels in this experiment
  label_type = "one_hot"

  # Initialize the model
  model = NTMOneShotLearningModel(model=model_path, n_classes=n_classes,\
                    batch_size=batch_size, seq_length=seq_length,\
                    image_width=image_width, image_height=image_height, \
                    rnn_size=rnn_size, num_memory_slots=num_memory_slots,\
                    rnn_num_layers=rnn_num_layers, read_head_num=read_head_num,\
                    write_head_num=write_head_num, memory_vector_dim=memory_vector_dim,\
                    learning_rate=learning_rate)
  print("Model initialized")
  data_loader = OmniglotDataLoader(
      image_size=(image_width, image_height),
      n_train_classses=n_train_classes,
      n_test_classes=n_test_classes
  )
  print("Data loaded")
  # Note: our training loop is in the tensorflow 1.x style
  with tf.Session() as sess:
    if restore_training:
      saver = tf.train.Saver()
      ckpt = tf.train.get_checkpoint_state(save_dir + '/' + model_path)
      saver.restore(sess, ckpt.model_checkpoint_path)
    else:
      saver = tf.train.Saver(tf.global_variables())
      tf.global_variables_initializer().run()
    train_writer = tf.summary.FileWriter(tensorboard_dir + '/' + model_path, sess.graph)
    print("1st\t2nd\t3rd\t4th\t5th\t6th\t7th\t8th\t9th\t10th\tepoch\tloss")
    for b in range(num_epochs):
      # Test the model
      if b % 100 == 0:
        # Note: the images are flattened to 1D tensors
        # The input data structure is of the following form:
        # x_image[i,j,:] = jth image in the ith sequence (or, episode)
        # And the sequence of 50 images x_image[i,:,:] constitute
        # one episode, and each class (out of 5 classes) has around 10
        # appearances in this sequence, as seq_length = 50 and 
        # n_classes = 5, as specified in the code block below
        # See the details in utils.py, OmniglotDataLoader class
        x_image, x_label, y = data_loader.fetch_batch(n_classes, batch_size, seq_length,
                                  type='test',
                                  augment=augment,
                                  label_type=label_type)
        feed_dict = {model.x_image: x_image, model.x_label: x_label, model.y: y}
        output, learning_loss = sess.run([model.o, model.learning_loss], feed_dict=feed_dict)
        merged_summary = sess.run(model.learning_loss_summary, feed_dict=feed_dict)
        train_writer.add_summary(merged_summary, b)
        accuracy = test(seq_length, y, output)
        for accu in accuracy:
          print('%.4f' % accu, end='\t')
        print('%d\t%.4f' % (b, learning_loss))

      # Save model per 2000 epochs
      if b%2000==0 and b>0:
        saver.save(sess, save_dir + '/' + model_path + '/model.tfmodel', global_step=b)

      # Train the model
      x_image, x_label, y = data_loader.fetch_batch(n_classes, batch_size, seq_length, \
                                type='train',
                                augment=augment,
                                label_type=label_type)
      feed_dict = {model.x_image: x_image, model.x_label: x_label, model.y: y}
      sess.run(model.train_op, feed_dict=feed_dict)
      
# Fill in this function. You might not need seq_length (the length of an episode)
# as an input, depending on your setup 
# Note: y is the true labels, and of shape (batch_size, seq_length, 5)
# output is the network's classification labels
def test(seq_length, y, output):
  # Fill in

  return # Fill in

In [0]:
restore_training = False
label_type = "one_hot"
n_classes = 5
seq_length = 50
augment = True
read_head_num = 4
batch_size = 16
num_epochs = 100000
learning_rate = 1e-3
rnn_size = 200
image_width = 20
image_height = 20
rnn_num_layers = 1
num_memory_slots = 128
memory_vector_dim = 40
shift_range = 1
write_head_num = 4
test_batch_num = 100
n_train_classes = 220
n_test_classes = 60
save_dir = './save/one_shot_learning'
tensorboard_dir = './summary/one_shot_learning'
model_path = 'LSTM'
train(learning_rate, image_width, image_height, n_train_classes, n_test_classes, restore_training, \
         num_epochs, n_classes, batch_size, seq_length, num_memory_slots, augment, save_dir, model_path, tensorboard_dir)
