In [2]:
import numpy as np 
import tensorflow as tf
import matplotlib.pyplot as plt 

import collections
import hashlib
import numbers

from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import _Linear

create a tuple object for the cell, 

In [5]:
_KernelRNNStateTuple = collections.namedtuple("KernelRNNStateTuple", ("h","h_hat","Theta", "Gamma","input_trace","recurrent_trace","input_sensitivity","recurrent_sensitivity","kernel_coeff"))
_KernelRNNOutputTuple = collections.namedtuple("KernelRNNOutputTuple", ("h","h_hat","Theta","Gamma", "input_trace","recurrent_trace"))

class KernelRNNStateTuple(_KernelRNNStateTuple):
  """Tuple used by kernel RNN Cells for `state_variables `.
  Stores 9 elements: `(h, h_hat, Theta, Gamma, input_trace,recurrent_trace, input_sensitivity,recurrent_sensitivity, kernel_coeff`, in that order. 
  always is used for this type of cell
  """
  __slots__ = ()

  @property
  def dtype(self):
    (h, h_hat,Theta, Gamma, input_trace,recurrent_trace, input_sensitivity, recurrent_sensitivity, kernel_coeff ) = self
    if h.dtype != h_hat.dtype:
      raise TypeError("Inconsistent internal state: %s vs %s" %
                      (str(h.dtype), str(h_hat.dtype)))
    return h_hat.dtype


class KernelRNNOutputTuple(_KernelRNNOutputTuple):
  """Tuple used by kernel Cells for output state.
  Stores 6 elements: `(h,h_hat, Theta, Gamma, input_trace, recurrent_trace)`, 
  Only used when `output_is_tuple=True`.
  """
  __slots__ = ()

  @property
  def dtype(self):
    (h, h_hat, Theta, Gamma, input_trace, recurrent_trace) = self
    if h.dtype != h_hat.dtype:
      raise TypeError("Inconsistent internal state: %s vs %s" %
                      (str(h.dtype), str(h_hat.dtype)))
    return h_hat.dtype


In [None]:
## expand dimensions for incoming recurrent and input activations for multiplication with current acivation 
def _tensor_expand_dim(x,y,output_size):
    """input - x : a 2D tensor with batch x n 
    y is a 2D with size batch x m
    outputs is 3D tensor with size batch x n x n and batch x n x m 
    """ 
    shape_x=x.get_shape()
    shape_y=y.get_shape()
    #y=tf.cast(y,tf.float32)
    # define a matrix for removing the diagonal in recurrent spikes 
    diag_zero= lambda:tf.subtract(tf.constant(1.0,shape=[shape_x[1],shape_x[1]]),
                                                    tf.eye(output_size))
    x_diag_fixer = tf.Variable(initial_value=diag_zero, dtype=tf.float32)
    # expand x  
    x_temp=tf.reshape(tf.tile(x,[1,output_size]),[-1,output_size,shape_x[1]])
    # remove diagonal 
    x_expand=tf.multiply(x_temp,x_diag_fixer)
    # expand y  
    y_expand=tf.reshape(tf.tile(y,[1,output_size]),[-1,output_size,shape_y[1]])
    return x_expand, y_expand

create a Kernel RNN cell 

In [None]:
class KernelRNNCell(tf.contrib.rnn.RNNCell):
"""Kernel recurrent neural network Cell
  Args:
    num_units: int, The number of units in the cell.
    activation: Nonlinearity to use.  Default: `Relu`.
    eligibility_kernel: kernel funtion to use for elibility 
    reuse: (optional) Python boolean describing whether to reuse variables
     in an existing scope.  If not `True`, and the existing scope already has
     the given variables, an error is raised.
    kernel_initializer: (optional) The initializer to use for the weight and
    projection matrices.
    bias_initializer: (optional) The initializer to use for the bias.
  """
    def __init__(self,
                num_units,
                num_inputs,
                activation=None,
                reuse=None,
                eligibility_kernel=None,
                state_is_tuple=True,
                output_is_tuple=False,
                noise_std=1.0,
                batch_KeRNL=True):
        super(KernelRNNCell,self).__init__(_reuse=reuse)
        self._num_units = num_units
        self._num_inputs= num_inputs
        self._activation = activation or math_ops.tanh
        self._eligibility_kernel = eligibility_kernel or math_ops.exp
        self._noise_std=noise_std
        self._linear = None
        self._state_is_tuple=state_is_tuple
        self._output_is_tuple= output_is_tuple
        self._batch_KeRNL=batch_KeRNL
        self._tensor_expand_dim=_tensor_expand_dim
    
    @property
    def state_size(self):
        return (KernelRNNStateTuple(self._num_units, 
                                    self._num_units, 
                                    self._num_units, 
                                    self._num_units, 
                                    np.array([self._num_units,self._num_inputs]), 
                                    np.array([self._num_units,self._num_units]),
                                    np.array([self._num_units,self._num_inputs]),
                                    np.array([self._num_units,self._num_units]),
                                    self._num_units)
                if self._state_is_tuple else self._num_units)
    @property
    def output_size(self):
        return (KernelRNNOutputTuple(self._num_units, 
                                     self._num_units, 
                                     self._num_units, 
                                     self._num_units,
                                     np.array([self._num_units,self._num_inputs]), 
                                     np.array([self._num_units,self._num_units]))
                if self._output_is_tuple else self._num_units)

    # call function routine 
    def call(self, inputs, state):
    """Kernel RNN cell (KernelRNN).
    Args:
      inputs: `2-D` tensor with shape `[batch_size x input_size]`.
      state: An `KernelRNNStateTuple` of state tensors, shaped as following 
        h:                   [batch_size x self.state_size]`
        h_hat:               [batch_size x self.state_size]`
        Theta:               [batch_size x self.state_size]`
        Gamma:               [batch_size x self.state_size]`
        input_trace          [batch_size x self.state_size x self.input_size]`
        recurrent_trace      [batch_size x self.state_size x self.state_size]`
        input_sensitivity    [batch_size x self.state_size x self.input_size]`
        recurrent_sensitivity[batch_size x self.state_size x self.state_size]`
        kernel coeff         [batch_size x self.state_size]`
    Returns:
      A pair containing the new output, and the new state as SNNStateTuple
      output has the following shape 
        h:                   [batch_size x self.state_size]`
        h_hat:               [batch_size x self.state_size]`
        Theta:               [batch_size x self.state_size]`
        Gamma                [batch_size x self.state_size]`
        input_trace          [batch_size x self.state_size x self.input_size]`
        recurrent_trace      [batch_size x self.state_size x self.state_size]`
            
    """
        if self._state_is_tuple:
            h, h_hat, Theta, Gamma, input_trace, recurrent_trace, input_sensitivity, recurrent_sensitivity, kernel_coeff= state
        else:
            logging.error("State has to be tuple for this type of cell")
        
        if self._linear is None: 
            self._linear=_linear

        psi_new=tf.random_normal(shape=h_hat.get_shape(), mean=0.0,stddev=self._noise_std)    
        # propagate data forward 
        h_new=self._activation(self._linear([inputs,h],self._num_units,True))
        # propagate noisy data forward 
        h_hat_new= self._activation(self._linear(inputs, h_hat+psi_new,self._num_units,True))
        # TODO : check of weights get reused 
        # integrate over perturbations
        Theta_new=tf.add(tf.multiply(self._eligibility_kernel(-kernel_coeff),
                              Theta),psi_new)
        # derivative of perturbation w.r.t to kernel_coeff
        Gamma_new=tf.subtract(tf.multiply(self._eligibility_kernel(-kernel_coeff),
                              Gamma),
                             tf.multiply(self._eligibility_kernel(-kernel_coeff),
                              Theta))
        # update elgibility traces for input and recurrent units 
        g_new=self._linear([inputs,h],self._num_units,True)
        pre_activation=self._activation(g_new)
        # expand recurrent and input activation 
        recurrent_expand,inputs_expand=self._tensor_expand_dim(h,inputs,self._num_units)
        input_trace_update=tf.multiply(tf.expand_dims(tf.gradients(pre_activation,g_new),2),inputs_expand)
        recurrent_trace_update=tf.multiply(tf.expand_dims(tf.gradients(pre_activation,g_new),2),recurrent_expand)
        input_trace_new=tf.add(tf.multiply(self._eligibility_kernel(-kernel_coeff),
                              input_trace),input_trace_update)
        recurrent_trace_new=tf.add(tf.multiply(self._eligibility_kernel(-kernel_coeff),
                              recurrent_trace),recurrent_trace_update)

        # TODO implement online updating for sensitivity and kernel coeff
        
        
        if self._state_is_tuple: 
            new_state=KernelRNNStateTuple(h_new,h_hat_new,Theta_new,Gamma_new,input_trace_new,
                                          recurrent_trace_new,input_sensitivity,recurrent_sensitivity,kernel_coeff)
        if self._output_is_tuple:
            new_output=KernelRNNOutputTuple(h)
        
    
    
    