code for creating a Kernel based relu-RNN learning for sequential MNIST
adapted from : Roth, Christopher, Ingmar Kanitscheider, and Ila Fiete. 2018. “Kernel RNN Learning (KeRNL),” September. https://openreview.net/forum?id=ryGfnoC5KQ.

In [1]:
import numpy as np 
import tensorflow as tf
import matplotlib.pyplot as plt 

import collections
import hashlib
import numbers
import matplotlib.cm as cm
from sys import getsizeof
from datetime import datetime
from pathlib import Path
import os

from pandas import DataFrame
from IPython.display import HTML




from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import _Linear
from tensorflow.contrib import slim

## user defined modules 
# kernel rnn cell 
import keRNL_cell 

In [4]:
# uplading mnist data 

old_v = tf.logging.get_verbosity()
tf.logging.set_verbosity(tf.logging.ERROR)

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
train_data = mnist.train.images  # Returns np.array
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images  # Returns np.array
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)

tf.logging.set_verbosity(old_v)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [5]:
# Training Parameters
learning_rate = 1e-4
training_steps = 5000
batch_size = 100
display_step = 2
test_len=128
grad_clip=100
# Network Parameters
num_input = 1 # MNIST data input (img shape: 28*28)
timesteps = 28*28 # timesteps
num_hidden = 128 # hidden layer num of features
num_classes = 10 # MNIST total classes (0-9 digits)

# tf Graph input


In [10]:
def kernel_RNN(x, weights, biases):

    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, timesteps, n_input)
    # Required shape: 'timesteps' tensors list of shape (batch_size, n_input)

    # Unstack to get a list of 'timesteps' tensors of shape (batch_size, n_input)
    with tf.variable_scope('recurrent',initializer=tf.initializers.identity()) as scope: 
        # Define a lstm cell with tensorflow
        keRNL = keRNL_cell.KeRNLCell(num_units=num_hidden,num_inputs=num_input,time_steps=timesteps,sensitivity_initializer=tf.initializers.identity)
        # Get lstm cell output
        kernel_outputs, kernel_states = tf.nn.dynamic_rnn(keRNL , x, dtype=tf.float32)

    # Linear activation, using rnn inner loop last output
    return tf.matmul(kernel_outputs[:,-1,:], weights['out']) + biases['out'], kernel_states

In [38]:
tf.reset_default_graph()
graph=tf.Graph()
with graph.as_default():
    # Define weights
    weights = {
        'out': tf.Variable(tf.random_normal([num_hidden, num_classes]),name='output_weight')
    }
    biases = {
        'out': tf.Variable(tf.random_normal([num_classes]),name='output_addition')
    }
    X = tf.placeholder("float", [None, timesteps, num_input])
    Y = tf.placeholder("float", [None, num_classes])
    logits,states = kernel_RNN(X, weights, biases)
    loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
    logits=logits, labels=Y))
    prediction = tf.nn.softmax(logits)
    variable_names=[v.name for v in tf.trainable_variables()]
    trainables=tf.trainable_variables()
    # Define optimizer
    optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate)
    
    # compute gradients 
    # get the trainable variables 
    temporal_kernel_index=[np.unicode_.find(k.name,'temporal_filter_coeff')>-1 for k in trainables].index(True)
    sensitivity_tensor_index=[np.unicode_.find(k.name,'sensitivity_tensor')>-1 for k in trainables].index(True)    
    kernel_index=[np.unicode_.find(k.name,'kernel')>-1 for k in trainables].index(True)
    bias_index=[np.unicode_.find(k.name,'bias')>-1 for k in trainables].index(True)
    output_weight_index=[np.unicode_.find(k.name,'output_weight')>-1 for k in trainables].index(True)
    # create a new trainable list 
    #a=np.asarray([temporal_kernel_index,sensitivity_tensor_index,kernel_index,bias_index],dtype=np.int)
    a=np.asarray([sensitivity_tensor_index,temporal_kernel_index,kernel_index],dtype=np.int)
    new_trainables= [trainables[k] for k in a ]
# manually calculate gradients 
    #1- sensitivity tensor gradient 
    sensitivity_tensor_update=tf.matmul(tf.transpose(states.delta_sensitivity),states.Gamma)
    
    #2- temporal kernel coefficient gradient 
    temporal_kernel_update=tf.reduce_mean(tf.multiply(states.delta_sensitivity,tf.matmul(states.Gamma,tf.transpose(trainables[sensitivity_tensor_index]))),axis=0)
    
    #3- gradient for the recurrent weights 
    grad_cost_to_output=tf.gradients(loss_op,logits, name= 'grad_cost_to_y')
    error_in_hidden_state=tf.expand_dims(tf.reduce_mean(tf.matmul(grad_cost_to_output[-1],tf.transpose(trainables[output_weight_index])),axis=0),axis=0)
    weight_update_aux=tf.matmul(trainables[kernel_index],tf.transpose(error_in_hidden_state))
    total_trace=tf.concat([states.input_trace,states.recurrent_trace],axis=2)
    weight_update=tf.transpose(tf.reduce_mean(tf.multiply(total_trace,tf.transpose(tf.tile(weight_update_aux,[1,num_hidden]))),axis=0))
    
    
# zip gradients and vars 
    new_grads_and_vars=list(zip([sensitivity_tensor_update, temporal_kernel_update,weight_update],new_trainables))
    train_op = optimizer.apply_gradients(new_grads_and_vars)
    
    # TODO 
    # Evaluate model (with test logits, for dropout to be disabled)
    correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    # Initialize the variables (i.e. assign their default value)
    init = tf.global_variables_initializer()
        # predictions 
        #prediction=tf.nn.softmax(logits)
    tf.summary.histogram('prediction',prediction+1e-8)
    tf.summary.histogram('logits',logits+1e-8)
    tf.summary.histogram('sensitivity_updates',sensitivity_tensor_update+1e-8)
    tf.summary.histogram('temporal_kernel_updates',temporal_kernel_update+1e-8)
    #tf.summary.scalar('loss',loss_op)
    merged_summary_op=tf.summary.merge_all()
    

In [39]:
# verify initialization 
with tf.Session(graph=graph) as sess : 
    sess.run(init)
    values,trainable_vars = sess.run([variable_names,new_trainables])
    for k, v in zip(variable_names,values):
        print(["variable: " , k])
        #print(["value: " , v])
        print(["variable: " , np.unicode_.find(k,'output')]) 
        print(["shape: " , v.shape])
        #print(v) 
     

['variable: ', 'output_weight:0']
['variable: ', 0]
['shape: ', (128, 10)]
['variable: ', 'output_addition:0']
['variable: ', 0]
['shape: ', (10,)]
['variable: ', 'recurrent/rnn/ke_rnl_cell/temporal_filter_coeff:0']
['variable: ', -1]
['shape: ', (128,)]
['variable: ', 'recurrent/rnn/ke_rnl_cell/sensitivity_tensor:0']
['variable: ', -1]
['shape: ', (128, 128)]
['variable: ', 'recurrent/rnn/ke_rnl_cell/kernel:0']
['variable: ', -1]
['shape: ', (129, 128)]
['variable: ', 'recurrent/rnn/ke_rnl_cell/bias:0']
['variable: ', -1]
['shape: ', (128,)]


In [37]:

weight_update

<tf.Tensor 'transpose_5:0' shape=(129, 128) dtype=float32>

In [43]:
log_dir = "logs/kernel_rnn/batch_learning_gc_%d_eta_%d_batch_%d_run_%s" %(grad_clip,learning_rate,batch_size, datetime.now().strftime("%Y%m%d_%H%M"))
Path(log_dir).mkdir(exist_ok=True, parents=True)
filelist = [ f for f in os.listdir(log_dir) if f.endswith(".local") ]
for f in filelist:
    os.remove(os.path.join(log_dir, f))

In [44]:
# write graph into tensorboard 
tb_writer = tf.summary.FileWriter(log_dir,graph)
# run a training session 
with tf.Session(graph=graph) as sess:
    sess.run(init)
    for step in range(1,100):#range(1,training_steps+1):
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        batch_x=batch_x.reshape((batch_size,timesteps,num_input))
        # run optimizaer 
        opt, loss_train,acc_train=sess.run([train_op,loss_op,accuracy],feed_dict={X:batch_x, Y:batch_y})
        logits_out, states_out,sensitivity_tensor= sess.run([logits, states,sensitivity_tensor_update],feed_dict={X:batch_x, Y:batch_y})
        merged_summary=sess.run(merged_summary_op,feed_dict={X:batch_x, Y:batch_y})
        tb_writer.add_summary(merged_summary, global_step=step)
        #tb_writer.flush()
        # show interim performance 
        if step % display_step==0 or step==1 : 
            # get batch loss and accuracy 
            print('Step: {}, Train Loss: {:.3f}, Train Acc: {:.3f}'.format(
            step + 1, loss_train, acc_train))
            # write summary 
            #tb_writer.add_summary(acc,global_step=step)
            #tb_writer.flush()
            # evaluate performance on test data 
            test_X=mnist.test.images[:test_len].reshape((-1, timesteps, num_input))
            test_Y=mnist.test.labels[:test_len]

    print("Optimization Finished!")
    test_data = mnist.test.images[:test_len].reshape((-1, timesteps, num_input))
    test_label = mnist.test.labels[:test_len]
    print("Testing Accuracy:", 
        sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))

Step: 2, Train Loss: 2.572, Train Acc: 0.130
Step: 3, Train Loss: 2.539, Train Acc: 0.170
Step: 5, Train Loss: 2.555, Train Acc: 0.140
Step: 7, Train Loss: 2.720, Train Acc: 0.100
Step: 9, Train Loss: 2.614, Train Acc: 0.080
Step: 11, Train Loss: 2.561, Train Acc: 0.120
Step: 13, Train Loss: 2.646, Train Acc: 0.080
Step: 15, Train Loss: 2.565, Train Acc: 0.160
Step: 17, Train Loss: 2.527, Train Acc: 0.140
Step: 19, Train Loss: 2.496, Train Acc: 0.170
Step: 21, Train Loss: 2.614, Train Acc: 0.120
Step: 23, Train Loss: 2.689, Train Acc: 0.090
Step: 25, Train Loss: 2.627, Train Acc: 0.120
Step: 27, Train Loss: 2.614, Train Acc: 0.100
Step: 29, Train Loss: 2.705, Train Acc: 0.120
Step: 31, Train Loss: 2.546, Train Acc: 0.160
Step: 33, Train Loss: 2.446, Train Acc: 0.180
Step: 35, Train Loss: 2.716, Train Acc: 0.080
Step: 37, Train Loss: 2.536, Train Acc: 0.180
Step: 39, Train Loss: 2.569, Train Acc: 0.100
Step: 41, Train Loss: 2.843, Train Acc: 0.090
Step: 43, Train Loss: 2.617, Train Acc:

In [33]:
states.input_trace.shape

TensorShape([Dimension(None), Dimension(128), Dimension(1)])

get the name of trainable variables in the graph