<a href="https://colab.research.google.com/github/divyanshgupt/Unreliable-Transmission/blob/main/Superspike_Multilayer_Network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Dependencies
import numpy as np
from matplotlitb import pyplot as plt

import torch

from tqdm import tqdm

In [None]:
dtype = torch.float
device = torch.device('cpu')

# device = torch.device('cuda:0')


In [None]:
nb_inputs = 100
nb_hidden = 4
nb_outputs = 1

In [None]:
args = {'thres': -50,
        'u_rest': -60,
        'tau_mem': 1e-2,
        'tau_syn': 5e-3,
        'tau_ref': 5e-3,
        't_rise': 5e-3, # the pre-synaptic double exponential kernel rise time
        't_decay': 1e-2, # the pre-synaptic double exponential kernel decay time
        'timestep_size': 1e-4, # 0.1 msec
        't_rise_alpha': 5e-3, # change this
        't_decay_alpha': 1e-2, # change this 
        'nb_steps': 5000} # 0.5 secs in total

nb_steps = args['nb_steps']
timestep_size = args['timestep_size']

In [None]:
tau_mem = args['tau_mem']
tau_syn = args['tau_syn']
dt = args['timestep_size']

alpha = float(np.exp(-dt/tau_syn))
beta = float(np.exp(-dt/tau_mem))

In [None]:
#@title Helper Functions

# Presynaptic Trace
def presynaptic_trace_2(input_trains, args):
  """
  Evaluates the presynaptic trace (double exponential kernel)
  given the whole input trains
  Inputs:
    input_trains
    args:['timestep_size', 't_rise', 't_decay', 'nb_steps']
  Returns:
    Presynaptic Trace 
  """
  dt = args['timestep_size']
  t_rise = args['t_rise']
  t_decay = args['t_decay']
  nb_timesteps = args['nb_steps']
  nb_trains = len(input_trains)

  trace_1 = torch.zeros((nb_trains, nb_timesteps), device=device, dtype=dtype)
  trace_2 = torch.zeros((nb_trains, nb_timesteps), device=device, dtype=dtype)

  for t in range(nb_timesteps - 1):
    trace_1[:, t+1] = trace_1[:, t] + (-trace_1[:, t]/t_rise + input_trains[:, t])*dt
    trace_2[:, t+1] = trace_2[:, t] + (-trace_2[:, t] + trace_1[:, t])*dt/t_decay

  return trace_2

# Eligibility Trace
def eligibility_trace3(hebbian, args):
  """
  Evaluate the hebbian-coincidence based eligibility trace over all timesteps
  for all the given synaptic connections in the hebbian matrix using the 
  double exponential kernel.
  Inputs:
    hebbian - 2-D matrix of shape: (nb_inputs, nb_outputs)
    args: ['timestep_size', 't_rise_alpha', 't_decay_alpha', 'nb_steps']
  Returns:
    Eligibilty trace matrix of shape: (nb_inputs, nb_outputs, nb_timesteps)
  """
  dt = args['timestep_size']
  t_rise = args['t_rise_alpha']
  t_decay = args['t_decay_alpha']
  nb_timesteps = args['nb_steps']
  nb_inputs = hebbian.shape[0]
  nb_outputs = hebbian.shape[1]

  trace_1 = torch.zeros((nb_inputs, nb_outputs, nb_timesteps), device=device,
                       dtype=dtype)
  trace_2 = torch.zeros((nb_inputs, nb_outputs, nb_timesteps), device=device,
                        dtype=dtype)
  for t in range(nb_timesteps-1):
    trace_1[:, :, t+1] = trace_1[:, :, t] + (-trace_1[:, :, t]/t_rise + hebbian[:, :, t])*dt
    trace_2[:, :, t+1] = trace_2[:, :, t] + (-trace_2[:, :, t] + trace_1[:, :, t])*dt/t_decay
  
  return trace_2

# Error Signal
def error_signal3(output, target, args):
  """
  Evaluates the error signal by running the double exponential 
  kernel on the difference of the output and the target spike trains.
  Inputs:
    output - spike_train, shape: (nb_timesteps,)
    target - spike_train, shape: (nb_timesteps,)
    args:['timestep_size', 't_rise_alpha', 't_decay_alpha', 'nb_steps']
  Returns
    Error Signal Trace of shape: (nb_timesteps,)
  """
  t_rise = args['t_rise_alpha']
  t_decay = args['t_decay_alpha']
  dt = args['timestep_size']
  nb_timesteps = args['nb_steps']

  trace_1 = torch.zeros(nb_timesteps, device=device, dtype=dtype)
  trace_2 = torch.zeros(nb_timesteps, device=device, dtype=dtype)

  difference = target - output
  for t in range(nb_timesteps - 1):
    trace_1[t + 1] = trace_1[t] + (-trace_1[t]/t_rise + difference[t])*dt
    trace_2[t + 1] = trace_2[t] + (-trace_2[t] + trace_1[t])*dt/t_decay

  return trace_2

# Feedback Signal


# Poisson Train Generator
def Poisson_trains(n, lam, timesteps, dt):
  """
  Generates homogeneous poisson trains
  inputs:
    n - number of poisson spike trains 
    lam - 1-D array containing mean value of poisson trains
  Returns
    2-D array of shape (n,timesteps)
  """
  trains = torch.zeros((n, timesteps), device=device, dtype=dtype)
  unif = torch.rand((n, timesteps), device=device, dtype=dtype)
  for i in range(n):
    trains[unif <= lam[i]*dt] = 1

  return trains


In [None]:
#@title Plotting Functions
def plot_single_train(spike_train, nb_steps, timestep_size, idx=0):

  positions = np.arange(0, nb_steps)
  spike_positions = positions[spike_train == 1]
 # print(spike_positions)
  plt.eventplot(spike_positions, lineoffsets=idx)
  plt.xlim(0, nb_steps)
  #plt.show()

def plot_trains(spike_trains, title='Spike Trains'):
  plt.figure(dpi = 100)
  for i in range(len(spike_trains)):
    plot_single_train(spike_trains[i], nb_steps, timestep_size, idx=i)
  plt.title(title)
  plt.xlabel('Timestep')
  plt.ylabel('Spike Train No.')
  plt.show()


def plot_traces(eligiblity_rec, pre_synaptic_rec, input_trains, idx):
    
  j = idx

  fig, axs = plt.subplots(3, sharex=True, figsize=(15,10), dpi=120)

  axs[0].plot(eligibility_rec[j])
  axs[0].set_title("Eligibility Trace No." + str(j))
 
  positions = np.arange(0, nb_steps)
  spike_positions = positions[input_trains[j] == 1]
  axs[1].eventplot(spike_positions)
  axs[1].set_title("Corresponding Input Train No." + str(j))
  axs[1].set_xlim([0, nb_steps])

  axs[2].plot(pre_trace_rec[j])
  axs[2].set_title("Presynaptic Trace No." + str(j))

  for ax in axs.flat:
    ax.set(xlabel = "Timestep")
    # Hide x labels and tick labels for top plots and y ticks for right plots:
    ax.label_outer()

 # fig.show()

def plot_neuron_dynamics(mem_rec, spk_rec, error_rec, target):

  fig, axs = plt.subplots(4, sharex=True, figsize=(15, 10), dpi=120)

  ## Plot the target spike train
  positions = np.arange(0, nb_steps)
  spike_positions = positions[target == 1]
  axs[0].eventplot(spike_positions)
  axs[0].set_title("Target Spike Train")
  axs[0].set_xlim([0, nb_steps])

  ## Plot error signal
  axs[1].plot(error_rec)
  axs[1].set_title("Error Signal")

  ## Plot output spike train
  positions = np.arange(0, nb_steps)
  spike_positions = positions[spk_rec == 1]
  axs[2].eventplot(spike_positions)
  axs[2].set_title("Output Spike Train")
  axs[2].set_xlim([0, nb_steps])

  ## Plot membrane potential
  axs[3].plot(mem_rec)
  axs[3].set_title("Membrane Potential")
  axs[3].set_ylabel("Potential (in mV)")

  for ax in axs.flat:
    ax.set(xlabel = "Timestep")
    # Hide x labels and tick labels for top plots and y ticks for right plots:
    ax.label_outer()

 # fig.show()

### Task


In [None]:
#@title Input & Target

input_trains = poisson_generator()

In [None]:
#@title Weight Initialization

# Weights from Input Layer to Hidden Layer
w1 = torch.empty((nb_inputs, nb_hidden), device=device, dtype=dtype)
torch.nn.init.
# Weights from Hidden Layer to Output Layer
w2 = torch.empty((nb_hidden, nb_outputs), device=device, dtype=dtype))
torch.nn.init.

print("Weight Initialization Done")

In [8]:
a = [1, 2, 3]
print("A:", a)

b = [5, 4, 3]
print("B:", b)

a[2 in b] = 8
print("A:", a)

print(0 in b)
a[(0 in b)]

A: [1, 2, 3]
B: [5, 4, 3]
A: [8, 2, 3]
False


8

In [10]:
a = [[1, 2, 3], [6, 7, 8], [4,2,9]]
print(a[:, 1])

TypeError: ignored

In [None]:

epochs = 10

def run_snn(input_trains, w1, w2, args):
  """
  Input:
    input_trains
    w1 - input >> hidden layer weights
    w2 - hidden >> output layer weights
    args:['u_rest', 'thres']
  Returns:
    spk_rec_2 - final output spike train
    spk_rec_2 - spike train from the hidden layer
    mem_rec_2 - membrane potential recording from the output layer
    mem_rec_1 - membrane potential recording from the hidden layer
  """

  nb_inputs = w1.shape[0]
  nb_hidden = w1.shape[1]
  nb_outputs = w2.shape[1]
  u_rest = args['u_rest']
  thres = args['thres']


  mem_1 = torch.zeros(nb_hidden, dtype=dtype, device=device)
  syn_1 = torch.zeros(nb_hidden, dtype=dtype, device=device)
  mem_2 = torch.zeros(nb_outputs, dtype=dtype, device=device)
  syn_2 = torch.zeros(nb_outputs, dtype=dtype, device=device)

  # initialize lists to record values
  mem_rec_1 = []
  mem_rec_1 = []
  spk_rec_1 = []
  spk_rec_2 = []
  
  for t in range(nb_steps):
    # Spike
    out_1 = spike_fn(mem_1, thres)
    spk_rec_1.append(out_1)
    inp_2 = out_1 # input to the next layer

    # Reseting membrane potential upon spike
    reset_1 = torch.zeros(nb_hidden, dtype=dtype, device=device)
    for i in range(nb_hidden): # loop through individual neurons and set reset values based on pask activity
      spk_rec_1_i = [row[i] for row in spk_rec] # obtains the spiking activity of neuron of interest
      if t < 50:
        if 1 in spk_rec_1_i:
          reset_1[i] = 1
      elif 1 in spk_rec_1_i[-50:]:
        reset_1[i] = 1
      else:
        reset_1[i] = 0
      
    # evaluating new membrane potential and synaptic input 
    weighted_inp_1 = w1.T * input_trains[:, t] # final shape: (nb_hidden, )
    new_mem_1 = (beta*mem_1 + (1 - beta)*syn_1 + (1 - beta)*u_rest)*(1 - reset_1) + (reset_1 * u_rest)
    new_syn_1 = alpha*syn_1 + weighted_inp_1

    mem_rec_1.append(mem)

    mem_1 = new_mem_1
    syn_1 = new_syn_1

    # Readout Layer
    reset_2 = torch.zeros(nb_outputs, dtype=dtype, device=device)

    for i in range(nb_outputs):
      spk_rec_2_i = [row[i] for row in spk_rec_2]
      if t < 50:
        if 1 in spk_rec_2_i:
          reset_2[i] = 1
      elif 1 in spk_rec_2_i[-50:]:
        reset_2[i] = 1
      else:
        reset_2[i] = 0

    # evaluating new membrane potential and synaptic input
    out_2 = spike_fn(mem_2, thres)
    spk_rec_2.append(out_2)

    weighted_inp_2 = w2.T * inp_2 # final shape: (nb_outputs,)
    new_mem_2 = (beta*mem_2 + (1 - beta)*syn_2 + (1 - beta)*u_rest)*(1 - reset_2) + (reset_2 * u_rest)
    new_syn_2 = alpha*syn_2 + weighted_inp_2 # change input term

    mem_rec_2.append(mem_2)

    mem_2 = new_mem_2
    syn_2 = new_syn_2

return spk_rec2, spk_rec_1, mem_rec_2, mem_rec_1
