# Code for implementing MNIST learning on kernel_SNN 

# Background 
This code implements a spiking neural net with conductance in input. the following equations govern the dynamic of the network. 
### transmembrane voltage dynamics
first we model the transmembrane voltage as 
$$\tau_m \frac{dV_i}{dt}= - V_i(t)+ R_m \times I^{syn}_i(t) $$ 
$$ {\tau_a}_i \frac{dB_i(t)}{dt} = b_i^0 -B_i(t)$$ 
where, $R_m$ is membrane resistance, $\tau_m$ is membrane time constant, and ${\tau_a}_i$ is adaptation time constant  .
the synaptic current relates to synaptic activations in the following way
$$I^{syn}_i(t)= \sum_j W^{in}_{ij} \times X(t) + \sum_j W^{rec}_{ij} \times S_j(t) $$ 

### neuron firing dynamics 
The firing dynamics of the neuron is model as a simple reseting. More specifically, 
$$V_i \rightarrow V_{reset} \ \ \  if \ \ \ V_i>=B_{i} $$

$ V_{\Theta}$ represent the threshold voltage and $V_{reset}$ is the reset voltage of the neuron.

### Input dynamics 
Input synapes are the the site of learning in the spiking network. Below a conductance based formulation is presented. 
First, the time-dependent input conductance to membrane is calculated as follows 
$$ g_i(t) = \sum_j W_{ij} S_{j}(t) $$

in the current version $S_{j}(t)$ is equal to spike at timestep $t$ without any decay dynamics. 
-  TODO the term $j$ reperesent all the neurons that have a synapse onto the neuron $i$. the time dependence of conductance is due to $S(t)$ which represent the spiking activity for neurons connected to neuron $i$ . The spiking activity has the following governing equations 
$$ S_{j} \rightarrow S_{j}+1 \quad if \ neuron\ j\ fires$$
$$ \frac{dS_{j}(t)}{dt} = \frac{-S_{j}(t)}{\tau_s}$$ 

### Spike Adaptation dynamics 
The threshold for spiking increases with every spike emited from a neuron with the following dynamics 
$$ B_{i}(t) \rightarrow B_{i}(t)+\frac{\beta}{{\tau_a}_i} \quad if \ neuron\ i\ fires$$


### implementation in discrete time 
we start with Euler method for modeling the dynamics 
### References 
-  Fiete, Ila R., Walter Senn, Claude Z. H. Wang, and Richard H. R. Hahnloser. 2010. “Spike-Time-Dependent Plasticity and Heterosynaptic Competition Organize Networks to Produce Long Scale-Free Sequences of Neural Activity.” Neuron 65 (4): 563–76. 

-  Bellec, Guillaume, Darjan Salaj, Anand Subramoney, Robert Legenstein, and Wolfgang Maass. 2018. “Long Short-Term Memory and Learning-to-Learn in Networks of Spiking Neurons.” arXiv [cs.NE]. arXiv. http://arxiv.org/abs/1803.09574.



In [1]:
# python libraries
import numpy as np 
import matplotlib.pyplot as plt 
import collections
import hashlib
import numbers
import matplotlib.cm as cm
from sys import getsizeof
from datetime import datetime
from pathlib import Path
import os
from IPython.display import HTML
import re

# tensorflow and its dependencies 
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import _Linear
from tensorflow.contrib import slim

## user defined modules 
# kernel rnn cell 
import kernl_spiking_cell_v3

### getting MNIST Data

In [2]:
# uplading mnist data 
old_v = tf.logging.get_verbosity()
tf.logging.set_verbosity(tf.logging.ERROR)
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
train_data = mnist.train.images  # Returns np.array
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images  # Returns np.array
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)
tf.logging.set_verbosity(old_v)


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


### initial parameters for the network 

In [3]:
# Training Parameters
weight_learning_rate = 1e-5
tensor_learning_rate=1e-3 
training_steps = 5000
batch_size = 25
display_step = 25
test_len=128
grad_clip=200
# Network Parameters
# 1-input layer 
num_input = 1 # MNIST data input (img shape: 28*28)
num_context_input=1
MNIST_timesteps = 28*28 # timesteps
context_timesteps=54
timesteps=MNIST_timesteps+context_timesteps
num_unit_input_layer=80 # input layer neurons
num_context_unit=1
noise_std=0.5
# 2-hidden layer 
num_hidden = 200 # hidden layer num of features
# 3-output layer 
num_classes = 10 # MNIST total classes (0-9 digits)

# report batch number 
total_batch = int(mnist.train.num_examples / batch_size)
print("Total number of batches:", total_batch)

Total number of batches: 2200


### define the network

In [4]:
def kernl_SNN_all_states(x,context):
    with tf.variable_scope('context_layer') as scope:
        context_input_layer_cell=kernl_spiking_cell_v3.context_input_spike_cell(num_units=1,context_switch=MNIST_timesteps)
        context_initial_state = context_input_layer_cell.zero_state(batch_size, dtype=tf.float32)
        output_context, states_context = tf.nn.dynamic_rnn(context_input_layer_cell, dtype=tf.float32, inputs=context,initial_state=context_initial_state)
        
    with tf.variable_scope('input_layer') as scope: 
        input_layer_cell=kernl_spiking_cell_v3.input_spike_cell(num_units=num_unit_input_layer)
        input_initial_state = input_layer_cell.zero_state(batch_size, dtype=tf.float32)
        output_l1, states_l1 = tf.nn.dynamic_rnn(input_layer_cell, dtype=tf.float32, inputs=x,initial_state=input_initial_state)
        
    with tf.variable_scope('hidden_layer') as scope: 
        hidden_layer_cell=kernl_spiking_cell_v3.kernl_spike_Cell(num_units=num_hidden,
                                                                 num_inputs=num_unit_input_layer+num_context_unit,
                                                                 time_steps=timesteps,
                                                                 output_is_tuple=True,
                                                                 tau_refract=1.0,
                                                                 tau_m=20,
                                                                 noise_std=noise_std)
        hidden_initial_state = hidden_layer_cell.zero_state(batch_size, dtype=tf.float32)
        output_hidden, states_hidden = tf.nn.dynamic_rnn(hidden_layer_cell, dtype=tf.float32, inputs=tf.concat([output_l1,output_context],-1),initial_state=hidden_initial_state)
    with tf.variable_scope('output_layer') as scope : 
        output_layer_cell=kernl_spiking_cell_v3.output_spike_cell(num_units=num_classes)
        output_voltage, voltage_states=tf.nn.dynamic_rnn(output_layer_cell,dtype=tf.float32,inputs=output_hidden.spike)

    return output_voltage,output_hidden


### Gradient equation 

first we derive the governing equations for calculating gradients, sensitivity tensor, and temporal filter. We first rewrite the equation for the cell for mathematical simplicitiy. We can rewrite the governing diffrential equation for the network in discrete time as follows <br>
$$V_{t+1}=\sigma_{threshold}(\delta \times V_{t} + R_m \times W^{in} \times X_{t}+ R_m \times W^{rec} \times S_{t})$$ <br>
$$S_{t+1}= \rho \times S_t + \sigma_{spike}(V_t)$$ <br>
$$Y^{spike}_t=\sigma_{spike}(V_t)$$  <br>
<font color='red'> there is a problem with units in the equation above, look into it. </font> <br>
considering a cost function that depends on the $Y^{spike}_t$ and $X_t$ such as $C=C(X(-),f(Y^{spike}(-)),Y^{target})$ where $f(Y^{spike}(-))$ is the output transformation from hidden layer to output layer, in this case a linear layer of neurons. we need to calculate $\frac{dC}{dW}$ for $W^{rec}$ and $W^{in}$. considering that output is determined by $V_t$ we make the following assertion. 
$$\frac{\partial C}{\partial S} = 0 $$ 

in addition following the chain rule for computing $\frac{\partial C}{\partial W}$ we have  <br>
$$\frac{\partial C}{\partial W^{rec}} = \frac{\partial C}{\partial Y^{spike}} \times \frac{\partial Y^{spike}}{\partial V} \times \frac{\partial V}{\partial W^{rec}}$$ <br> 
$$\frac{\partial C}{\partial W^{in}} = \frac{\partial C}{\partial Y^{spike}} \times \frac{\partial Y^{spike}}{\partial V} \times \frac{\partial V}{\partial W^{in}}$$ <br>

we start by implementing the sensitivity lemma, and rewrite the equation for gradients 
$$g_i(t)=\sum_jW_{ij}S_j(t)$$ <br>
$$\frac{dC}{dW_{ij}}=\frac{1}{T} \int_0^T dt \frac{\delta C}{\delta W_{ij}(t)}=\frac{1}{T} \int_0^T dt \frac{\delta C}{\delta g_{i}(t)}\times S_j(t)$$  <br>
$$ \frac{\delta C}{\delta g_{i}(t)}= \sum_k \int_0^T \frac{\delta C}{\delta y^{spike}_{k}(t')} \times \frac{\delta y^{spike}_{k}(t') }{\delta v_{k}(t')} \times \frac{\delta v_k(t')}{\delta g_i(t)}$$ 
<br>
first we focus on the first two terms in the equation, we can caluculate $\frac{\delta C}{\delta y^{spike}_{k}(t')}$ as error in output $k$ at time $t'$. In addition $\frac{\delta y^{spike}_{k}(t') }{\delta v_{k}(t')}$ is equivalent to $ \sigma'_{spike}$ . in order to avoid discontinuitiy in $\sigma'_{spike}$ we use a pseudo derivative defined as 
$$ \frac{\delta y^{spike}_k(t)}{\delta v_{k}(t)} := max\{0,1-|v^{norm}{k}(t)|\}$$
$$ v^{norm}{k}(t) = \frac{v_k(t)-B_{k}(t)}{B_{k}(t)} $$
<br>
in the equation above the term that depends on interactions between units is  $ \frac{\delta v_k(t')}{\delta g_i(t)} $ and captures how activity of neuron $i$ at time $t$ affects the activity of neuron $k$ at time $t'$ . In order to estimate this interaction we make the following assumption <br>

$$\frac{\delta v_k(t')}{\delta g_i(t)} = m_{ki}(t,t') = M_{ki}\times K(t-t') \times h(s_i(t))$$  
<font color='red'> how one can learn h and M and K emprically, or what other forms are good estimates of the following interaction. </font> <br>

going back to the gradient definition we can incorporate the derivation of interaction and and calculate the gradient as follows
<br>
$$\frac{dC}{dW_{ij}}=\frac{1}{T} \int_0^T dt \frac{\delta C}{\delta W_{ij}(t)}=\frac{1}{T} \int_0^T dt \frac{\delta C}{\delta g_{i}(t)}\times S_j(t)$$ <br>
$$\frac{dC}{dW_{ij}}=\frac{1}{T} \int_0^T dt \frac{\delta C}{\delta W_{ij}(t)}=\frac{1}{T} \int_0^T dt \frac{\delta C}{\delta y^{spike}_{k}(t')} \times  M_{ki}\times K(t-t') \times h\big(S_i(t)\big) \times S_j(t)$$
<br>
### learning of $M$ and $K(\tau)$

first we focus on estimating the two parameters $M_{ki}$ and $K(t)$ . We first apply a small iid hidden perturbation $\xi$ to $S(t)$ and track its effect on the output voltage.
<br>
$$\tilde{V}_{t+1}=\sigma_{threshold}(\delta \times \tilde{V}_{t} + R_m \times W^{in} \times X_{t}+ R_m \times W^{rec} \times \big(S_{t}+\xi_t)\big)$$ <br>
we then minimize the following cost function <br>
$$ C_{M,K}= \Big(\tilde{V_i}(t)-V_i(t)-\sum_j M_{ij} \sum_{\tau}K(\tau)\times\xi_j(t-\tau)\Big)^2$$
<br>
for a single exponential filter $K(\tau)=exp(-\gamma_j\tau)$ we have  <br>

$$ C_{M,K}= \Big(\tilde{V_i}(t)-V_i(t)-\sum_j M_{ij} \sum_{\tau} exp(-\gamma_j\tau) \times \xi_j(t-\tau)\Big)^2$$

<font color='red'> add pseudo code here for the spiking neuron cell. </font> <br>
### Pseudo-code for kernl_snn_cell
kernel_snn_v_2.1 <br>
while t<T do: <br>
>             update I_syn = W_in*X + W_rec*S <br>
              find neurons outside refractory period : eligible <br>
              calculate decay factor for v_mem : alpha <br>
              update v_mem for eligible neurons :  v_t+1= alpha* v_t + (1-alpha)* I_syn <br> 
              find spike_t+1 <-- v_mem > Beta <br>
              v_reset <-- v_mem for neurons crossing threshold (emmited spikes)
              update threshold <br>
              update refractory period <br>
              update synaptic input <br>
              update eligibility trace <br>




### Computation graph for learning $M$ and $K(\tau)$

In [5]:
tf.reset_default_graph()
graph=tf.Graph()
with graph.as_default():
    # check hardware 
    
    # define weights and inputs to the network
    X = tf.placeholder("float", [None, timesteps, num_input])
    Y = tf.placeholder("float", [None, num_classes])
    Context=tf.placeholder('float',shape=[batch_size,timesteps,num_context_input])
    # define a function for extraction of variable names
    kernl_output,kernl_hidden_states=kernl_SNN_all_states(X,Context)
    
    trainables=tf.trainable_variables()
    variable_names=[v.name for v in tf.trainable_variables()]
    # 
    find_joing_index = lambda x, name_1,name_2 : [a and b for a,b in zip([np.unicode_.find(k.name, name_1)>-1 for k in x] ,[np.unicode_.find(k.name, name_2)>-1 for k in x])].index(True)
    # find trainable parameters for kernl 
    with tf.name_scope('kernl_Trainables') as scope:
        kernl_output_weight_index= find_joing_index(trainables,'output_layer','kernel')
        kernl_temporal_filter_index= find_joing_index(trainables,'kernl','temporal_filter')
        kernl_sensitivity_tensor_index= find_joing_index(trainables,'kernl','sensitivity_tensor')
        kernl_kernel_index= find_joing_index(trainables,'hidden_layer','kernel')
    # 
        kernl_tensor_training_indices=np.asarray([kernl_sensitivity_tensor_index,kernl_temporal_filter_index],dtype=np.int)
        kernl_tensor_trainables= [trainables[k] for k in kernl_tensor_training_indices]
    #
        kernl_weight_training_indices=np.asarray([kernl_kernel_index,kernl_output_weight_index],dtype=np.int)
        kernl_weight_trainables= [trainables[k] for k in kernl_weight_training_indices]
    
 
    ##################
    # kernl train ####
    ##################
    with tf.name_scope("kernl_performance") as scope:
        # outputs 
        kernl_logit=tf.reduce_mean(kernl_output[:,-context_timesteps:,:],axis=1)
        kernl_loss_output_prediction=tf.losses.softmax_cross_entropy(onehot_labels=Y,logits=kernl_logit)
        kernl_prediction = tf.nn.softmax(kernl_logit)
        kernl_correct_pred = tf.equal(tf.argmax(kernl_prediction, 1), tf.argmax(Y, 1))
        kernl_accuracy = tf.reduce_mean(tf.cast(kernl_correct_pred, tf.float32))
        
    with tf.name_scope('kernl_train_tensors') as scope: 
        kernl_loss_state_prediction=tf.losses.mean_squared_error(tf.subtract(kernl_hidden_states.v_mem_hat[:,-1,:], kernl_hidden_states.v_mem[:,-1,:]),tf.matmul(kernl_hidden_states.Theta[:,-1,:],trainables[kernl_sensitivity_tensor_index]))
        kernl_tensor_optimizer = tf.train.RMSPropOptimizer(learning_rate=tensor_learning_rate)
        kernl_tensor_grads=tf.gradients(ys=kernl_loss_state_prediction,xs=kernl_tensor_trainables)
        kernl_tensor_grad_and_vars=list(zip(kernl_tensor_grads,kernl_tensor_trainables))
        kernl_tensor_train_op=kernl_tensor_optimizer.apply_gradients(kernl_tensor_grad_and_vars)
        
    
    ##################
    # SUMMARIES ######
    ##################
    
    with tf.name_scope("kernl_tensor_summaries") as scope: 
        # kernl sensitivity tensor 
        tf.summary.histogram('kernl_sensitivity_tensor_grad',kernl_tensor_grads[0]+1e-10)
        tf.summary.histogram('kernl_sensitivity_tensor',trainables[kernl_sensitivity_tensor_index]+1e-10)
        # kernl temporal filter 
        tf.summary.histogram('kernl_temporal_filter_grad',kernl_tensor_grads[1]+1e-10)
        tf.summary.histogram('kernl_temporal_filter',trainables[kernl_temporal_filter_index]+1e-10)
        # kernl loss 
        tf.summary.scalar('kernl_loss_state_prediction',kernl_loss_state_prediction+1e-10)
        # kernl senstivity tensor and temporal filter 
        tf.summary.image('kernl_sensitivity_tensor',tf.expand_dims(tf.expand_dims(trainables[kernl_sensitivity_tensor_index],axis=0),axis=-1))
        tf.summary.image('kernl_sensitivity_tensor_grad',tf.expand_dims(tf.expand_dims(kernl_tensor_grads[0],axis=0),axis=-1))
        tf.summary.image('kernl_temporal_filter',tf.expand_dims(tf.expand_dims(tf.expand_dims(trainables[kernl_temporal_filter_index],axis=0),axis=-1),axis=-1))
        tf.summary.image('kernl_temporal_filter_grad',tf.expand_dims(tf.expand_dims(tf.expand_dims(kernl_tensor_grads[1],axis=0),axis=-1),axis=-1))
        kernl_tensor_merged_summary_op=tf.summary.merge_all(scope="kernl_tensor_summaries")
    
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

In [6]:
# verify initializatio

with tf.Session(graph=graph,) as sess : 
    sess.run(init)
    values,trainable_vars = sess.run([variable_names,trainables])
    for k, v in zip(variable_names,values):
        print(["variable: " , k])
        #print(["value: " , v])
        print(["variable: " , np.unicode_.find(k,'output')]) 
        print(["shape: " , v.shape])
        #print(v) 

['variable: ', 'hidden_layer/rnn/kernl_spike__cell/temporal_filter:0']
['variable: ', -1]
['shape: ', (200,)]
['variable: ', 'hidden_layer/rnn/kernl_spike__cell/sensitivity_tensor:0']
['variable: ', -1]
['shape: ', (200, 200)]
['variable: ', 'hidden_layer/rnn/kernl_spike__cell/kernel:0']
['variable: ', -1]
['shape: ', (281, 200)]
['variable: ', 'output_layer/rnn/output_spike_cell/kernel:0']
['variable: ', 0]
['shape: ', (200, 10)]


In [7]:
log_dir = "/home/eghbal/MyData/KeRNL/logs/kernl_SNN_v3/MNIST_gc_%d_eta_m_%d_eta_%d_batch_%d_run_%s" %(grad_clip,tensor_learning_rate,weight_learning_rate,batch_size, datetime.now().strftime("%Y%m%d_%H%M"))
Path(log_dir).mkdir(exist_ok=True, parents=True)
filelist = [ f for f in os.listdir(log_dir) if f.endswith(".local") ]
for f in filelist:
    os.remove(os.path.join(log_dir, f))

In [8]:
# write graph into tensorboard 
tb_writer = tf.summary.FileWriter(log_dir,graph)
# run a training session 
with tf.Session(graph=graph) as sess:
    sess.run(init)
    for step in range(1,training_steps+1):
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        batch_x=batch_x.reshape((batch_size,MNIST_timesteps,num_input))
        batch_x_full=np.concatenate([batch_x,np.zeros((batch_size,timesteps-MNIST_timesteps,num_input))],axis=1)
        context_input=np.ones((batch_size,timesteps,num_context_input))
        kernl_tensor_train,kernl_loss_state=sess.run([kernl_tensor_train_op,kernl_loss_state_prediction], feed_dict={X: batch_x_full,Y:batch_y,Context:context_input})
        
        # run summaries 
        kernl_tensor_merged_summary=sess.run(kernl_tensor_merged_summary_op,feed_dict={X:batch_x_full, Y:batch_y,Context:context_input})
        
        tb_writer.add_summary(kernl_tensor_merged_summary, global_step=step)
        # 
        if step % display_step==0 or step==1 : 
            # get batch loss and accuracy 
            print('Step: {}, keRNL tensor Loss {:.3f},'.format(step + 1, kernl_loss_state))


    print("Optimization Finished!")
    #test_data = mnist.test.images[:test_len].reshape((-1, timesteps, num_input))
    #test_label = mnist.test.labels[:test_len]
    #print("Testing Accuracy:", 
    #    sess.run(loss_output_prediction, feed_dict={X: test_data, Y: test_label}))
    save_path = saver.save(sess, log_dir+"/model.ckpt", global_step=step,write_meta_graph=True)
    print("Model saved in path: %s" % save_path)

Step: 2, keRNL tensor Loss 8360.558,
Step: 26, keRNL tensor Loss 8391.920,
Step: 51, keRNL tensor Loss 8725.039,
Step: 76, keRNL tensor Loss 7907.877,
Step: 101, keRNL tensor Loss 7679.605,
Step: 126, keRNL tensor Loss 8445.836,
Step: 151, keRNL tensor Loss 8075.896,
Step: 176, keRNL tensor Loss 7681.286,
Step: 201, keRNL tensor Loss 7864.392,
Step: 226, keRNL tensor Loss 8248.395,
Step: 251, keRNL tensor Loss 7604.929,
Step: 276, keRNL tensor Loss 7745.477,
Step: 301, keRNL tensor Loss 7417.432,
Step: 326, keRNL tensor Loss 7870.850,
Step: 351, keRNL tensor Loss 7877.712,
Step: 376, keRNL tensor Loss 7769.990,
Step: 401, keRNL tensor Loss 7137.896,
Step: 426, keRNL tensor Loss 7758.786,
Step: 451, keRNL tensor Loss 7499.935,
Step: 476, keRNL tensor Loss 7462.279,
Step: 501, keRNL tensor Loss 7260.229,
Step: 526, keRNL tensor Loss 7388.560,
Step: 551, keRNL tensor Loss 7208.475,
Step: 576, keRNL tensor Loss 7633.149,
Step: 601, keRNL tensor Loss 7532.893,
Step: 626, keRNL tensor Loss 7

In [None]:
# verify initialization
config = tf.ConfigProto()
config.gpu_options.allow_growth = True


with tf.Session(graph=graph,) as sess : 
    sess.run(init)
    values,trainable_vars = sess.run([variable_names,trainables])
    for k, v in zip(variable_names,values):
        print(["variable: " , k])
        #print(["value: " , v])
        print(["variable: " , np.unicode_.find(k,'output')]) 
        print(["shape: " , v.shape])
        #print(v) 