# Introduction
 In the following notebook, I aim to create a model of spiking neural network with conductance dynamics. The dynamics of each neuron can be seperated into membrane dynamics, synpase dynamics and spiking dynamics. In its simplest form the following equations goven the dynamics of each neuron 

### membrane dynamics 
membrane dynamics of neuron $i$ during integration has the following dynamics 
$$ C_m \frac{dV_i}{dt}= -g_L (V_i - V_L) - g_{Ei} (V_i-V_E) - g_{Ii} (V_i-V_I) $$ 

In this equation $\frac{C_m}{g_L}$ is the membrane time constant, $g_{Ei}$ and $g_{Ii}$ are synatic conductances for excitatory and inhibitory inputs, and $V_{E}$ and $V_I$  are reversal potention for the corresponding synapses.  
### Firing dynamics
The firing dynamics of the neuron is model as a simple reseting. More specifically, 
$$V_i \rightarrow V_{reset} \ \ \  if \ \ \ V_i>=V_{\Theta} $$

$ V_{\Theta}$ represent the threshold voltage and $V_{reset}$ is the reset voltage of the neuron.

### Input dynamics 
Input synapes are the the site of learning in the spiking network. Below a conductance based formulation is presented. 
First, the time-dependent input conductance to membrane is calculated as follows 
$$ g_i(t) = \sum_j W_{ij} S_{ij}(t) $$

the term $j$ reperesent all the neurons that have a synapse onto the neuron $i$. the time dependence of conductance is due to $S(t)$ which represent the spiking activity for neurons connected to neuron $i$ . The spiking activity has the following governing equations 
$$ S_{ij} \rightarrow S_{ij}+1 \quad if \ neuron\ j\ fires$$
$$ \frac{dS_{ij}(t)}{dt} = \frac{-S_{ij}(t)}{\tau_s}$$ 

So each spike add a unit to $S_{ij}$ and the input decays with time constant $\tau_s$

# Implementation in TensorFlow

Given the dynamics presented about, neurons can be defined as object with the following functions:  
1. function for calucating the membrane dynamics.
2. function for calculating the firing event.
3. function for calculating the conductance. 

We implement the object in within tensorflow and define input as placeholders, and functions stated above are operations in a graph defined over TF variables that represent neurons.

the following parameters are defined as global (based on Fiete et. al. 2007)
$$ C_m=1 \quad V_L=-60 \quad V_E= 0 \quad V_I= -70 \quad g_L=0.03 $$
$$V_{\Theta}=-50 \quad V_{reset}=-55 \quad \tau_s = 5 $$

in the implementation of the code lower case represent scalars and upper case represent vectors and matrices  

In [1]:
from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [32]:
# basic LIFneuron 
class LIFNeuron(object):
    # initialize the network object
    def __init__(self,n_Neur=1,c_m=1, v_m=1.0, v_L=-60, v_E=0, v_I=-70, g_L=0.03, v_Theta=-50.0, v_Reset=-50.0, tau_s=5):
        # number of neurons 
        self.n_Neur=n_Neur 
        # membrane capacitance  
        self.c_m=c_m
        # leak potential
        self.v_L=v_L
        # reversal potentional for excitatory input
        self.v_E=v_E
        # reversal potential for inhibitory input 
        self.v_I=v_I
        # leak conductance
        self.g_L=g_L
        # spike threshold potential
        self.v_Theta=v_Theta
        # Reset potential
        self.v_Reset=v_Reset
        # synapse time constant 
        self.tau_s=tau_s
        # instantiate a graph for neuron 
        self.graph=tf.Graph()
        #build the graph 
        with self.graph.as_default():
            
            # get vars and place holders for the network 
            self.get_vars_and_ph()
            
            # operations on the graph to implement the dynamics 
            self.input=self.get_input_op()
            
            self.get_conductance_op()
            
            self.get_integrating_op()
            
            self.V, self.has_fired =self.get_firing_op()

            # TODO : define outputs 

    # function for variables and placeholders 
    def get_vars_and_ph(self):
        
        # membrane potential V,
        # vector with size as the number of neurons, initial value v_Reset
        self.V=tf.Variable(tf.constant(self.v_Reset,shape=[self.n_Neur,1],dtype=tf.float32),name='V')
        
        # vector representing which neuron has fired at current time 
        self.has_fired=tf.Variable(tf.constant(0.0,shape=[self.n_Neur,1],dtype=tf.float32),name='has_fired')
        
        # 
        self.V_Reset=tf.Variable(tf.constant(self.v_Reset,shape=[self.n_Neur,1],dtype=tf.float32),name='V_Reset')
        
        # leak potential 
        self.V_L=tf.constant(self.v_L,shape=[self.n_Neur,1],dtype=tf.float32)
        
        # excitatory reversal potential 
        self.V_E=tf.constant(self.v_E,shape=[self.n_Neur,1],dtype=tf.float32)
        
        # inhibitory reversal potential 
        self.V_I=tf.constant(self.v_I,shape=[self.n_Neur,1],dtype=tf.float32)
        
        # input current from external source, 
        # vector with size as the number of neurons 
        self.I_ext=tf.placeholder(dtype=tf.float32,shape=[self.n_Neur,1],name='I_ext')
        
        # simulation time interval 
        self.dt=tf.placeholder(dtype=tf.float32, name='dt')
        
        # weight matrix in the network
        self.W=tf.Variable(tf.random.normal(shape=[self.n_Neur,self.n_Neur],mean=0.0,stddev=0.1,dtype=tf.float32),name='W')
        
        # synaptic conductance 
        # vector with size as the number of neurons 
        self.G=tf.Variable(tf.zeros(shape=[self.n_Neur,1],dtype=tf.float32),name='G')
        
        # leak conductance 
        self.G_L=tf.constant(self.g_L,shape=[self.n_Neur,1],dtype=tf.float32)
        
        # synaptic input
        # matrix with size as the number of neurons, columns are input, and rows are output
        self.S=tf.Variable(tf.zeros(shape=[self.n_Neur,self.n_Neur],dtype=tf.float32),name='S')
        
    # function for getting input Current 
    def get_input_op(self):
        return self.I_ext

    # function for neuron dynamics during integration 
    def get_integrating_op(self):
        # get external current input 
        I_ext=self.get_input_op()
        
        # update membrane potential
        dV_op=tf.divide(
            tf.add_n(
                [tf.negative(tf.multiply(self.g_L,tf.subtract(self.V,self.V_L))),
                tf.negative(tf.multiply(self.G,tf.subtract(self.V,self.V_E))),
                tf.negative(tf.multiply(self.G,tf.subtract(self.V,self.V_I)))]),
            self.c_m)
        V_op=self.V.assign_add(dV_op*self.dt)
    
    # neuron dynamics during firing 
    def get_firing_op(self):
        # find out which neurons has crossed the threshold 
        has_fired_op=tf.greater_equal(self.V, tf.constant(self.v_Theta,shape=[self.n_Neur,1],dtype=tf.float32))
        
        # reset membrane potential only for units that fired,
        V_op=tf.where(has_fired_op,self.V_Reset,self.V)
        has_fired_op_float=tf.dtypes.cast(has_fired_op,tf.float32)
        
        self.has_fired.assign(has_fired_op_float)
        self.V.assign(V_op)
        return V_op , has_fired_op_float
    
    # update conductance 
    def get_conductance_op(self):
        # first, update synaptic input 
        dS_op=tf.divide(tf.negative(self.S),self.c_m)
        S_op=self.S.assign_add(dS_op*self.dt)
        
        # second, update the value of synptic input for neurons 
        has_fired_ax=tf.tile(self.has_fired,[1,self.n_Neur])
        self.S.assign_add(has_fired_ax)
        
        # get the updated G from multiplying W and S. 
        self.G.assign(tf.reduce_sum(tf.multiply(self.W,self.S), 1, keepdims=True))
        

next we simulate 1 neuron to test the functionality 

In [30]:

# Duration of the simulation in ms
T = 200
# Duration of each time step in ms
dt = 1
# Number of iterations = T/dt
steps = int(T / dt)
# Output variables
I = []
U = []

neuron = LIFNeuron()


In [31]:

# Simulation with square input currents

# Duration of the simulation in ms
T = 200
# Duration of each time step in ms
dt = 1
# Number of iterations = T/dt
steps = int(T / dt)
# Output variables
I = []
U = []
Fire=[]
neuron = LIFNeuron()
    
with tf.Session(graph=neuron.graph) as sess:

    sess.run(tf.global_variables_initializer())    

    for step in range(steps):
        
        t = step * dt
        # Set input current in mA
        if t > 10 and t < 30:
            I_ext = 0.5
        elif t > 50 and t < 100:
            I_ext = 1.2
        elif t > 120 and t < 180:
            I_ext = 1.5
        else:
            I_ext = 0.0

        feed = { neuron.I_ext: I_ext, neuron.dt: dt}
        
        v,fire = sess.run([neuron.V,neuron.has_fired], feed_dict=feed)

        I.append(I_ext)
        U.append(v)
        Fire.append(fire)

plt.rcParams["figure.figsize"] =(12,6)
# Draw the input current and the membrane potential
plt.figure()
plt.plot([i for i in I])
plt.title('Square input stimuli')
plt.ylabel('Input current (I)')
plt.xlabel('Time (msec)')
plt.figure()
plt.plot([u for u in U])
plt.axhline(y=1.0, color='r', linestyle='-')
plt.title('LIF response')
plt.ylabel('Membrane Potential (mV)')
plt.xlabel('Time (msec)')

ValueError: Cannot feed value of shape () for Tensor 'I_ext:0', which has shape '(1,)'

adding random input current 

In [None]:
# Simulation with random input currents

# Duration of the simulation in ms
T = 200
# Duration of each time step in ms
dt = 1
# Number of iterations = T/dt
steps = int(T / dt)
# Output variables
I = []
U = []

neuron = LIFNeuron()
    
with tf.Session(graph=neuron.graph) as sess:

    sess.run(tf.global_variables_initializer())    

    for step in range(steps):
        
        t = step * dt
        if t > 10 and t < 180:
            i_app = np.random.normal(1.5, 1.0)
        else:
            i_app = 0.0

        feed = { neuron.i_app: i_app, neuron.dt: dt}
        
        u = sess.run(neuron.potential, feed_dict=feed)
        
        I.append(i_app)
        U.append(u)

plt.rcParams["figure.figsize"] =(12,6)
# Draw the input current and the membrane potential
plt.figure()
plt.plot([i for i in I])
plt.title('Random input stimuli')
plt.ylabel('Input current (I)')
plt.xlabel('Time (msec)')
plt.figure()
plt.plot([u for u in U])
plt.axhline(y=1.0, color='r', linestyle='-')
plt.title('LIF response')
plt.ylabel('Membrane Potential (mV)')
plt.xlabel('Time (msec)')

next is simulating a neuron with synaptic current , assume there are m synapses from input neurons projecting to this neuron, given that the  a memory of previous spikes affect current firing, there should be a history of previous spikes 

In [None]:
# A new neuron model derived from the LIF neuron
# It takes synaptic spikes as input and remember them over a specified time period
class LIFSynapticNeuron(LIFNeuron):
    
    def __init__(self, n_syn, w, max_spikes=50, u_rest=0.0, u_thresh=1.0, tau_rest=4.0, r=1.0, tau=10.0, q=1.5, tau_syn=10.0):
      
        # Number of synapses
        self.n_syn = n_syn
        # Maximum number of spikes we remember
        self.max_spikes = max_spikes
        # The neuron synaptic 'charge'
        self.q = q
        # The synaptic time constant (ms)
        self.tau_syn = tau_syn
        # The synaptic efficacy
        self.w = w

        super(LIFSynapticNeuron, self).__init__(u_rest, u_thresh, tau_rest, r, tau)
    
    # Update the parent graph variables and placeholders
    def get_vars_and_ph(self):
        
        # Get parent grah variables and placeholders
        super(LIFSynapticNeuron, self).get_vars_and_ph()

        # Add ours
        
        # The history of synaptic spike times for the neuron 
        self.t_spikes = tf.Variable(tf.constant(-1.0, shape=[self.max_spikes, self.n_syn], dtype=tf.float32))
        # The last index used to insert spike times
        self.t_spikes_idx = tf.Variable(self.max_spikes-1, dtype=tf.int32)
        # A placeholder indicating which synapse spiked in the last time step
        self.syn_has_spiked = tf.placeholder(shape=[self.n_syn], dtype=tf.bool)

    # Operation to update spike times
    def update_spike_times(self):
        
        # Increase the age of older spikes
        old_spikes_op = self.t_spikes.assign_add(tf.where(self.t_spikes >=0,
                                                          tf.constant(1.0, shape=[self.max_spikes, self.n_syn]) * self.dt,
                                                          tf.zeros([self.max_spikes, self.n_syn])))

        # Increment last spike index (modulo max_spikes)
        new_idx_op = self.t_spikes_idx.assign(tf.mod(self.t_spikes_idx + 1, self.max_spikes))

        # Create a list of coordinates to insert the new spikes
        idx_op = tf.constant(1, shape=[self.n_syn], dtype=tf.int32) * new_idx_op
        coord_op = tf.stack([idx_op, tf.range(self.n_syn)], axis=1)

        # Create a vector of new spike times (non-spikes are assigned a negative time)
        new_spikes_op = tf.where(self.syn_has_spiked,
                                 tf.constant(0.0, shape=[self.n_syn]),
                                 tf.constant(-1.0, shape=[self.n_syn]))
        
        # Replace older spikes by new ones
        return tf.scatter_nd_update(old_spikes_op, coord_op, new_spikes_op)

    # Override parent get_input_op method
    def get_input_op(self):
        
        # Update our memory of spike times with the new spikes
        t_spikes_op = self.update_spike_times()

        # Evaluate synaptic input current for each spike on each synapse
        i_syn_op = tf.where(t_spikes_op >=0,
                            self.q/self.tau_syn * tf.exp(tf.negative(t_spikes_op/self.tau_syn)),
                            t_spikes_op*0.0)

        # Add each synaptic current to the input current
        i_op =  tf.reduce_sum(self.w * i_syn_op)
        
        return tf.add(self.i_app, i_op)
    

In [None]:
# Simulation with synaptic input currents

# Duration of the simulation in ms
T = 200
# Duration of each time step in ms
dt = 1
# Number of iterations = T/dt
steps = int(T / dt)
# Number of synapses
n_syn = 25
# Spiking frequency in Hz
f = 20
# We need to keep track of input spikes over time
syn_has_spiked = np.full((steps,n_syn), False)
# We define the synaptic efficacy as a random vector
W = np.random.normal(1.0, 0.5, size=n_syn)
# Output variables
I = []
U = []

# Instantiate our synaptic LIF neuron, with a memory of 200 events
# Note that in practice, a much shorter period is required as the
# contribution of each synapse decreases very rapidly
neuron = LIFSynapticNeuron(n_syn=n_syn, w=W, max_spikes=200)
    
with tf.Session(graph=neuron.graph) as sess:

    sess.run(tf.global_variables_initializer())    

    for step in range(steps):
        
        t = step * dt
        
        if t > 10 and t < 180:
            r = np.random.uniform(0,1, size=(n_syn))
            syn_has_spiked[step,:] = r < f * dt * 1e-3

        feed = { neuron.i_app: 0.0, neuron.syn_has_spiked: syn_has_spiked[step], neuron.dt: dt}
        i, u = sess.run([neuron.input, neuron.potential], feed_dict=feed)

        I.append(i)
        U.append(u)
plt.rcParams["figure.figsize"] =(12,6)
# Draw spikes
spikes = np.argwhere(syn_has_spiked)
t, s = spikes.T
plt.figure()
plt.axis([0, T, 0, n_syn])
plt.title('Synaptic spikes')
plt.ylabel('spikes')
plt.xlabel('Time (msec)')
plt.scatter(t, s)
# Draw the input current and the membrane potential
plt.figure()
plt.plot([i for i in I])
plt.title('Synaptic input')
plt.ylabel('Input current (I)')
plt.xlabel('Time (msec)')
plt.figure()
plt.plot([u for u in U])
plt.axhline(y=1.0, color='r', linestyle='-')
plt.title('LIF response')
plt.ylabel('Membrane Potential (mV)')
plt.xlabel('Time (msec)')

next step is combining a recurrent input with input to the network 

# References and resources

https://github.com/kaizouman/tensorsandbox/blob/master/snn/leaky_integrate_fire.ipynb

https://lcn.epfl.ch/~gerstner/SPNM/node26.html#SECTION02311000000000000000
 
Fiete, Ila R., Michale S. Fee, and H. Sebastian Seung. 2007. “Model of Birdsong Learning Based on Gradient Estimation by Dynamic Perturbation of Neural Conductances.” Journal of Neurophysiology 98 (4): 2038–57.

Fiete, Ila R., and H. Sebastian Seung. 2006. “Gradient Learning in Spiking Neural Networks by Dynamic Perturbation of Conductances.” Physical Review Letters 97 (4): 048104.