<a href="https://colab.research.google.com/github/nesarashree/snn/blob/main/1D_LIF_simulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
try:
  import ipywidgets as widgets
except ImportError:
  widgets=None

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

In [None]:
fig = plt.figure(figsize=(5,4))
ax=plt.subplot(111)
plt.close(fig)

In [None]:
# pseudocode 1D-LIF

# for each time period t:
  # update V from value at time t to time (t+dt)
  # process incoming spikes- increase V value by synaptic weight w
  # check if V crossed threshold
  # if so:
    # emit spike
    # reset V

In [None]:
# function to run simulation
def LIF(tau=10, # tau = time constant
        t0=10,t1=40,t2=60, # t0,t1,t2 = time of 3 input spikes
        w=0.1, # w = input synapse weight
        threshold=1.0, # threshold value to produce spike
        reset=0.0): #reset value after spike
  times = [t0,t1,t2]  # spike times
  times.sort(reverse=True) #sort times- more efficient to take the last value off the list

  # set parameters
  duration = 100 # time total in ms
  dt=0.1 # timestamp in ms
  V_rec=[] # list to record membrane potentials
  V=0.0 # initial membrane potential
  T=np.arange(np.round(duration/dt))*dt # array of times
  alpha=np.exp(-dt/tau) # factor V decays each time step
  spikes = [] # list to store spike times

  # run simulation
  for t in T:
    V_rec.append(V) # record V (single neuron case, don't need to propogate output spike)
    V *= alpha  # integrating equations
    if times and t>times[-1]: # if there is an input spike
      V += w # increase V by synaptic weight w
      times.pop() # remove spike from list
    V_rec.append(V) # record V before reset just so we can see spike
    if V>threshold: # should there be an output spike?
      V=reset # reset V
      spikes.append(t) # emit spike

  # generate plot
  ax.clear()
  for t in times:
    ax.axvline(t,ls='--',c='b') # blue dashed line per spike (NOT WORKINGG)
  ax.plot(np.repeat(T,2),V_rec,'-k',lw=2)

  for t in spikes:
    ax.axvline(t,ls='--',c='r') # red dashed line per firing spike (crosses thresh)
  ax.axhline(threshold,ls='--',c='g') # green dashed line at threshold constant

  # set axes
  ax.set_xlim(0,duration)
  ax.set_xlabel('Time (ms)')
  ax.set_ylim(-1,2)
  ax.set_ylabel('V')

  plt.tight_layout()
  display(fig);

# create interactive widget for variables
widgets.interact(LIF,
    tau=widgets.IntSlider(min=1,max=100,values=50),
    t1=widgets.IntSlider(min=0,max=100,values=20),
    t0=widgets.IntSlider(min=0,max=100,values=40),
    t2=widgets.IntSlider(min=0,max=100,values=60),
    w=widgets.FloatSlider(min=-1,max=2,step=0.05,value=0.5),
    threshold=widgets.FloatSlider(min=0.0,max=2.0,step=0.05,values=1.0),
    reset=widgets.FloatSlider(min=0.0,max=1.0,step=0.05,value=0.0)
    );

interactive(children=(IntSlider(value=1, description='tau', min=1), IntSlider(value=0, description='t0'), IntS…