code for creating a Kernel based relu-RNN learning for sequential MNIST
adapted from : Roth, Christopher, Ingmar Kanitscheider, and Ila Fiete. 2018. “Kernel RNN Learning (KeRNL),” September. https://openreview.net/forum?id=ryGfnoC5KQ.

In [20]:
import numpy as np 
import tensorflow as tf
import matplotlib.pyplot as plt 

import collections
import hashlib
import numbers
import matplotlib.cm as cm
from sys import getsizeof
from datetime import datetime
from pathlib import Path
import os

from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import _Linear
from tensorflow.contrib import slim

In [4]:
# uplading mnist data 

old_v = tf.logging.get_verbosity()
tf.logging.set_verbosity(tf.logging.ERROR)

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
train_data = mnist.train.images  # Returns np.array
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images  # Returns np.array
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)

tf.logging.set_verbosity(old_v)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [9]:
# Training Parameters
learning_rate = 1e-4
training_steps = 5000
batch_size = 128
display_step = 200
test_len=128
grad_clip=100
# Network Parameters
num_input = 1 # MNIST data input (img shape: 28*28)
timesteps = 28*28 # timesteps
num_hidden = 128 # hidden layer num of features
num_classes = 10 # MNIST total classes (0-9 digits)

# tf Graph input


In [6]:
_KernelRNNStateTuple = collections.namedtuple("KernelRNNStateTuple", ("h","h_hat","Theta", "Gamma","input_trace","recurrent_trace","sensitivty_tensor","kernel_coeff"))
_KernelRNNOutputTuple = collections.namedtuple("KernelRNNOutputTuple", ("h","h_hat","Theta","Gamma", "input_trace","recurrent_trace","sensitivty_tensor","kernel_coeff"))

class KernelRNNStateTuple(_KernelRNNStateTuple):
  """Tuple used by kernel RNN Cells for `state_variables `.
  Stores 8 elements: `(h, h_hat, Theta, Gamma, input_trace,recurrent_trace, sensitivty_tensor, kernel_coeff`, in that order. 
  always is used for this type of cell
  """
  __slots__ = ()

  @property
  def dtype(self):
    (h, h_hat,Theta , Gamma, input_trace,recurrent_trace, sensitivity_tensor, kernel_coeff ) = self
    if h.dtype != h_hat.dtype:
      raise TypeError("Inconsistent internal state: %s vs %s" %
                      (str(h.dtype), str(h_hat.dtype)))
    return h_hat.dtype


class KernelRNNOutputTuple(_KernelRNNOutputTuple):
  """Tuple used by kernel Cells for output state.
  Stores 6 elements: `(h,h_hat, Theta, Gamma, input_trace, recurrent_trace)`, 
  Only used when `output_is_tuple=True`.
  """
  __slots__ = ()

  @property
  def dtype(self):
    (h, h_hat,Theta , Gamma, input_trace,recurrent_trace, sensitivity_tensor, kernel_coeff) = self
    if h.dtype != h_hat.dtype:
      raise TypeError("Inconsistent internal state: %s vs %s" %
                      (str(h.dtype), str(h_hat.dtype)))
    return h_hat.dtype
#################################################################
def _tensor_expand_dim(x,y,output_size):
    """input - x : a 2D tensor with batch x n 
    y is a 2D with size batch x m
    outputs is 3D tensor with size batch x n x n and batch x n x m 
    """ 
    shape_x=x.get_shape()
    shape_y=y.get_shape()
    #y=tf.cast(y,tf.float32)
    # define a matrix for removing the diagonal in recurrent spikes 
    diag_zero= lambda:tf.subtract(tf.constant(1.0,shape=[shape_x[1],shape_x[1]]),
                                                    tf.eye(output_size))
    x_diag_fixer = tf.Variable(initial_value=diag_zero, dtype=tf.float32)
    # expand x  
    x_temp=tf.reshape(tf.tile(x,[1,output_size]),[-1,output_size,shape_x[1]])
    # remove diagonal 
    x_expand=tf.multiply(x_temp,x_diag_fixer)
    # expand y  
    y_expand=tf.reshape(tf.tile(y,[1,output_size]),[-1,output_size,shape_y[1]])
    return x_expand, y_expand

def _create_pertubation(x,mean,std):
    shape=x.get_shape()
    logging.warn("%s: Please use float ", [shape[0].value,shape[1].value])
    scope=vs.get_variable_scope()
    with vs.variable_scope (scope) as perturbation_scope:
        perturbation=tf.constant(1.0,shape=[shape[0].value,shape[1].value])
    #perturbation=tf.constant(tf.random_normal(shape=[shape[0].value,shape[1].value], mean=mean,stddev=std))
    
    return perturbation

def _gaussian_noise_perturbation(input_layer, std):
    noise = tf.random_normal(shape=tf.shape(input_layer), mean=0.0, stddev=std, dtype=tf.float32) 
    return tf.multiply(input_layer,0) + noise


def _kernel_coeff_initializer(shape,dtype=None,partition_info=None,verify_shape=None, max_val=1):
    if dtype is None: 
        dtype=tf.float32
        
    return tf.random_uniform(shape,0,max_val,dtype=dtype)
###################################################################
_KERNEL_COEF_NAME= "kernel_coeff"
_SENSITIVITY_TENSOR_NAME= "sensitivity_tensor"

class KernelRNNCell(tf.contrib.rnn.RNNCell):
    """Kernel recurrent neural network Cell
      Args:
        num_units: int, The number of units in the cell.
        activation: Nonlinearity to use.  Default: `Relu`.
        eligibility_kernel: kernel funtion to use for elibility 
        reuse: (optional) Python boolean describing whether to reuse variables
         in an existing scope.  If not `True`, and the existing scope already has
         the given variables, an error is raised.
        kernel_initializer: (optional) The initializer to use for the weight and
        projection matrices.
        bias_initializer: (optional) The initializer to use for the bias.
    """
    def __init__(self,
                 num_units,
                 num_inputs,
                 time_steps=1,
                 noise_std=1.0,
                 activation=None,
                 reuse=None,
                 eligibility_kernel=None,
                 state_is_tuple=True,
                 output_is_tuple=False,
                 batch_KeRNL=True,
                 sensitivity_initializer=None,
                 kernel_coeff_initializer=None,
                 kernel_initializer=None,
                 bias_initializer=None):
        
        super(KernelRNNCell,self).__init__(_reuse=reuse)
        self._num_units = num_units
        self._num_inputs= num_inputs
        self._time_steps= time_steps
        self._activation = activation or math_ops.tanh
        self._eligibility_kernel = eligibility_kernel or math_ops.exp
        self._noise_std=noise_std
        self._linear = None
        self._state_is_tuple=state_is_tuple
        self._output_is_tuple= output_is_tuple
        self._batch_KeRNL=batch_KeRNL
        self._tensor_expand_dim=_tensor_expand_dim
        self._gaussian_noise_perturbation=_gaussian_noise_perturbation
        self._sensitivity_initializer=sensitivity_initializer
        self._kernel_coeff_initializer=kernel_coeff_initializer
        self._kernel_initializer=kernel_initializer
        self._bias_initializer=bias_initializer
    
    @property
    # h,h_hat,Theta, Gamma,input_trace,recurrent_trace,sensitivty_tensor,kernel_coeff
    def state_size(self):
        return (KernelRNNStateTuple(self._num_units, 
                                    self._num_units, 
                                    self._num_units, 
                                    self._num_units, 
                                    np.array([self._num_units,self._num_inputs]), 
                                    np.array([self._num_units,self._num_units]),
                                    np.array([self._num_units,self._num_units]),
                                    self._num_units)
                if self._state_is_tuple else self._num_units)
    @property
    def output_size(self):
        return (KernelRNNOutputTuple(self._num_units, 
                                    self._num_units, 
                                    self._num_units, 
                                    self._num_units, 
                                    np.array([self._num_units,self._num_inputs]), 
                                    np.array([self._num_units,self._num_units]),
                                    np.array([self._num_units,self._num_units]),
                                    self._num_units)
                if self._output_is_tuple else self._num_units)

    # call function routine 
    def call(self, inputs, state):
        """Kernel RNN cell (KernelRNN).
        Args:
          inputs: `2-D` tensor with shape `[batch_size x input_size]`.
          state: An `KernelRNNStateTuple` of state tensors, shaped as following 
            h:                   [batch_size x self.state_size]`
            h_hat:               [batch_size x self.state_size]`
            Theta:               [batch_size x self.state_size]`
            Gamma:               [batch_size x self.state_size]`
            input_trace          [batch_size x self.state_size x self.input_size]`
            recurrent_trace      [batch_size x self.state_size x self.state_size]`
            sensitivity_tensor   [batch_size x self.state_size x self.state_size]`
            kernel coeff         [batch_size x self.state_size]`
        Returns:
          A pair containing the new output, and the new state as SNNStateTuple
          output has the following shape 
            h:                   [batch_size x self.state_size]`
            h_hat:               [batch_size x self.state_size]`
            Theta:               [batch_size x self.state_size]`
            Gamma:               [batch_size x self.state_size]`
            input_trace          [batch_size x self.state_size x self.input_size]`
            recurrent_trace      [batch_size x self.state_size x self.state_size]`
            sensitivity_tensor   [batch_size x self.state_size x self.state_size]`
            kernel coeff         [batch_size x self.state_size]`
        """
        # initialize kernel_coeff
        scope=vs.get_variable_scope()
        if self._kernel_coeff_initializer is None:
            kernel_initializer=init_ops.constant_initializer(1/self._time_steps,dtype=tf.float32)
        else: 
            kernel_initializer=self._kernel_coeff_initializer
        with vs.variable_scope(scope,initializer=kernel_initializer) as kernel_scope:
            kernel_coeff=tf.get_variable(_KERNEL_COEF_NAME,shape=[self._num_units],dtype=tf.float32,trainable=True)
        
        # initialize Sensitivity_tensor
        scope=vs.get_variable_scope()
        if self._sensitivity_initializer is None:
            sensitivity_initializer=init_ops.truncated_normal_initializer
        else: 
            sensitivity_initializer=self._sensitivity_initializer
        with vs.variable_scope(scope,initializer=sensitivity_initializer) as sensitivity_scope:
            sensitivity_tensor=tf.get_variable(_SENSITIVITY_TENSOR_NAME,shape=[self._num_units,self._num_units],dtype=tf.float32,trainable=True)
        
            
        
        if self._state_is_tuple:
            h, h_hat, Theta, Gamma, input_trace, recurrent_trace, sensitivity_tensor, kernel_coeff= state
        else:
            logging.error("State has to be tuple for this type of cell")
        
        if self._linear is None: 
            self._linear = _Linear([inputs, h], self._num_units, True)
        psi_new=self._gaussian_noise_perturbation(h,self._noise_std)
        # propagate data forward 
        h_new=self._activation(self._linear([inputs,h]))
        # propagate noisy data forward 
        h_hat_update=tf.add(h_hat,psi_new)
        h_hat_new= self._activation(self._linear([inputs, h_hat_update]))
        # TODO : check of weights get reused 
        # integrate over perturbations
        Theta_new=tf.add(tf.multiply(self._eligibility_kernel(-kernel_coeff),
                              Theta),psi_new)
        # derivative of perturbation w.r.t to kernel_coeff
        Gamma_new=tf.subtract(tf.multiply(self._eligibility_kernel(-kernel_coeff),
                              Gamma),
                             tf.multiply(self._eligibility_kernel(-kernel_coeff),
                              Theta))
        # update elgibility traces for input and recurrent units 
        recurrent_expand,inputs_expand=self._tensor_expand_dim(h,inputs,self._num_units)
        #
        g_new=self._linear([inputs,h])
        pre_activation=self._activation(g_new)
        activation_gradients=tf.gradients(pre_activation,g_new)[0] # convert list to a tensor 
        gradient_expansion=tf.expand_dims(activation_gradients,axis=-1)
        input_trace_update=tf.multiply(gradient_expansion,inputs_expand)
        recurrent_trace_update=tf.multiply(gradient_expansion,recurrent_expand)
        #logging.warn("%s: input_trace ", input_trace.get_shape())
        kernel_decay=tf.expand_dims(self._eligibility_kernel(-kernel_coeff),axis=-1)
        # update input trace 
        input_trace_decay=tf.multiply(kernel_decay,input_trace)
        input_trace_new=tf.add(input_trace_decay,input_trace_update)

        # update recurrent trace  
        recurrent_trace_decay=tf.multiply(kernel_decay,recurrent_trace)
        recurrent_trace_new=tf.add(recurrent_trace_decay,recurrent_trace_update)

        # TODO implement online updating for sensitivity and kernel coeff
        
        
        if self._state_is_tuple: 
            new_state=KernelRNNStateTuple(h_new,h_hat_new,Theta_new,Gamma_new,input_trace_new,
                                          recurrent_trace_new,sensitivity_tensor,kernel_coeff)
        if self._output_is_tuple:
            new_output=KernelRNNOutputTuple(h_new,h_hat_new,Theta_new,Gamma_new,input_trace_new,
                                          recurrent_trace_new,sensitivity_tensor,kernel_coeff)
        else:
            new_output=h_new

        return new_output, new_state
        

In [15]:
def kernel_RNN(x, weights, biases):

    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, timesteps, n_input)
    # Required shape: 'timesteps' tensors list of shape (batch_size, n_input)

    # Unstack to get a list of 'timesteps' tensors of shape (batch_size, n_input)
    with tf.variable_scope('recurrent',initializer=tf.initializers.identity()) as scope: 
        # Define a lstm cell with tensorflow
        kernel_cell = KernelRNNCell(num_units=num_hidden,num_inputs=num_input,time_steps=timesteps)
        # Get lstm cell output
        kernel_outputs, kernel_states = tf.nn.dynamic_rnn(kernel_cell, x, dtype=tf.float32)

    # Linear activation, using rnn inner loop last output
    return tf.matmul(kernel_outputs[:,-1,:], weights['out']) + biases['out'], kernel_outputs 

In [16]:
tf.reset_default_graph()
graph=tf.Graph()
with graph.as_default():
    # Define weights
    weights = {
        'out': tf.Variable(tf.random_normal([num_hidden, num_classes]),name='output_weight')
    }
    biases = {
        'out': tf.Variable(tf.random_normal([num_classes]),name='output_bias')
    }
    X = tf.placeholder("float", [None, timesteps, num_input])
    Y = tf.placeholder("float", [None, num_classes])
    logits,output = kernel_RNN(X, weights, biases)
    prediction = tf.nn.softmax(logits)
    variable_names=[v.name for v in tf.trainable_variables()]
    # Define loss and optimizer
    loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=Y))
    optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate)
    # compute gradients 
    grads_and_vars=optimizer.compute_gradients(loss_op)
    # clip the gradient based on norm clipping:  g^ <-- threshold/l2_norm(g^)*g^
    #cropped_grads_and_vars=[(tf.clip_by_norm(grad, 100),var) if  np.unicode_.find(var.name,'output')==-1 else (grad,var) for grad,var in grads_and_vars]
    train_op = optimizer.minimize(loss_op)
    # Evaluate model (with test logits, for dropout to be disabled)
    correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    # Initialize the variables (i.e. assign their default value)
    init = tf.global_variables_initializer()
        # predictions 
        #prediction=tf.nn.softmax(logits)
    tf.summary.histogram('prediction',prediction+1e-8)
    tf.summary.histogram('logits',logits+1e-8)
    tf.summary.scalar('loss',loss_op)
    merged_summary_op=tf.summary.merge_all()
    # define loss 
        #loss_op=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,labels=Y))
        # optimization loop 
        #tf.summary.scalar('loss',loss_op)
        #tf.summary.histogram('logits',logits)    
        #optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate)
        #gradients=optimizer.compute_gradients(loss_op)
        #capped_gvs = [(tf.clip_by_norm(grad, 1.), var) if not var.name.startswith("dense") else (grad, var) for grad, var in gradients]
        #for _, var in gradients:
        #    if var.name.startswith("dense"):
        #        print(var.name)   
        #train_op=optimizer.apply_gradients(capped_gvs)
        # initialize variables 
    merged_summary_op=tf.summary.merge_all()
        
        #saver=tf.train.Saver()
    

In [17]:
# verify initialization 
with tf.Session(graph=graph) as sess : 
    sess.run(init)
    values = sess.run(variable_names)
    for k, v in zip(variable_names,values):
        print(["variable: " , k])
        print(["value: " , v])
        print(["variable: " , np.unicode_.find(k,'output')]) 
        print(["shape: " , v.shape])
        #print(v) 
     

['variable: ', 'output_weight:0']
['value: ', array([[-0.5562263 , -1.5679277 , -0.32520884, ..., -0.4588559 ,
         1.7650766 ,  1.5129395 ],
       [-1.0631367 ,  0.09719802,  0.49336594, ...,  0.3012817 ,
        -0.00972463,  0.7842659 ],
       [-1.3930802 ,  0.2786544 ,  0.5650163 , ...,  3.3679826 ,
         0.98051184, -0.88528967],
       ...,
       [ 1.3211502 ,  0.14028516, -1.0883702 , ..., -0.4652753 ,
         0.11319174,  0.8593472 ],
       [ 1.277668  ,  0.21948794,  0.18592085, ..., -0.41221184,
        -0.3877456 ,  1.730171  ],
       [ 0.94079083,  0.83154386, -0.53782547, ..., -1.162312  ,
        -0.4911751 ,  1.3081511 ]], dtype=float32)]
['variable: ', 0]
['shape: ', (128, 10)]
['variable: ', 'output_bias:0']
['value: ', array([-0.9497624 , -0.63880116, -0.6099232 , -0.55430096,  1.1011579 ,
        1.337941  , -0.70203036, -1.3880402 , -0.7286829 ,  0.69202995],
      dtype=float32)]
['variable: ', 0]
['shape: ', (10,)]
['variable: ', 'recurrent/rnn/kernel

In [21]:
log_dir = "logs/irnn/bptt_gc_%d_eta_%d_batch_%d_run_%s" %(grad_clip,learning_rate,batch_size, datetime.now().strftime("%Y%m%d_%H%M"))
Path(log_dir).mkdir(exist_ok=True, parents=True)
filelist = [ f for f in os.listdir(log_dir) if f.endswith(".local") ]
for f in filelist:
    os.remove(os.path.join(log_dir, f))

In [22]:
# write graph into tensorboard 
tb_writer = tf.summary.FileWriter(log_dir,graph)
# run a training session 
with tf.Session(graph=graph) as sess:
    sess.run(init)
    for step in range(1,50):#range(1,training_steps+1):
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        batch_x=batch_x.reshape((batch_size,timesteps,num_input))
        # run optimizaer 
        sess.run(train_op,feed_dict={X:batch_x, Y:batch_y})
        loss_train, acc_train= sess.run([loss_op, accuracy],feed_dict={X:batch_x, Y:batch_y})
        merged_summary=sess.run(merged_summary_op,feed_dict={X:batch_x, Y:batch_y})
        tb_writer.add_summary(merged_summary, global_step=step)
        #tb_writer.flush()
        # show interim performance 
        if step % display_step==0 or step==1 : 
            # get batch loss and accuracy 
            print('Step: {}, Train Loss: {:.3f}, Train Acc: {:.3f}'.format(
            step + 1, loss_train, acc_train))
            # write summary 
            #tb_writer.add_summary(acc,global_step=step)
            #tb_writer.flush()
            # evaluate performance on test data 
            test_X=mnist.test.images[:test_len].reshape((-1, timesteps, num_input))
            test_Y=mnist.test.labels[:test_len]

    print("Optimization Finished!")
    test_data = mnist.test.images[:test_len].reshape((-1, timesteps, num_input))
    test_label = mnist.test.labels[:test_len]
    print("Testing Accuracy:", \
        sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))

Step: 2, Train Loss: 2.804, Train Acc: 0.055
Optimization Finished!
Testing Accuracy: 0.3125


'logs/irnn/irnn/bptt_gc_100_eta_0_run_20190121_1220'

get the name of trainable variables in the graph