<a href="https://colab.research.google.com/github/divyanshgupt/Unreliable-Transmission/blob/main/SuperSpike_offline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch

In [1]:
#@title Dependencies
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
import torch
import torch.nn.Functional as F
import torch.nn as nn

ModuleNotFoundError: ignored

In [None]:
dtype = torch.float
device = torch.device("cpu")

# Uncomment the line below to run on GPU
device = torch.device("cuda:0") 

### Surrogate Gradient

In [None]:
#@title Surrogate Gradient

class SurrGradSpike(torch.autograd.Function):

  scale = 100.0 # controls the steepness of the gradient

  @staticmethod
  def forward(ctx, input):
    '''
    computes a step-function on the input. ctx is a context variable
    that stores information needed later for backpropagation
    '''
    ctx.save_for_backward(input)
    out = torch.zeros_like(input)
    out[input > 0] = 1
    return out

  @staticmethod
  def backward(ctx, grad_output):
    '''
    In the backward method, we recieve a tensor we need to compute 
    the surrogradient of the loss with respect to the input. 
    Here we use the negative half of the fast sigmoid as in 
    Zenke & Ganguli 2018.
    
    '''
    input, _ = ctx.saved_tensors
    grad_input = grad_output.clone()
    grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
    return grad

# overwrite the spike function with the surrograte gradient function
# using the apply method
spike_fn = SurrGradSpike.apply

### Single Neuron

In [4]:
nb_inputs  = 100
nb_outputs = 1

nb_steps = 5000
timestep_size = 1e-4 # 0.1 msec timesteps

In [None]:
#@title LIF Neuron Model Parameters
args = {'thres': -50,
        'U_rest': -60,
        'tau_mem': 1e-2,
        'tau_syn': 5e-3,
        'tau_ref': 5e-3,
        't_rise': 5e-3, # the pre-synaptic double exponential kernel rise time
        't_decay': 1e-2, # the pre-synaptic double exponential kernel decay time
        'timestep_size': 1e-4} 

tau_syn = args['tau_syn']
tay_mem = args['tau_mem']

alpha = float(np.exp(timestep_size/tau_syn))
beta = float(np.exp(timestep_size/tau_mem))

In [None]:
#@title Input Spike Trains

spk_freq = 10 # not sure about this, but assuming it since the paper
# uses 10 Hz frequency as the target output frequency (actually 
# 5 equidistant spikes over 500 ms)

input_trains = Poisson_trains(100, spk_freq*np.ones(100),
                              nb_steps, timestep_size)

In [8]:
#@title Target Spike Trains

target = torch.zeros(nb_steps)
target[:: 1000] = 1
print(target)


tensor([1., 0., 0.,  ..., 0., 0., 0.])
1
1
1
1
1


In [None]:
#@title Weight Initialization

weight_scale = 7*(1 - beta) # copied from spytorch

weights = torch.empty((nb_inputs, nb_outputs), device=device, dtype=dtype, requires_grad=True)
torch.init.nn.normal_(weights, mean=0.0, std=weight_scale/np.sqrt(nb_inputs))

In [None]:
#@title Poisson Train Generator
def Poisson_trains(n, lam, timesteps, dt):
  """

  inputs:
    n - number of poisson spike trains 
    lam - 1-D array containing mean value of poisson trains
  Returns

  """
  trains = torch.zeros((n, timesteps), device=device, dtype=dtype)
  unif = torch.rand((n, timesteps), device=device, dtype=dtype)

#  counter = 0
  for i in range(n):
    trains[unif <= lam[i]*dt] = 1
#    counter += len(unif <= lam[i]*dt)
#  print("Total No. of Spikes", counter)

  return trains

In [None]:
#@title SNN
def run_snn(inputs):

  h1 = torch.einsum('abc,cd->abd', (inputs,w1))
  syn = torch.zeros((batch_size, nb_hidden), device=device, dtype=dtype)
  mem = torch.zeros((batch_size, nb_hidden), device=device, dtype=dtype)
  # lists to record the membrane potentials and the synaptic currents:
  mem_rec = []
  spk_rec = []
  # loop to simulate time
  for t in range(nb_steps):
    mthr = mem - 1.0
    out = spike_fn(mthr)
    rst = out.detach()  # do not want to backpropagate through reset

    new_syn = alpha*syn + h1[:, t]
    new_mem = (beta*mem + syn)(1 - rst)

    mem_rec.append(mem)
    spk_rec.append(out)

    mem = new_mem
    syn = new_syn
  
  # create tensors to stack the elements in the recording lists
  mem_rec = torch.stack(mem_rec, dim=1)
  spk_rec = torch.stack(spk_rec, dim=1)

  # readout layer
  h2 = torch.einsum('abc,cd->abd', (spk_rec, w2))
  flt = torch.zeros((batch_size, nb_outputs), device=device, dtype=dtype)
  out = torch.zeros((batch_size, nb_outputs), device=device, dtype=dtype)
  out_rec = [out]
  for t in range(nb_steps):
    new_flt = alpha*flt + h2[:, t]
    new_out = beta*out + flt

    flt = new_flt
    out = new_out

    out_rec.append(out)

  out_rec = torch.stack(out_rec, dim=1)
  other_recs = [mem_rec, spk_rec]
  return out_rec, other_recs
