In [1]:
%matplotlib inline

import os
import sys 
import random

import numpy as np
import tensorflow as tf
import matplotlib.pylab as plt
from PIL import Image
from PIL import ImageOps

sys.path.append(os.path.abspath("/home/himani/memory_augmented_neural_net"))
from utils import OmniglotDataLoader, one_hot_decode, five_hot_decode
#from model import NTMOneShotLearningModel

Setup flags

In [2]:
class flags:
    mode='train'
    restore_training=False
    debug=False
    label_type='one_hot'
    n_classes=3
    seq_length=n_classes*10
    augment=True
    model='MANN'
    read_head_num=4
    batch_size=2
    num_epoches=100000
    learning_rate=1e-3
    rnn_size=200
    image_width=20
    image_height=20
    rnn_num_layers=1
    memory_size=128
    memory_vector_dim=40
    test_batch_num=100
    n_train_classes=1200
    n_test_classes=423
    save_dir='/home/himani/checkpoints/run1'
    tensorboard_dir='/home/himani/logs/run1'
    test_frequency=2
    save_frequency=5000

args=flags

Setup dataloader

In [3]:
data_loader = OmniglotDataLoader(
    image_size=(args.image_width, args.image_height),
    n_train_classses=args.n_train_classes,
    n_test_classes=args.n_test_classes)
print("Dataloading complete.")

Dataloading complete...


Build model with default settings: 
* label_type: one_hot
* model: MANN

In [4]:
class MANNCell():
    def __init__(self, rnn_size, memory_size, memory_vector_dim, head_num, gamma=0.95,
                 reuse=False, k_strategy='separate'):
        self.rnn_size = rnn_size
        self.memory_size = memory_size
        self.memory_vector_dim = memory_vector_dim
        self.head_num = head_num                                    # #(read head) == #(write head)
        self.reuse = reuse
        self.controller = tf.nn.rnn_cell.BasicLSTMCell(self.rnn_size)
        self.step = 0
        self.gamma = gamma
        self.k_strategy = k_strategy

    def __call__(self, x, prev_state):
        prev_read_vector_list = prev_state['read_vector_list']      # read vector (the content that is read out, length = memory_vector_dim)
        prev_controller_state = prev_state['controller_state']      # state of controller (LSTM hidden state)

        # x + prev_read_vector -> controller (RNN) -> controller_output

        controller_input = tf.concat([x] + prev_read_vector_list, axis=1)
        with tf.variable_scope('controller', reuse=self.reuse):
            controller_output, controller_state = self.controller(controller_input, prev_controller_state)

        # controller_output (after fully connected layer): 
        #                       -> k (dim = memory_vector_dim, compared to each vector in M)
        #                       -> a (dim = memory_vector_dim, add vector, only when k_strategy='separate')
        #                       -> alpha (scalar, combination of w_r and w_lu)

        if self.k_strategy == 'summary':
            num_parameters_per_head = self.memory_vector_dim + 1
        elif self.k_strategy == 'separate':
            num_parameters_per_head = self.memory_vector_dim * 2 + 1
        total_parameter_num = num_parameters_per_head * self.head_num
        
        with tf.variable_scope("o2p", reuse=(self.step > 0) or self.reuse):
            o2p_w = tf.get_variable('o2p_w', [controller_output.get_shape()[1], total_parameter_num],
                                    initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1))
            o2p_b = tf.get_variable('o2p_b', [total_parameter_num],
                                    initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1))
            parameters = tf.nn.xw_plus_b(controller_output, o2p_w, o2p_b)
        head_parameter_list = tf.split(parameters, self.head_num, axis=1) #k,a,alpha

        # k, prev_M -> w_r
        # alpha, prev_w_r, prev_w_lu -> w_w

        prev_w_r_list = prev_state['w_r_list']      # vector of weightings (blurred address) over locations
        prev_M = prev_state['M']
        prev_w_u = prev_state['w_u']
        prev_indices, prev_w_lu = self.least_used(prev_w_u)
        w_r_list = []
        w_w_list = []
        k_list = []
        a_list = []
        # p_list = []   # For debugging
        for i, head_parameter in enumerate(head_parameter_list):
            with tf.variable_scope('addressing_head_%d' % i):
                k = tf.tanh(head_parameter[:, 0:self.memory_vector_dim], name='k')
                if self.k_strategy == 'separate':
                    a = tf.tanh(head_parameter[:, self.memory_vector_dim:self.memory_vector_dim * 2], name='a')
                sig_alpha = tf.sigmoid(head_parameter[:, -1:], name='sig_alpha')
                w_r = self.read_head_addressing(k, prev_M)
                w_w = self.write_head_addressing(sig_alpha, prev_w_r_list[i], prev_w_lu)
            w_r_list.append(w_r)
            w_w_list.append(w_w)
            k_list.append(k)
            if self.k_strategy == 'separate':
                a_list.append(a)
            # p_list.append({'k': k, 'sig_alpha': sig_alpha, 'a': a})   # For debugging

        w_u = self.gamma * prev_w_u + tf.add_n(w_r_list) + tf.add_n(w_w_list)   # eq (20)

        # Set least used memory location computed from w_(t-1)^u to zero
        M_ = prev_M * tf.expand_dims(1. - tf.one_hot(prev_indices[:, -1], self.memory_size), dim=2)

        # Writing
        M = M_
        with tf.variable_scope('writing'):
            for i in range(self.head_num):
                w = tf.expand_dims(w_w_list[i], axis=2)
                if self.k_strategy == 'summary':
                    k = tf.expand_dims(k_list[i], axis=1)
                elif self.k_strategy == 'separate':
                    k = tf.expand_dims(a_list[i], axis=1)
                M = M + tf.matmul(w, k)

        # Reading
        read_vector_list = []
        with tf.variable_scope('reading'):
            for i in range(self.head_num):
                read_vector = tf.reduce_sum(tf.expand_dims(w_r_list[i], dim=2) * M, axis=1)
                read_vector_list.append(read_vector)

        # controller_output -> NTM output
        NTM_output = tf.concat([controller_output] + read_vector_list, axis=1)

        state = {
            'controller_state': controller_state,
            'read_vector_list': read_vector_list,
            'w_r_list': w_r_list,
            'w_w_list': w_w_list,
            'w_u': w_u,
            'M': M,
        }

        self.step += 1
        return NTM_output, state

    def read_head_addressing(self, k, prev_M):
        with tf.variable_scope('read_head_addressing'):

            # Cosine Similarity

            k = tf.expand_dims(k, axis=2)
            inner_product = tf.matmul(prev_M, k)
            k_norm = tf.sqrt(tf.reduce_sum(tf.square(k), axis=1, keep_dims=True))
            M_norm = tf.sqrt(tf.reduce_sum(tf.square(prev_M), axis=2, keep_dims=True))
            norm_product = M_norm * k_norm
            K = tf.squeeze(inner_product / (norm_product + 1e-8))                   # eq (17)

            # Calculating w^c

            K_exp = tf.exp(K)
            w = K_exp / tf.reduce_sum(K_exp, axis=1, keep_dims=True)                # eq (18)

            return w

    def write_head_addressing(self, sig_alpha, prev_w_r, prev_w_lu):
        with tf.variable_scope('write_head_addressing'):

            # Write to (1) the place that was read in t-1 (2) the place that was least used in t-1

            return sig_alpha * prev_w_r + (1. - sig_alpha) * prev_w_lu              # eq (22)

    def least_used(self, w_u):
        _, indices = tf.nn.top_k(w_u, k=self.memory_size)
        w_lu = tf.reduce_sum(tf.one_hot(indices[:, -self.head_num:], depth=self.memory_size), axis=1)
        return indices, w_lu

    def zero_state(self, batch_size, dtype):
        one_hot_weight_vector = np.zeros([batch_size, self.memory_size])
        one_hot_weight_vector[..., 0] = 1
        one_hot_weight_vector = tf.constant(one_hot_weight_vector, dtype=tf.float32)
        with tf.variable_scope('init', reuse=self.reuse):
            state = {
                'controller_state': self.controller.zero_state(batch_size, dtype),
                'read_vector_list': [tf.zeros([batch_size, self.memory_vector_dim])
                                     for _ in range(self.head_num)],
                'w_r_list': [one_hot_weight_vector for _ in range(self.head_num)],
                'w_u': one_hot_weight_vector,
                'M': tf.constant(np.ones([batch_size, self.memory_size, self.memory_vector_dim]) * 1e-6, dtype=tf.float32)
            }
            return state

In [5]:
class NTMOneShotLearningModel():
    def __init__(self, args):
        args.output_dim = args.n_classes
        
        #placeholders for inputs
        self.x_image = tf.placeholder(dtype=tf.float32,
                                      shape=[args.batch_size, args.seq_length, args.image_width * args.image_height])
        self.x_label = tf.placeholder(dtype=tf.float32,
                                      shape=[args.batch_size, args.seq_length, args.output_dim])
        self.y = tf.placeholder(dtype=tf.float32,
                                shape=[args.batch_size, args.seq_length, args.output_dim])
        
        #create cell
        self.cell = MANNCell(args.rnn_size, args.memory_size, args.memory_vector_dim,
                                    head_num=args.read_head_num)
        
        #step over the entire episode
        state = self.cell.zero_state(args.batch_size, tf.float32)
        self.state_list = [state]   # For debugging
        self.o = []
        for t in range(args.seq_length):
            output, state = self.cell(tf.concat([self.x_image[:, t, :], self.x_label[:, t, :]], axis=1), state)
            # output, state = self.cell(self.y[:, t, :], state)
            with tf.variable_scope("o2o", reuse=(t > 0)):
                o2o_w = tf.get_variable('o2o_w', [output.get_shape()[1], args.output_dim],
                                        initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1))
                o2o_b = tf.get_variable('o2o_b', [args.output_dim],
                                        initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1))
                output = tf.nn.xw_plus_b(output, o2o_w, o2o_b)
            
            output = tf.nn.softmax(output, dim=1)

            self.o.append(output)
            self.state_list.append(state)
            
        self.o = tf.stack(self.o, axis=1)
        self.state_list.append(state)

        eps = 1e-8
        self.learning_loss = -tf.reduce_mean(  # cross entropy function
            tf.reduce_sum(self.y * tf.log(self.o + eps), axis=[1, 2])
        )

        self.o = tf.reshape(self.o, shape=[args.batch_size, args.seq_length, -1])

        with tf.variable_scope('optimizer'):
            self.optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
            self.train_op = self.optimizer.minimize(self.learning_loss)

Setup training variables

In [6]:
print("Initializing model...")
model = NTMOneShotLearningModel(args)

Initializing model...


In [7]:
config=tf.ConfigProto()
config.gpu_options.allow_growth = True
sess=tf.Session(config=config)

if args.restore_training:
    saver = tf.train.Saver(max_to_keep=2)
    ckpt = tf.train.get_checkpoint_state(args.save_dir + '/' + args.model)
    saver.restore(sess, ckpt.model_checkpoint_path)
else:
    saver = tf.train.Saver(tf.global_variables(),max_to_keep=2)
    sess.run(tf.global_variables_initializer())

In [8]:
##assumption, min seq_length is 11

def test_f(args, y, output):
    correct = [0] * args.seq_length
    total = [0] * args.seq_length
    if args.label_type == 'one_hot':
        y_decode = one_hot_decode(y)
        output_decode = one_hot_decode(output)
    elif args.label_type == 'five_hot':
        y_decode = five_hot_decode(y)
        output_decode = five_hot_decode(output)
    for i in range(np.shape(y)[0]):
        y_i = y_decode[i]
        output_i = output_decode[i]
        # print(y_i)
        # print(output_i)
        class_count = {}
        for j in range(args.seq_length):
            if y_i[j] not in class_count:
                class_count[y_i[j]] = 0
            class_count[y_i[j]] += 1
            total[class_count[y_i[j]]] += 1
            if y_i[j] == output_i[j]:
                correct[class_count[y_i[j]]] += 1
    return [float(correct[i]) / total[i] if total[i] > 0. else 0. for i in range(1, 11)]

In [9]:
args.num_epoches=1 #args.num_epoches
args.test_frequency=1

In [10]:
print("batch\tloss\t1st\t2nd\t3rd\t4th\t5th\t6th\t7th\t8th\t9th\t10th")


for b in range(args.num_epoches):
    # Test
    if b % args.test_frequency == 0:
        x_image, x_label, y = data_loader.fetch_batch(
            args.n_classes, args.batch_size, args.seq_length,
            type='test',augment=args.augment,label_type=args.label_type)
        
        feed_dict = {model.x_image: x_image, 
                     model.x_label: x_label, 
                     model.y: y}
        
        output, learning_loss = sess.run([model.o, model.learning_loss], feed_dict=feed_dict)

        print('%d\t%.4f\t' % (b, learning_loss)),
        accuracy = test_f(args, y, output)
        for accu in accuracy:
            print('%.4f\t' % accu),
        print('')
        
    if b % args.save_frequency == 0 and b > 0:
        saver.save(sess, args.save_dir + '/' + args.model + '/model.tfmodel', global_step=b)
    
    # Train
    x_image, x_label, y = data_loader.fetch_batch(
        args.n_classes, args.batch_size, args.seq_length,
        type='train',augment=args.augment,label_type=args.label_type)
    
    feed_dict = {model.x_image: x_image, 
                 model.x_label: x_label, 
                 model.y: y}
    
    sess.run(model.train_op, feed_dict=feed_dict)

batch	loss	1st	2nd	3rd	4th	5th	6th	7th	8th	9th	10th
0	33.0854	0.3333	0.3333	0.3333	0.3333	0.3333	0.3333	0.3333	0.3333	0.3333	0.3333	


Outputs

In [16]:
x_image, x_label, y = data_loader.fetch_batch(
    args.n_classes, args.batch_size, args.seq_length,
    type='train',augment=args.augment,label_type=args.label_type)

feed_dict = {model.x_image: x_image, 
             model.x_label: x_label, 
             model.y: y}

eval_outputs=[model.state_list,
             model.o]


In [17]:
print(len(x_image)) #batch-size
print(len(x_image[0])) #episode_length
print(x_image[0][0].shape)

2
30
(400,)


In [18]:
out=sess.run(eval_outputs, feed_dict=feed_dict)

In [19]:
state_list=out[0]
o=out[1]

In [20]:
print(len(state_list))
print(len(state_list[0]))
print(state_list[0].keys())

32
5
['w_u', 'read_vector_list', 'controller_state', 'M', 'w_r_list']


In [21]:
print(state_list[0]['w_u'].shape) #batch_size, mem_size

(2, 128)


In [22]:
print(len(state_list[0]['read_vector_list'])) #num_read_heads
print(state_list[0]['read_vector_list'][0].shape) #batch_size, mem_dim

4
(2, 40)


In [23]:
print(type(state_list[0]['controller_state']))# 2-tuples of the c_state and m_state
print(len(state_list[0]['controller_state'])) 
print(state_list[0]['controller_state'][0].shape) #batch_size,rnn_size

print(model.cell.controller.state_size)

<class 'tensorflow.python.ops.rnn_cell_impl.LSTMStateTuple'>
2
(2, 200)
LSTMStateTuple(c=200, h=200)


In [19]:
print(state_list[0]['M'].shape) #batch_size,mem_size,mem_dim

(2, 128, 40)


In [20]:
print(len(state_list[0]['w_r_list'])) #num_read_heads
print(state_list[0]['w_r_list'][0].shape) #batch_size, mem_size

4
(2, 128)


LSTMStateTuple(c=200, h=200)

Calculating w_l_u

In [30]:
wu=state_list[0]['w_u']
print(wu[0])

[ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.]


In [39]:
bb=np.random.rand(10)
bb=np.reshape(bb,(2,-1))
print(bb)

[[ 0.25650661  0.20915602  0.29043392  0.34055697  0.87345325]
 [ 0.18170077  0.44975874  0.93630793  0.37495192  0.51294125]]


In [47]:
vals, indices = tf.nn.top_k(bb, k=5)
inds2=indices[:, -2:]
tmp=tf.one_hot(inds2,depth=5)
bb2= tf.reduce_sum(tmp, axis=1)

In [48]:
sess.run([vals,indices,inds2,tmp,bb2])

[array([[ 0.87345325,  0.34055697,  0.29043392,  0.25650661,  0.20915602],
        [ 0.93630793,  0.51294125,  0.44975874,  0.37495192,  0.18170077]]),
 array([[4, 3, 2, 0, 1],
        [2, 4, 1, 3, 0]], dtype=int32),
 array([[0, 1],
        [3, 0]], dtype=int32),
 array([[[ 1.,  0.,  0.,  0.,  0.],
         [ 0.,  1.,  0.,  0.,  0.]],
 
        [[ 0.,  0.,  0.,  1.,  0.],
         [ 1.,  0.,  0.,  0.,  0.]]], dtype=float32),
 array([[ 1.,  1.,  0.,  0.,  0.],
        [ 1.,  0.,  0.,  1.,  0.]], dtype=float32)]

Reading memory

In [49]:
wr=state_list[0]['w_r_list']
print(wr[0])

[[ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.]
 [ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  

In [57]:
bb=np.random.rand(10)
bb=np.reshape(bb,(2,-1))
print(bb)
print(bb.shape)

[[ 0.8994863   0.11661608  0.48097275  0.99441174  0.42735952]
 [ 0.38328854  0.6363465   0.60420785  0.75010454  0.53011838]]
(2, 5)


In [71]:
mem=np.ones((5,3))
print(mem)

[[ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]]


In [72]:
bb2=tf.expand_dims(bb, dim=2)
mul=bb2*mem
bb3=tf.reduce_sum(mul,axis=1)

In [73]:
a,b,c=sess.run([bb2,mul,bb3])

In [74]:
print(a)
print(len(a))
print(a[0].shape)

[[[ 0.8994863 ]
  [ 0.11661608]
  [ 0.48097275]
  [ 0.99441174]
  [ 0.42735952]]

 [[ 0.38328854]
  [ 0.6363465 ]
  [ 0.60420785]
  [ 0.75010454]
  [ 0.53011838]]]
2
(5, 1)


In [75]:
print(b)
print(b.shape)

[[[ 0.8994863   0.8994863   0.8994863 ]
  [ 0.11661608  0.11661608  0.11661608]
  [ 0.48097275  0.48097275  0.48097275]
  [ 0.99441174  0.99441174  0.99441174]
  [ 0.42735952  0.42735952  0.42735952]]

 [[ 0.38328854  0.38328854  0.38328854]
  [ 0.6363465   0.6363465   0.6363465 ]
  [ 0.60420785  0.60420785  0.60420785]
  [ 0.75010454  0.75010454  0.75010454]
  [ 0.53011838  0.53011838  0.53011838]]]
(2, 5, 3)


In [76]:
print(c)
print(c.shape)

[[ 2.91884639  2.91884639  2.91884639]
 [ 2.90406581  2.90406581  2.90406581]]
(2, 3)


Looping over internal states

In [24]:
x_image, x_label, y = data_loader.fetch_batch(
    args.n_classes, args.batch_size, args.seq_length,
    type='train',augment=args.augment,label_type=args.label_type)

feed_dict = {model.x_image: x_image, 
             model.x_label: x_label, 
             model.y: y}

eval_outputs=[model.state_list]

In [39]:
for i in range(0,3):
    out=sess.run(eval_outputs, feed_dict=feed_dict)
    print('\n ******** step: %d **********'%i)
    print(out[0][0]['controller_state'])


 ******** step: 0 **********
LSTMStateTuple(c=array([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
    