In [1]:
##################################
#
# Implementation of linear logic recurrent neural network
#
# The architecture is a modified RNN, see the paper "Linear logic and recurrent neural networks".
# Our inputs are sequences of symbols taken from an alphabet of size num_classes. The length
# of the sequences is N. Our outputs are also sequences of length N from the same alphabet.
#
# Here "symbol" means a one hot vector.

# The next three lines are recommend by TF
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np
import collections
import six
import math
import time
import random

from tensorflow.python.ops.rnn_cell_impl import _RNNCell as RNNCell
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops.math_ops import sigmoid
from tensorflow.python.ops.math_ops import tanh

# Our libraries
import ntm
import seqhelper
import learnfuncs

###########
# TODO/BUGS
#
# - Running on N = 10 crashes

In [2]:
##############
# GLOBAL FLAGS

use_model             = 'ntm' # ntm, pattern_ntm, pattern_ntm_alt
task                  = 'pattern' # copy, repeat copy, pattern
epoch                 = 20 # number of training epochs
num_classes           = 5 # number of symbols in the alphabet
N                     = 20 # length of input sequences
batch_size            = 500 # take a smaller batch size (500 works) on Tesla
controller_state_size = 100 # dimension of the internal state space of the controller
memory_address_size   = 20 # number of memory locations
memory_content_size   = 5 # size of vector stored at a memory location
powers_ring1          = [0,1,2] # powers available to pattern NTM
model_optimizer       = 'rmsprop' # adam, rmsprop
LOG_DIR               = '/Users/murfetd/Coding/deeplinearlogic/log'

training_percent      = 0.01 # percentage used for training
num_training          = 20000 # int(training_percent * num_classes**N)
num_test              = 2 * num_training

In [3]:
#######################
# SETUP TASKS
#
# Our sequences are of one-hot vectors, which we interpret as follows:
#
# [1.0, 0.0, 0.0] = 0
# [0.0, 1.0, 0.0] = 1
# [0.0, 0.0, 1.0] = 2 etc
#
# We write our sequences and functions referring to sequences of integers,
# and then convert to one-hot vectors for integration with TF.

###########
# COPY TASK
if( task == 'copy' ):
    func_to_learn = learnfuncs.f_identity
    N_out = N

##################
# REPEAT COPY TASK
# put n zeros before the 1, for a copy task with n + 1 copies
if( task == 'repeat copy' ):
    pattern = [0,1]
    func_to_learn = lambda s: learnfuncs.f_repetitionpattern(s,pattern)
    N_out = 2 * N

##############
# PATTERN TASK
if( task == 'pattern' ):
    pattern = [1,0,0,2,0]
    func_to_learn = lambda s: learnfuncs.f_repetitionpattern(s,pattern)
    N_out = 2 * N

# Give an example input/output pair
a = [random.randint(0,num_classes-1) for i in range(N)]
fa = func_to_learn(a)

print("Under the chosen function, the sequence")
print(a)
print("is mapped to")
print(fa)

Under the chosen function, the sequence
[2, 0, 2, 1, 3, 3, 2, 1, 0, 0, 2, 2, 1, 1, 1, 4, 3, 2, 2, 3]
is mapped to
[2, 0, 0, 0, 1, 1, 3, 3, 3, 2, 2, 1, 1, 1, 0, 0, 2, 2, 2, 1, 1, 1, 1, 1, 4, 4, 3, 3, 3, 2, 2, 3, 3, 3, 0, 0, 2, 2, 2, 3]


In [4]:
################
# DEFINE MODEL #
################

input_size = num_classes # dimension of the input space I

# inputs, we create N of them, each of shape [None,input_size], one for
# each position in the sequence
inputs = [tf.placeholder(tf.float32, [None,input_size]) for _ in range(N)]
targets = [tf.placeholder(tf.float32, [None,input_size]) for _ in range(N_out)]

# state_size is the number of hidden neurons in each layer
state_size = 0

if( use_model == 'ntm' ):
    state_size = controller_state_size + 2*memory_address_size + memory_address_size * memory_content_size
    cell = ntm.NTM(state_size,input_size,controller_state_size,memory_address_size,memory_content_size, [-1,0,1])
elif( use_model == 'pattern_ntm' ):
    state_size = controller_state_size + 4*memory_address_size + \
                memory_address_size * memory_content_size + \
                memory_address_size * len(powers_ring1)

    cell = ntm.PatternNTM(state_size,input_size,controller_state_size,
                          memory_address_size,memory_content_size, powers_ring1, [-1,0,1])
elif( use_model == 'pattern_ntm_alt' ):
    state_size = controller_state_size + 4*memory_address_size + \
                memory_address_size * memory_content_size + \
                memory_address_size * len(powers_ring1)

    cell = ntm.PatternNTM_alt(state_size,input_size,controller_state_size,
                          memory_address_size,memory_content_size, powers_ring1, [-1,0,1])

# Initialise the state
state = tf.truncated_normal([batch_size, state_size], 0.0, 0.01, dtype=tf.float32)
#state = cell.zero_state(batch_size, tf.float32)

reuse = False

for i in range(N):
    output, state = cell(inputs[i],state,'NTM',reuse)
    reuse = True

# We only start recording the outputs of the controller once we have
# finished feeding in the input. We feed zeros as input in the second phase.
rnn_outputs = []
for i in range(N_out):
    output, state = cell(tf.zeros([batch_size,input_size]),state,'NTM',reuse)
    rnn_outputs.append(output)

# Final fully connected layer
E = tf.Variable(tf.truncated_normal([controller_state_size,input_size]))
F = tf.Variable(tf.constant(0.1, shape=[input_size]))

# prediction is a length N list of tensors of shape [None,input_size], where
# the jth row of prediction[d] is, for the jth input sequence in the batch,
# the probability distribution over symbols for the output symbol in position d.
logits = [tf.matmul(rnn_output, E) + F for rnn_output in rnn_outputs]
prediction = [tf.nn.softmax(logit) for logit in logits] 
ce = [tf.reduce_sum(targets[i] * tf.log(prediction[i])) for i in range(N_out)]

if( model_optimizer == 'adam' ):
    optimizer = tf.train.AdamOptimizer(1e-4)
elif( model_optimizer == 'rmsprop' ):
    optimizer = tf.train.RMSPropOptimizer(1e-4,decay=0.9,momentum=0.9)

cross_entropy = -tf.add_n(ce)
minimize = optimizer.minimize(cross_entropy)

mistakes = [tf.not_equal(tf.argmax(targets[i], 1), tf.argmax(prediction[i], 1)) for i in range(N_out)]
errors = [tf.reduce_mean(tf.cast(m, tf.float32)) for m in mistakes]

# Summaries
mean_error = tf.scalar_mul(np.true_divide(1,N_out), tf.add_n(errors))
tf.summary.scalar('error', mean_error)

[<tf.Tensor 'gradients/NTM_59/split_grad/concat:0' shape=(500, 240) dtype=float32>, None, None]
[<tf.Tensor 'gradients/NTM_58/split_grad/concat:0' shape=(500, 240) dtype=float32>, None, None]
[<tf.Tensor 'gradients/NTM_57/split_grad/concat:0' shape=(500, 240) dtype=float32>, None, None]
[<tf.Tensor 'gradients/NTM_56/split_grad/concat:0' shape=(500, 240) dtype=float32>, None, None]
[<tf.Tensor 'gradients/NTM_55/split_grad/concat:0' shape=(500, 240) dtype=float32>, None, None]
[<tf.Tensor 'gradients/NTM_54/split_grad/concat:0' shape=(500, 240) dtype=float32>, None, None]
[<tf.Tensor 'gradients/NTM_53/split_grad/concat:0' shape=(500, 240) dtype=float32>, None, None]
[<tf.Tensor 'gradients/NTM_52/split_grad/concat:0' shape=(500, 240) dtype=float32>, None, None]
[<tf.Tensor 'gradients/NTM_51/split_grad/concat:0' shape=(500, 240) dtype=float32>, None, None]
[<tf.Tensor 'gradients/NTM_50/split_grad/concat:0' shape=(500, 240) dtype=float32>, None, None]
[<tf.Tensor 'gradients/NTM_49/split_grad

<tf.Tensor 'error:0' shape=() dtype=string>

In [5]:
####################
# INITIALISE MODEL #
####################

init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)

merged_summaries = tf.summary.merge_all()
file_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)

In [6]:
############
# TRAINING #
############

random.seed()
one_hots = seqhelper.one_hot_vectors(num_classes)

pre_train_time = time.time()

# Training
no_of_batches = int(num_training/batch_size)

# An annoying thing here is that we cannot use a list as a key in a 
# dictionary. The workaround we found on StackOverflow here:
# http://stackoverflow.com/questions/33684657/issue-feeding-a-list-into-feed-dict-in-tensorflow)

# epoch is a global var
for i in range(epoch):
    for j in range(no_of_batches):
        inp = []
        out = []

        # We sample each batch on the fly from the set of all sequences
        for z in range(batch_size):
            a = [random.randint(0,num_classes-1) for k in range(N)]
            fa = func_to_learn(a)
            a_onehot = [one_hots[e] for e in a]
            fa_onehot = [one_hots[e] for e in fa]
            inp.append(np.array(a_onehot))
            out.append(np.array(fa_onehot))        
        
        feed_dict = {}
        for d in range(N):
            in_node = inputs[d]
            # inp has dimensions [batch_size, N, num_classes] and we want to extract
            # the 2D Tensor of shape [batch_size, num_classes] obtained by setting the
            # second coordinate to d
            ti = []
            for k in range(batch_size):
                ti.append(inp[k][d])
            feed_dict[in_node] = np.array(ti)

        for d in range(N_out):
            out_node = targets[d]
            to = []
            for k in range(batch_size):
                to.append(out[k][d])
            feed_dict[out_node] = np.array(to)
            
        summary,_ = sess.run([merged_summaries,minimize], feed_dict)
        file_writer.add_summary(summary)
    current_mean = np.mean(sess.run(errors, feed_dict))
    print("Epoch - " + str(i+1) + ", Mean error of final batch in epoch - " + str(current_mean))

print("")
print("It took", time.time() - pre_train_time, "seconds to train.")

Epoch - 1, Mean error of final batch in epoch - 0.77555
Epoch - 2, Mean error of final batch in epoch - 0.71845
Epoch - 3, Mean error of final batch in epoch - 0.714
Epoch - 4, Mean error of final batch in epoch - 0.6966
Epoch - 5, Mean error of final batch in epoch - 0.68945
Epoch - 6, Mean error of final batch in epoch - 0.679
Epoch - 7, Mean error of final batch in epoch - 0.6697
Epoch - 8, Mean error of final batch in epoch - 0.66875
Epoch - 9, Mean error of final batch in epoch - 0.67085
Epoch - 10, Mean error of final batch in epoch - 0.6507
Epoch - 11, Mean error of final batch in epoch - 0.6405
Epoch - 12, Mean error of final batch in epoch - 0.64365
Epoch - 13, Mean error of final batch in epoch - 0.63695
Epoch - 14, Mean error of final batch in epoch - 0.6237
Epoch - 15, Mean error of final batch in epoch - 0.62785
Epoch - 16, Mean error of final batch in epoch - 0.6246
Epoch - 17, Mean error of final batch in epoch - 0.6154
Epoch - 18, Mean error of final batch in epoch - 0.

In [7]:
###########
# TESTING #
###########

no_of_batches = int(num_test/batch_size)
#print("Number of batches: " + str(no_of_batches))

error_means = []
for j in range(no_of_batches):
    inp = []
    out = []

    # We sample each batch on the fly from the set of all sequences
    for z in range(batch_size):
        a = [random.randint(0,num_classes-1) for k in range(N)]
        fa = func_to_learn(a)
        a_onehot = [one_hots[e] for e in a]
        fa_onehot = [one_hots[e] for e in fa]
        inp.append(np.array(a_onehot))
        out.append(np.array(fa_onehot))        
        
    feed_dict = {}
    for d in range(N):
        in_node = inputs[d]
        ti = []
        for k in range(batch_size):
            ti.append(inp[k][d])
        feed_dict[in_node] = np.array(ti)

    for d in range(N_out):
        out_node = targets[d]
        to = []
        for k in range(batch_size):
            to.append(out[k][d])
        feed_dict[out_node] = np.array(to)
            
    current_mean = np.mean(sess.run(errors, feed_dict))
    error_means.append(current_mean)
    print("Batch - " + str(j+1) + ", Mean error - " + str(current_mean))

final_error = np.mean(error_means)

# The first three digits of this should match the printout for the
# first three test output sequences given earlier
#data = sess.run([tf.argmax(targets[0],1), tf.argmax(prediction[0],1)],feed_dict)

#print("First digits of test outputs (actual)")
#print(data[0])
#print("First digits of test outputs (predicted)")
#print(data[1])

# print the mean of the errors in each digit for the test set.
#incorrects = sess.run(errors, feed_dict)
# print(incorrects)

print("")        
print("###########")
print("# Summary #")
print("###########")
print("")
print("model         - " + use_model)
print("task name     - " + task)
print("num_classes   - " + str(num_classes))
print("N             - " + str(N))
print("N_out         - " + str(N_out))
print("ring powers   - " + str(powers_ring1))
print("# epochs      - " + str(epoch))
print("optimizer     - " + str(model_optimizer))
print("# weights     - " + str(ntm.count_number_trainable_params()))
print("(css,mas,mcs) - (" + str(controller_state_size) + "," + str(memory_address_size) + "," + str(memory_content_size) + ")")
print("train percent - " + str(training_percent))
print("num_training  - " + str(num_training) + "/" + str(num_classes**N))
print("num_test      - " + str(num_test) + "/" + str(num_classes**N))
print("")
print("")
print("error         - " + str(final_error))
sess.close()

Batch - 1, Mean error - 0.60605
Batch - 2, Mean error - 0.60385
Batch - 3, Mean error - 0.6078
Batch - 4, Mean error - 0.60905
Batch - 5, Mean error - 0.59985
Batch - 6, Mean error - 0.6045
Batch - 7, Mean error - 0.60605
Batch - 8, Mean error - 0.60725
Batch - 9, Mean error - 0.60735
Batch - 10, Mean error - 0.60595
Batch - 11, Mean error - 0.6105
Batch - 12, Mean error - 0.59895
Batch - 13, Mean error - 0.60205
Batch - 14, Mean error - 0.61105
Batch - 15, Mean error - 0.60795
Batch - 16, Mean error - 0.6069
Batch - 17, Mean error - 0.616
Batch - 18, Mean error - 0.60825
Batch - 19, Mean error - 0.6093
Batch - 20, Mean error - 0.60225
Batch - 21, Mean error - 0.5956
Batch - 22, Mean error - 0.611
Batch - 23, Mean error - 0.6069
Batch - 24, Mean error - 0.60235
Batch - 25, Mean error - 0.60495
Batch - 26, Mean error - 0.6133
Batch - 27, Mean error - 0.6094
Batch - 28, Mean error - 0.6043
Batch - 29, Mean error - 0.6042
Batch - 30, Mean error - 0.60625
Batch - 31, Mean error - 0.6012
Ba