# Code for implementing adding learning on snn with backpropagation 



# Background 
This code implements a spiking neural net with conductance in input. the following equations govern the dynamic of the network. 
### transmembrane voltage dynamics
first we model the transmembrane voltage as 
$$\tau_m \frac{dV_i}{dt}= - V_i(t)+ R_m \times I^{syn}_i(t) $$ 
$$ {\tau_a}_i \frac{dB_i(t)}{dt} = b_i^0 -B_i(t)$$ 
where, $R_m$ is membrane resistance, $\tau_m$ is membrane time constant, and ${\tau_a}_i$ is adaptation time constant  .
the synaptic current relates to synaptic activations in the following way
$$I^{syn}_i(t)= \sum_j W^{in}_{ij} \times X(t) + \sum_j W^{rec}_{ij} \times S_j(t) $$ 

### neuron firing dynamics 
The firing dynamics of the neuron is model as a simple reseting. More specifically, 
$$V_i \rightarrow V_{reset} \ \ \  if \ \ \ V_i>=B_{i} $$

$ V_{\Theta}$ represent the threshold voltage and $V_{reset}$ is the reset voltage of the neuron.

### Input dynamics 
Input synapes are the the site of learning in the spiking network. Below a conductance based formulation is presented. 
First, the time-dependent input conductance to membrane is calculated as follows 
$$ g_i(t) = \sum_j W_{ij} S_{j}(t) $$

in the current version $S_{j}(t)$ is equal to spike at timestep $t$ without any decay dynamics. 
-  TODO the term $j$ reperesent all the neurons that have a synapse onto the neuron $i$. the time dependence of conductance is due to $S(t)$ which represent the spiking activity for neurons connected to neuron $i$ . The spiking activity has the following governing equations 
$$ S_{j} \rightarrow S_{j}+1 \quad if \ neuron\ j\ fires$$
$$ \frac{dS_{j}(t)}{dt} = \frac{-S_{j}(t)}{\tau_s}$$ 

### Spike Adaptation dynamics 
The threshold for spiking increases with every spike emited from a neuron with the following dynamics 
$$ B_{i}(t) \rightarrow B_{i}(t)+\frac{\beta}{{\tau_a}_i} \quad if \ neuron\ i\ fires$$


### implementation in discrete time 
we start with Euler method for modeling the dynamics 
### References 
-  Fiete, Ila R., Walter Senn, Claude Z. H. Wang, and Richard H. R. Hahnloser. 2010. “Spike-Time-Dependent Plasticity and Heterosynaptic Competition Organize Networks to Produce Long Scale-Free Sequences of Neural Activity.” Neuron 65 (4): 563–76. 

-  Bellec, Guillaume, Darjan Salaj, Anand Subramoney, Robert Legenstein, and Wolfgang Maass. 2018. “Long Short-Term Memory and Learning-to-Learn in Networks of Spiking Neurons.” arXiv [cs.NE]. arXiv. http://arxiv.org/abs/1803.09574.



In [1]:
# python libraries
import numpy as np 
import matplotlib.pyplot as plt 
import collections
import hashlib
import numbers
import matplotlib.cm as cm
from sys import getsizeof
from datetime import datetime
from pathlib import Path
import os
from IPython.display import HTML
import re

# tensorflow and its dependencies 
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import _Linear
from tensorflow.contrib import slim

## user defined modules 
# kernel rnn cell 
import spiking_cell_bare as spiking_cell
import adding_problem

### initial parameters for the network 

In [2]:
# Training Parameters
weight_learning_rate = 1e-5
training_steps = 5000
batch_size = 25
display_step = 50
grad_clip=200
# Network Parameters
# 1-input layer 
num_input = 2 # 
time_steps = 100 # timesteps
# 2-hidden layer 
num_hidden = 100 # hidden layer num of features
# 3-output layer 
num_output = 1 


### define the network

In [3]:
def bptt_snn_all_states(x):
    with tf.variable_scope('hidden_layer') as scope: 
        hidden_layer_cell=spiking_cell.conductance_spike_cell(num_units=num_hidden,output_is_tuple=True,tau_refract=1.0,tau_m=20.0)
        hidden_initial_state = hidden_layer_cell.zero_state(batch_size, dtype=tf.float32)
        output_hidden, states_hidden = tf.nn.dynamic_rnn(hidden_layer_cell, dtype=tf.float32, inputs=x,initial_state=hidden_initial_state)
    with tf.variable_scope('output_layer') as scope : 
        output_layer_cell=spiking_cell.output_spike_cell(num_units=num_output)
        output_voltage, voltage_states=tf.nn.dynamic_rnn(output_layer_cell,dtype=tf.float32,inputs=output_hidden.spike)
    return output_voltage,output_hidden


### Computation graph backpropagation 


In [4]:
tf.reset_default_graph()
graph=tf.Graph()
with graph.as_default():
    # check hardware 
    # define weights and inputs to the network
    X = tf.placeholder("float", [None, time_steps, num_input])
    Y = tf.placeholder("float", [None, num_output])
   
    bptt_output,bptt_hidden_states=bptt_snn_all_states(X)
    
    trainables=tf.trainable_variables()
    variable_names=[v.name for v in tf.trainable_variables()]
    # 
    bptt_loss_output_prediction=tf.losses.mean_squared_error(Y,bptt_output[:,-1,:])
    find_joing_index = lambda x, name_1,name_2 : [a and b for a,b in zip([np.unicode_.find(k.name, name_1)>-1 for k in x] ,[np.unicode_.find(k.name, name_2)>-1 for k in x])].index(True)
    # find trainable parameters for bptt 
    with tf.name_scope('bptt_Trainables') as scope:
        bptt_output_weight_index= find_joing_index(trainables,'output_layer','kernel')
        bptt_kernel_index= find_joing_index(trainables,'hidden_layer','kernel')
    #
        bptt_weight_training_indices=np.asarray([bptt_kernel_index,bptt_output_weight_index],dtype=np.int)
        bptt_weight_trainables= [trainables[k] for k in bptt_weight_training_indices]
        
    with tf.name_scope('bptt_train_weights') as scope: 
        bptt_weight_optimizer = tf.train.RMSPropOptimizer(learning_rate=weight_learning_rate)
        bptt_loss_output_prediction=tf.losses.mean_squared_error(Y,bptt_output[:,-1,:])
        bptt_grad_cost_trainables=tf.gradients(bptt_loss_output_prediction,bptt_weight_trainables)
        bptt_weight_grads_and_vars=list(zip(bptt_grad_cost_trainables,bptt_weight_trainables))
        bptt_cropped_weight_grads_and_vars=[(tf.clip_by_norm(grad, grad_clip),var) if  np.unicode_.find(var.name,'output')==-1 else (grad,var) for grad,var in bptt_weight_grads_and_vars]
        # apply gradients 
        bptt_weight_train_op = bptt_weight_optimizer.apply_gradients(bptt_cropped_weight_grads_and_vars)
    
    ##################
    # SUMMARIES ######
    ##################
    
    with tf.name_scope("bptt_weight_summaries") as scope: 
        # bptt sensitivity tensor 
        tf.summary.histogram('bptt_kernel_grad',bptt_grad_cost_trainables[0]+1e-10)
        tf.summary.histogram('bptt_kernel', bptt_grad_cost_trainables[0]+1e-10)
                    # bptt output weight
        tf.summary.histogram('bptt_output_weight_grad',bptt_grad_cost_trainables[1]+1e-10)
        tf.summary.histogram('bptt_output_weights', bptt_grad_cost_trainables[1]+1e-10)
                    # bptt loss and accuracy
        tf.summary.scalar('bptt_loss_output_prediction',bptt_loss_output_prediction+1e-10)
        
        # bptt senstivity tensor and temporal filter 
        bptt_tensor_merged_summary_op=tf.summary.merge_all(scope="bptt_weight_summaries")
    
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

In [5]:
# verify initializatio

with tf.Session(graph=graph,) as sess : 
    sess.run(init)
    values,trainable_vars = sess.run([variable_names,trainables])
    for k, v in zip(variable_names,values):
        print(["variable: " , k])
        #print(["value: " , v])
        print(["variable: " , np.unicode_.find(k,'output')]) 
        print(["shape: " , v.shape])
        #print(v) 

['variable: ', 'hidden_layer/rnn/conductance_spike_cell/kernel:0']
['variable: ', -1]
['shape: ', (102, 100)]
['variable: ', 'output_layer/rnn/output_spike_cell/kernel:0']
['variable: ', 0]
['shape: ', (100, 1)]
['variable: ', 'output_layer/rnn/output_spike_cell/bias:0']
['variable: ', 0]
['shape: ', (1,)]


In [6]:
log_dir = os.environ['HOME']+"/MyData/KeRNL/logs/bptt_snn_addition/add_eta_weight_%1.0e_batch_%1.0e_hum_hidd_%1.0e_gc_%1.0e_steps_%1.0e_run_%s" %(weight_learning_rate,batch_size,num_hidden,grad_clip,training_steps, datetime.now().strftime("%Y%m%d_%H%M"))
Path(log_dir).mkdir(exist_ok=True, parents=True)
filelist = [ f for f in os.listdir(log_dir) if f.endswith(".local") ]
for f in filelist:
    os.remove(os.path.join(log_dir, f))

In [7]:
# write graph into tensorboard 
tb_writer = tf.summary.FileWriter(log_dir,graph)
# run a training session 
with tf.Session(graph=graph) as sess:
    sess.run(init)
    for step in range(1,training_steps+1):
        batch_x, batch_y = adding_problem.get_batch(batch_size=batch_size,time_steps=time_steps)
        bptt_weight_train, bptt_loss=sess.run([bptt_weight_train_op,bptt_loss_output_prediction],feed_dict={X:batch_x, Y:batch_y})
        # run summaries 
        bptt_tensor_merged_summary=sess.run(bptt_tensor_merged_summary_op,feed_dict={X:batch_x, Y:batch_y})
        
        tb_writer.add_summary(bptt_tensor_merged_summary, global_step=step)
        # 
        if step % display_step==0 or step==1 : 
            # get batch loss and accuracy 
            print('Step: {}, bptt weight loss {:.3f}'.format(step + 1, bptt_loss))


    print("Optimization Finished!")
    #test_data = mnist.test.images[:test_len].reshape((-1, timesteps, num_input))
    #test_label = mnist.test.labels[:test_len]
    #print("Testing Accuracy:", 
    #    sess.run(loss_output_prediction, feed_dict={X: test_data, Y: test_label}))
    save_path = saver.save(sess, log_dir+"/model.ckpt", global_step=step,write_meta_graph=True)
    print("Model saved in path: %s" % save_path)

Step: 2, bptt weight loss 397.244
Step: 51, bptt weight loss 354.662
Step: 101, bptt weight loss 347.002
Step: 151, bptt weight loss 430.255
Step: 201, bptt weight loss 288.001
Step: 251, bptt weight loss 356.837
Step: 301, bptt weight loss 305.316
Step: 351, bptt weight loss 348.729
Step: 401, bptt weight loss 245.598
Step: 451, bptt weight loss 287.464
Step: 501, bptt weight loss 246.126
Step: 551, bptt weight loss 240.125
Step: 601, bptt weight loss 222.274
Step: 651, bptt weight loss 193.074
Step: 701, bptt weight loss 196.821
Step: 751, bptt weight loss 197.269
Step: 801, bptt weight loss 190.732
Step: 851, bptt weight loss 184.934
Step: 901, bptt weight loss 196.524
Step: 951, bptt weight loss 133.992
Step: 1001, bptt weight loss 137.271
Step: 1051, bptt weight loss 205.615
Step: 1101, bptt weight loss 112.918
Step: 1151, bptt weight loss 106.671
Step: 1201, bptt weight loss 150.113
Step: 1251, bptt weight loss 147.707
Step: 1301, bptt weight loss 129.931
Step: 1351, bptt weight 