# Code for implementing adding learning on kernel snn


# Background 
This code implements a spiking neural net with conductance in input. the following equations govern the dynamic of the network. 
### transmembrane voltage dynamics
first we model the transmembrane voltage as 
$$\tau_m \frac{dV_i}{dt}= - V_i(t)+ R_m \times I^{syn}_i(t) $$ 
$$ {\tau_a}_i \frac{dB_i(t)}{dt} = b_i^0 -B_i(t)$$ 
where, $R_m$ is membrane resistance, $\tau_m$ is membrane time constant, and ${\tau_a}_i$ is adaptation time constant  .
the synaptic current relates to synaptic activations in the following way
$$I^{syn}_i(t)= \sum_j W^{in}_{ij} \times X(t) + \sum_j W^{rec}_{ij} \times S_j(t) $$ 

### neuron firing dynamics 
The firing dynamics of the neuron is model as a simple reseting. More specifically, 
$$V_i \rightarrow V_{reset} \ \ \  if \ \ \ V_i>=B_{i} $$

$ V_{\Theta}$ represent the threshold voltage and $V_{reset}$ is the reset voltage of the neuron.

### Input dynamics 
Input synapes are the the site of learning in the spiking network. Below a conductance based formulation is presented. 
First, the time-dependent input conductance to membrane is calculated as follows 
$$ g_i(t) = \sum_j W_{ij} S_{j}(t) $$

in the current version $S_{j}(t)$ is equal to spike at timestep $t$ without any decay dynamics. 
-  TODO the term $j$ reperesent all the neurons that have a synapse onto the neuron $i$. the time dependence of conductance is due to $S(t)$ which represent the spiking activity for neurons connected to neuron $i$ . The spiking activity has the following governing equations 
$$ S_{j} \rightarrow S_{j}+1 \quad if \ neuron\ j\ fires$$
$$ \frac{dS_{j}(t)}{dt} = \frac{-S_{j}(t)}{\tau_s}$$ 

### Spike Adaptation dynamics 
The threshold for spiking increases with every spike emited from a neuron with the following dynamics 
$$ B_{i}(t) \rightarrow B_{i}(t)+\frac{\beta}{{\tau_a}_i} \quad if \ neuron\ i\ fires$$


### implementation in discrete time 
we start with Euler method for modeling the dynamics 
### References 
-  Fiete, Ila R., Walter Senn, Claude Z. H. Wang, and Richard H. R. Hahnloser. 2010. “Spike-Time-Dependent Plasticity and Heterosynaptic Competition Organize Networks to Produce Long Scale-Free Sequences of Neural Activity.” Neuron 65 (4): 563–76. 

-  Bellec, Guillaume, Darjan Salaj, Anand Subramoney, Robert Legenstein, and Wolfgang Maass. 2018. “Long Short-Term Memory and Learning-to-Learn in Networks of Spiking Neurons.” arXiv [cs.NE]. arXiv. http://arxiv.org/abs/1803.09574.



In [1]:
# python libraries
import numpy as np 
import matplotlib.pyplot as plt 
import collections
import hashlib
import numbers
import matplotlib.cm as cm
from sys import getsizeof
from datetime import datetime
from pathlib import Path
import os
from IPython.display import HTML
import re

# tensorflow and its dependencies 
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import _Linear
from tensorflow.contrib import slim

## user defined modules 
# kernel rnn cell 
import kernl_spiking_cell_v4 as kernl_spiking_cell 
import adding_problem

### initial parameters for the network 

In [2]:
# Training Parameters
weight_learning_rate = 1e-5
tensor_learning_rate=1e-3 
training_steps = 5000
batch_size = 25
display_step = 50
grad_clip=200
# Network Parameters
# 1-input layer 
num_input = 2 # 
time_steps = 100 # timesteps
noise_std=1e-4
# 2-hidden layer 
num_hidden = 100 # hidden layer num of features
# 3-output layer 
num_output = 1 


### define the network

In [3]:
def kernl_SNN_all_states(x):
    with tf.variable_scope('hidden_layer') as scope: 
        hidden_layer_cell=kernl_spiking_cell.kernl_spike_Cell(num_units=num_hidden,
                                                                 num_inputs=num_input,
                                                                 time_steps=time_steps,
                                                                 output_is_tuple=True,
                                                                 tau_refract=1.0,
                                                                 tau_m=20,
                                                                 noise_param=noise_std)
        hidden_initial_state = hidden_layer_cell.zero_state(batch_size, dtype=tf.float32)
        output_hidden, states_hidden = tf.nn.dynamic_rnn(hidden_layer_cell, dtype=tf.float32, inputs=x,initial_state=hidden_initial_state)
    with tf.variable_scope('output_layer') as scope : 
        output_layer_cell=kernl_spiking_cell.output_spike_cell(num_units=num_output)
        output_voltage, voltage_states=tf.nn.dynamic_rnn(output_layer_cell,dtype=tf.float32,inputs=output_hidden.spike)
    return output_voltage,output_hidden


### Gradient equation 

### Computation graph for learning $M$ and $K(\tau)$

In [8]:
tf.reset_default_graph()
graph=tf.Graph()
with graph.as_default():
    # check hardware 
    # define weights and inputs to the network
    X = tf.placeholder("float", [None, time_steps, num_input])
    Y = tf.placeholder("float", [None, num_output])
   
    kernl_output,kernl_hidden_states=kernl_SNN_all_states(X)
    
    trainables=tf.trainable_variables()
    variable_names=[v.name for v in tf.trainable_variables()]
    # 
    kernl_loss_output_prediction=tf.losses.mean_squared_error(Y,kernl_output[:,-1,:])
    find_joing_index = lambda x, name_1,name_2 : [a and b for a,b in zip([np.unicode_.find(k.name, name_1)>-1 for k in x] ,[np.unicode_.find(k.name, name_2)>-1 for k in x])].index(True)
    # find trainable parameters for kernl 
    with tf.name_scope('kernl_Trainables') as scope:
        kernl_output_weight_index= find_joing_index(trainables,'output_layer','kernel')
        kernl_temporal_filter_index= find_joing_index(trainables,'kernl','temporal_filter')
        kernl_sensitivity_tensor_index= find_joing_index(trainables,'kernl','sensitivity_tensor')
        kernl_kernel_index= find_joing_index(trainables,'hidden_layer','kernel')
    # 
        kernl_tensor_training_indices=np.asarray([kernl_sensitivity_tensor_index,kernl_temporal_filter_index],dtype=np.int)
        kernl_tensor_trainables= [trainables[k] for k in kernl_tensor_training_indices]
    #
        kernl_weight_training_indices=np.asarray([kernl_kernel_index,kernl_output_weight_index],dtype=np.int)
        kernl_weight_trainables= [trainables[k] for k in kernl_weight_training_indices]
   
    with tf.name_scope('kernl_train_tensors') as scope: 
        state_diff_int=tf.subtract(kernl_hidden_states.v_mem_hat[:,:,:], kernl_hidden_states.v_mem[:,:,:])
        estimate_state_diff_int=tf.einsum('unv,vk->unk',kernl_hidden_states.Theta,trainables[kernl_sensitivity_tensor_index])
        kernl_loss_state_prediction=tf.losses.mean_squared_error(state_diff_int,estimate_state_diff_int)
        kernl_tensor_optimizer = tf.train.RMSPropOptimizer(learning_rate=tensor_learning_rate)
        kernl_tensor_grads=tf.gradients(ys=kernl_loss_state_prediction,xs=kernl_tensor_trainables)
        kernl_tensor_grad_and_vars=list(zip(kernl_tensor_grads,kernl_tensor_trainables))
        kernl_tensor_train_op=kernl_tensor_optimizer.apply_gradients(kernl_tensor_grad_and_vars)
        
    with tf.name_scope('kernl_train_weights') as scope: 
        kernl_weight_optimizer = tf.train.RMSPropOptimizer(learning_rate=weight_learning_rate)
        kernl_loss_output_prediction=tf.losses.mean_squared_error(Y,kernl_output[:,-1,:])
        #kernl_grad_cost_to_output=tf.gradients(ys=kernl_loss_output_prediction,xs=kernl_output[:,-1,:], name= 'kernl_grad_cost_to_y')
        kernl_grad_cost_to_output=tf.scalar_mul(2,tf.subtract(Y,kernl_output[:,-1,:]))
        kernl_error_in_hidden_state=tf.matmul(kernl_grad_cost_to_output,tf.transpose(trainables[kernl_output_weight_index]))
        kernl_delta_weight=tf.matmul(kernl_error_in_hidden_state,trainables[kernl_sensitivity_tensor_index]) 
        kernl_weight_update_test=tf.einsum("un,unv->unv",kernl_delta_weight,kernl_hidden_states.eligibility_trace[:,-1,:,:])
        kernl_weight_update=tf.transpose(tf.reduce_mean(kernl_weight_update_test,axis=0))
        # output layer 
        kernl_grad_cost_to_output_layer=tf.gradients(ys=kernl_loss_output_prediction,xs=trainables[kernl_output_weight_index])
        # crop the gradients  
        kernl_weight_grads_and_vars=list(zip([kernl_weight_update,kernl_grad_cost_to_output_layer[0]],kernl_weight_trainables))
        kernl_cropped_weight_grads_and_vars=[(tf.clip_by_norm(grad, grad_clip),var) if  np.unicode_.find(var.name,'output')==-1 else (grad,var) for grad,var in kernl_weight_grads_and_vars]
        # apply gradients 
        kernl_weight_train_op = kernl_weight_optimizer.apply_gradients(kernl_cropped_weight_grads_and_vars)
    
    ##################
    # SUMMARIES ######
    ##################
    
    with tf.name_scope("kernl_tensor_summaries") as scope: 
        # kernl sensitivity tensor 
        tf.summary.histogram('kernl_sensitivity_tensor_grad',kernl_tensor_grads[0]+1e-10)
        tf.summary.histogram('kernl_sensitivity_tensor',trainables[kernl_sensitivity_tensor_index]+1e-10)
        # kernl temporal filter 
        tf.summary.histogram('kernl_temporal_filter_grad',kernl_tensor_grads[1]+1e-10)
        tf.summary.histogram('kernl_temporal_filter',trainables[kernl_temporal_filter_index]+1e-10)
        # kernl loss 
        tf.summary.scalar('kernl_loss_state_prediction',kernl_loss_state_prediction+1e-10)
        # kernl senstivity tensor and temporal filter 
        #tf.summary.image('kernl_sensitivity_tensor',tf.expand_dims(tf.expand_dims(trainables[kernl_sensitivity_tensor_index],axis=0),axis=-1))
        #tf.summary.image('kernl_sensitivity_tensor_grad',tf.expand_dims(tf.expand_dims(kernl_tensor_grads[0],axis=0),axis=-1))
        #tf.summary.image('kernl_temporal_filter',tf.expand_dims(tf.expand_dims(tf.expand_dims(trainables[kernl_temporal_filter_index],axis=0),axis=-1),axis=-1))
        #tf.summary.image('kernl_temporal_filter_grad',tf.expand_dims(tf.expand_dims(tf.expand_dims(kernl_tensor_grads[1],axis=0),axis=-1),axis=-1))
        kernl_tensor_merged_summary_op=tf.summary.merge_all(scope="kernl_tensor_summaries")
    
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

<tf.Tensor 'kernl_train_weights/mean_squared_error/value:0' shape=() dtype=float32>

In [10]:
# verify initializatio

with tf.Session(graph=graph,) as sess : 
    sess.run(init)
    values,trainable_vars = sess.run([variable_names,trainables])
    for k, v in zip(variable_names,values):
        print(["variable: " , k])
        #print(["value: " , v])
        print(["variable: " , np.unicode_.find(k,'output')]) 
        print(["shape: " , v.shape])
        #print(v) 

['variable: ', 'hidden_layer/rnn/kernl_spike__cell/temporal_filter:0']
['variable: ', -1]
['shape: ', (100,)]
['variable: ', 'hidden_layer/rnn/kernl_spike__cell/sensitivity_tensor:0']
['variable: ', -1]
['shape: ', (100, 100)]
['variable: ', 'hidden_layer/rnn/kernl_spike__cell/kernel:0']
['variable: ', -1]
['shape: ', (102, 100)]
['variable: ', 'output_layer/rnn/output_spike_cell/kernel:0']
['variable: ', 0]
['shape: ', (100, 1)]


In [11]:
log_dir = os.environ['HOME']+"/MyData/KeRNL/logs/kernl_snn_addition/add_eta_tensor_%1.0e_eta_weight_%1.0e_batch_%1.0e_hum_hidd_%1.0e_gc_%1.0e_steps_%1.0e_run_%s" %(tensor_learning_rate,weight_learning_rate,batch_size,num_hidden,grad_clip,training_steps, datetime.now().strftime("%Y%m%d_%H%M"))
Path(log_dir).mkdir(exist_ok=True, parents=True)
filelist = [ f for f in os.listdir(log_dir) if f.endswith(".local") ]
for f in filelist:
    os.remove(os.path.join(log_dir, f))

In [12]:
# write graph into tensorboard 
tb_writer = tf.summary.FileWriter(log_dir,graph)
# run a training session 
with tf.Session(graph=graph) as sess:
    sess.run(init)
    for step in range(1,training_steps+1):
        batch_x, batch_y = adding_problem.get_batch(batch_size=batch_size,time_steps=time_steps)
        
        kernl_tensor_train,kernl_loss_state=sess.run([kernl_tensor_train_op,kernl_loss_state_prediction], feed_dict={X: batch_x,Y:batch_y})
        kernl_weight_train, kernl_loss=sess.run([kernl_weight_train_op,kernl_loss_output_prediction],feed_dict={X:batch_x, Y:batch_y})
        # run summaries 
        kernl_tensor_merged_summary=sess.run(kernl_tensor_merged_summary_op,feed_dict={X:batch_x, Y:batch_y})
        
        tb_writer.add_summary(kernl_tensor_merged_summary, global_step=step)
        # 
        if step % display_step==0 or step==1 : 
            # get batch loss and accuracy 
            print('Step: {}, kernl tensor Loss {:.3f}, kernl weight loss {:.3f}'.format(step + 1, kernl_loss_state,kernl_loss))


    print("Optimization Finished!")
    #test_data = mnist.test.images[:test_len].reshape((-1, timesteps, num_input))
    #test_label = mnist.test.labels[:test_len]
    #print("Testing Accuracy:", 
    #    sess.run(loss_output_prediction, feed_dict={X: test_data, Y: test_label}))
    save_path = saver.save(sess, log_dir+"/model.ckpt", global_step=step,write_meta_graph=True)
    print("Model saved in path: %s" % save_path)

Step: 2, kernl tensor Loss 0.000, kernl weight loss 15.170
Step: 51, kernl tensor Loss 0.001, kernl weight loss 10.483
Step: 101, kernl tensor Loss 0.001, kernl weight loss 13.719
Step: 151, kernl tensor Loss 0.001, kernl weight loss 13.196
Step: 201, kernl tensor Loss 0.004, kernl weight loss 8.336
Step: 251, kernl tensor Loss 0.000, kernl weight loss 14.149
Step: 301, kernl tensor Loss 0.000, kernl weight loss 10.687
Step: 351, kernl tensor Loss 0.000, kernl weight loss 10.887
Step: 401, kernl tensor Loss 0.002, kernl weight loss 10.397
Step: 451, kernl tensor Loss 0.001, kernl weight loss 8.231
Step: 501, kernl tensor Loss 0.002, kernl weight loss 10.696
Step: 551, kernl tensor Loss 0.007, kernl weight loss 8.437
Step: 601, kernl tensor Loss 0.002, kernl weight loss 6.522
Step: 651, kernl tensor Loss 0.000, kernl weight loss 4.914
Step: 701, kernl tensor Loss 0.000, kernl weight loss 7.132
Step: 751, kernl tensor Loss 0.001, kernl weight loss 6.894
Step: 801, kernl tensor Loss 0.035