<a href="https://colab.research.google.com/github/narduzzi/AMLD2025-SpikingTutorial/blob/master/AMLD2025_TinyML_Workshop_NeuromorphicTutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



```
# Copyright (c) 2025 Simon Narduzzi
# Licensed under the MIT License (https://opensource.org/licenses/MIT)
```



# AMLD 2025 - Neuromorphic Computing Tutorial



This Google Colab notebook provides a tutorial on neuromorphic computing using the Sinabs and Tonic libraries. It covers the following parts:

- Introduction to Spiking Neural Networks (SNNs)
- Event-based data processing
- SNN models (IAF, LIF, ALIF)
- Backpropagation and surrogate gradients
- Training and evaluation of SNNs
- Comparison of computational cost and accuracy.

The notebook also includes hands-on examples for users to practice. While GPU use is recommended, it is not mandatory to rely on GPU to run the code contained in this notebook.

In [None]:
# First, install the requirements
!pip install matplotlib seaborn numpy sinabs torchvision scikit-learn
!pip install git+https://github.com/neuromorphs/tonic.git@775e1ce5e0ffaeb42b43cc54cfed3ffb490809e7

In [None]:
# import libraries
import tonic
import sinabs
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

# set style for the notebook
custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", rc=custom_params)

colors = ["#00a1e5", "#ffcc33", "#bfd100", "#d61e5c", "#878787", "#003264", "#3fbac1"]
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=colors)

# Part 1: Introduction to Spiking Neural Networks

## Event-based signal

Event-based signals are fundamentally different from traditional continuous signals. Instead of representing information as a continuous stream of values, event-based signals, also known as spike trains, represent information as a binary sequence (0 or 1) of discrete events or spikes that occur at specific times.

Event-based signals are usually sparse (>90% sparsity) time-series. The event trigger computation downstream only when they are received by the neurons.


In [None]:
# generate random spikes
np.random.seed(30) # for repeatability
n_neurons = 10
t_max = 200 # ms
stimuli = np.random.random((n_neurons, t_max)) > 0.9

neuron_idx, t_idx = np.where(stimuli)
# plot the sample
plot_trace = False
neuron_of_interest = 3

fig, axes = plt.subplots(1,2 if plot_trace else 1)
axes = [axes] if not plot_trace else axes

fig.set_figwidth(15)
fig.set_figheight(5)
axes[0].scatter(t_idx, neuron_idx, s=10, color="k")

idx_of_neuron_of_interest = np.where(neuron_idx == neuron_of_interest)

axes[0].scatter(t_idx[idx_of_neuron_of_interest], neuron_idx[idx_of_neuron_of_interest], s=10, color=colors[0], label="Neuron of interest: #{}".format(neuron_of_interest))
axes[0].set_title("Random spike trains")
axes[0].set_xlabel("Time (ms)")
axes[0].set_ylabel("Neuron idx")

# plot one single spike train
if plot_trace:
  axes[1].plot(stimuli[neuron_of_interest], color="b")
  axes[1].set_title("Values of neuron #{}".format(neuron_of_interest))
  axes[1].set_xlabel("Time (ms)")
  axes[1].set_ylabel("Value")

plt.legend()
plt.show()

Random spike trains are generated. In the next section, we will focus on a single neuron spike train that will be fed into different neuron types.

## Spiking Neuron Models

In this section, we will explore different models of feed-forward spiking neurons.


In [None]:
%%html
<svg width="600" height="200">
    <!-- Synapse -->
    <line x1="50" y1="100" x2="250" y2="100" stroke="#2c3e50" stroke-width="6"/>
    <!-- Soma -->
    <circle cx="250" cy="100" r="30" stroke="#2c3e50" stroke-width="6" fill="#888888"/>
    <!-- Axon -->
    <line x1="280" y1="100" x2="500" y2="100" stroke="#2c3e50" stroke-width="6"/>
    <polygon points="500,90 520,100 500,110" fill="#2c3e50"/>

    <!-- Synaptic Spikes -->
    <line x1="70" y1="70" x2="70" y2="90" stroke="#00a1e5" stroke-width="4"/>
    <line x1="90" y1="70" x2="90" y2="90" stroke="#00a1e5" stroke-width="4"/>
    <line x1="110" y1="70" x2="110" y2="90" stroke="#00a1e5" stroke-width="4"/>

    <!-- Axonal Spikes -->
    <line x1="420" y1="70" x2="420" y2="90" stroke="#d61e5c" stroke-width="4"/>
    <line x1="440" y1="70" x2="440" y2="90" stroke="#d61e5c" stroke-width="4"/>

    <!-- Synaptic Triangle -->
    <polygon points="160,65 180,85 150,100" fill="#00a1e5"/>

    <!-- Soma Triangle -->
    <polygon points="250,40 270,60 240,75" fill="#ffcc33"/>

    <!-- Labels -->

    <text x="50" y="130" font-size="16" font-family="Arial">Synapse</text>
    <text x="230" y="105" font-size="16" font-family="Arial">Soma</text>
    <text x="480" y="130" font-size="16" font-family="Arial">Axon</text>
</svg>


_**Generic spiking neuron model**_

The spiking neuron abstractions we will consider here have three main parts:

- **Synapse**: the synapse weights and filters the incoming spikes, leading to an input current that is fed to the soma. The value $\tau_{syn}$ describes the time-scale under which it is filtered.

- **Soma**: the soma is the core of the neuron. The input is accumulated in the membrane of the soma. The membrane has a threshold value and a reset value that trigger the spike dynamics. The membrane can have a leak on a timescale defined by $\tau_{mem}$. It can also include different mechanisms, such as adaptation.

- **Axon**: the axon is responsible of forwarding the spikes to the next neuron's synapses.

$ $


_**IAF, LIF and ALIF**_

Here, we explore three different types of neurons:
- **Integrate-and-Fire (IAF)**: Simple neuron, with only synaptic leak. The membrane has no leak and only accumulates the input.
- **Leaky Integrate-and-Fire (LIF)**: has a membrane leak term
- **Adaptive LIF (ALIF)**: has a threshold adaptation mechanism, which limits the firing rate by increasing the threshold after each spike. In case of bursts of inputs, the output spike frequency will be reduced (slowed-down firing).

Please refer to the [documentation of Sinabs](https://sinabs.readthedocs.io/en/v2.0.2/api/layers.html) for details about the formulas and implementations of these neurons.

In [None]:
# Define IAF, LIF and AdLIF layers. The 'record_states' argument is passed to store the membrane potential.
# The time constants (taus) are choose empirically for the purpose of this demonstration.
iaf = sinabs.layers.IAF(tau_syn=10., min_v_mem=0, spike_fn=sinabs.activation.SingleSpike,  record_states=True)
lif = sinabs.layers.LIF(tau_mem=20., tau_syn=10.,  min_v_mem=0, spike_fn=sinabs.activation.SingleSpike,  record_states=True, norm_input=False)
adlif = sinabs.layers.ALIF(tau_mem=20., tau_syn=10., tau_adapt=30., spike_fn=sinabs.activation.SingleSpike,  record_states=True, norm_input=False)

In [None]:
# Select the stimuli (spike train) of the neuron of interest
input_stimuli = torch.from_numpy(stimuli[neuron_of_interest])

input_x = input_stimuli.unsqueeze(0).unsqueeze(-1)
print("Input shape: {}".format(input_x.shape))

weight = 0.10 # arbitrary synaptic weight

# perform inference on each neuron type
output_iaf = iaf(input_x * weight)
output_lif = lif(input_x * weight)
output_adlif = adlif(input_x * weight)

# print size (B,T,U): They are time-series of 1 sample, 200 timesteps, 1 neuron.
print("IAF output: {}".format(output_iaf.shape))
print("LIF output: {}".format(output_lif.shape))
print("AdLIF output: {}".format(output_adlif.shape))

In [None]:
# Different neurons have different recording values
iaf.recordings.keys(), lif.recordings.keys(), adlif.recordings.keys()

In [None]:
def plot_recordings(i_spike, o_spike, recordings, keys, title, axes=None, column=None, continuous=False):
    """
    This function plots the recording of the neuron, bottom up.
    The bottom plots the input stimuli. The plot in the internal dynamics.
    The top plot is the output spikes.
    """
    if axes is None:
        fig, ax = plt.subplots(3,1)
        fig.set_figwidth(15)
        fig.set_figheight(5)

    # get spikes indices
    _, input_spike_t, _ = np.where(i_spike.detach().numpy())
    ax = axes[2, column]
    if continuous:
        ax.plot(i_spike.detach().numpy()[0,:,0], label="v_mem")
    else:
      ax.scatter(input_spike_t, np.ones(input_spike_t.shape), s=10, label="Input Spikes")
    ax.set_ylabel("Input")
    ax.set_yticks([])
    ax.set_xlabel("Time (ms)")

    _, output_spike_t, _ = np.where(o_spike.detach().numpy())
    ax = axes[0, column]
    ax.set_ylabel("Output")
    ax.set_yticks([])
    ax.set_title(title)
    ax.scatter(output_spike_t, np.ones(output_spike_t.shape)*1.02, s=10, c=colors[3], label="Output Spikes")

    ax = axes[1, column]
    for i,k in enumerate(keys):
        ax.plot(recordings[k].detach().numpy()[0,:,0], label=k, color=colors[i])
    ax.set_ylabel("Cell value")
    ax.legend()
    sns.despine()

In [None]:
# Plot the results
fig, axes = plt.subplots(3,3, sharex=True)
fig.set_figwidth(15)
fig.set_figheight(6)

plot_recordings(input_x, output_iaf, iaf.recordings, ["i_syn","v_mem"], "IAF", axes, column=0)
plot_recordings(input_x, output_lif, lif.recordings, ["i_syn","v_mem"], "LIF", axes, column=1)
plot_recordings(input_x, output_adlif, adlif.recordings, ["i_syn","v_mem", "spike_threshold"], "ALIF", axes, column=2)
plt.tight_layout()
plt.show()


Here, we observe that the IAF has the higher spike rate, as the membrane is never allowed to discharge.

The LIF membrane discharges when the input is silent, slowly going down to the resting potential. However, when a new spike happens, the membranes potential goes up again, and can reaches the threshold.

The ALIF neuron shows adaptation of the threshold, effectively leading to a lower firing rate at the input, compared to other neurons. However, the adaptation comes at the cost of additional operations happening in the neuron model.

We can also apply a fixed current instead of binary spikes. This will lead to regular firing in all models.

In [None]:
# Constant input current

# perform inference
# reshape to (batch, timesteps, dim)
input_stimuli = torch.from_numpy(np.ones(200))

input_x = input_stimuli.unsqueeze(0).unsqueeze(-1)
print("Input shape: {}".format(input_x.shape))

weight = 0.008  # synaptic weight

# reset membranes potentials
iaf.reset_states()
lif.reset_states()
adlif.reset_states()

output_iaf = iaf(input_x * weight)
output_lif = lif(input_x * weight)
output_adlif = adlif(input_x * weight)

# Plot the results
fig, axes = plt.subplots(3,3, sharex=True)
fig.set_figwidth(15)
fig.set_figheight(5)

plot_recordings(input_x, output_iaf, iaf.recordings, ["i_syn","v_mem"], "IAF", axes, column=0, continuous=True)
plot_recordings(input_x, output_lif, lif.recordings, ["i_syn","v_mem"], "LIF", axes, column=1, continuous=True)
plot_recordings(input_x, output_adlif, adlif.recordings, ["i_syn","v_mem", "b"], "ALIF", axes, column=2, continuous=True)
plt.tight_layout()
plt.show()

## Backpropagation in Spiking Neural Networks

When modeling spiking neural networks, The Dirac delta function represents an infinitely narrow impulse, often used to model an instantaneous event or spike.

$$
S(t)= \sum_i \delta(t−t_i)
$$

In computer simulation, instead of the infinite impulse of the Dirac function, we use the Heaviside function to represent the threshold crossing function, which output a value of $1$:

$$
\mathrm{H}(x) = \begin{cases}
    0 & x < 0 \\
    1 & x >= 0
\end{cases}
$$

The reset mechanisms takes care of bringing the membrane potential back to the resting potential.

However, the networks can not be trained directly on spikes, as the derivative of the Heaviside function is not differentiable. This makes the training of spiking networks with backpropagation impossible.

**Surrogate Gradients come to the rescue**



Surrogate gradients are a technique used to overcome this challenge. They are differentiable approximations of the non-differentiable spike functions. By using surrogate gradients, we can effectively "trick" backpropagation into working with SNNs, by switching the function on the backward pass.

In [None]:
# Functions that can be used as forward or backward. In spiking networks, Heaviside is usually used for binary thresholding.
def heaviside(x):
  return np.heaviside(x, 0)

def sigmoid(x):
  return 1 / (1 + np.exp(-x))

def fast_tanh(x):
  return np.tanh(x)

# Derivatives (can be used as surrogates)
def sigmoid_derivative(x):
  return sigmoid(x) * (1 - sigmoid(x))

def fast_tanh_derivative(x):
  return 1 - np.tanh(x)**2


In [None]:
fig, axes = plt.subplots(1,2, sharex=True, sharey=True)
x = np.linspace(-10, 10, 1000)

axes[0].plot(x, heaviside(x), label=r"$H(x)$")
axes[0].plot(x, sigmoid(x), "--", label=r"$f(x)$")
axes[0].plot(x, sigmoid_derivative(x), label=r"$\frac{df}{dx}(x)$")
axes[0].set_xlabel("x")
axes[0].set_ylabel("Function value")
axes[0].set_title("Sigmoid")
axes[0].grid(True)
axes[0].legend()

axes[1].plot(x, heaviside(x), label=r"$H(x)$")
axes[1].plot(x, fast_tanh(x), "--")
axes[1].plot(x, fast_tanh_derivative(x))
axes[1].set_xlabel("x")
axes[1].set_title("Fast Tanh")
axes[1].grid(True)
plt.tight_layout()
plt.show()

We can choose a surrogate function for the backward pass. The backpropagation algorithm would behave as if the forward was using the original activation $f(x)$. The selection of the best surrogate gradient is a hot topic in neuromorphic research.

# Part 2: Event-based audio dataset

In this section, we will explore the use of the Spiking Heidelberg Digits (SHD), an audio event-based dataset that classify recordings of digits in german and english (20 classes, 10 digits in german and 10 in english).

## Data loading and exploration

We use the [Tonic](https://github.com/neuromorphs/tonic) library to load the dataset. The download is handled by the library, which also provides transformation functions to prepare the dataset for the use with spiking neural networks.

In [None]:
dataset_train = tonic.datasets.SHD(save_to= "./data_shd", train= True)
dataset_test = tonic.datasets.SHD(save_to= "./data_shd", train= False)

print("Training samples:", len(dataset_train))
print("Testing samples:", len(dataset_test))

In [None]:
# plot one audio sample
events0, label0 = dataset_train[0]
events1, label1 = dataset_train[6]

fig, axes = plt.subplots(1,2, sharey=True)
axes[0].scatter(events0["t"], events0["x"], s=10)
axes[0].set_title("Label: {}".format(label0))

axes[1].scatter(events1["t"], events1["x"], s=10)
axes[1].set_title("Label: {}".format(label1))

axes[0].set_ylabel("Neuron idx (channel)")
axes[0].set_xlabel("Time (us)")
axes[1].set_xlabel("Time (us)")

plt.tight_layout()
plt.show()


The samples of the dataset come as lists of events (x,t,p), as it reduces the memory footprint of the dataset. However, neural networks can not read lists. We therefore have to transform the events in a readable (frame-based) format.

We first analyse the distribution of events, and trim the sequence to reduce the memory footprint.

In [None]:
# Plot distribution of events
from tqdm.notebook import tqdm

max_t = 0
all_train_t = []
for i in tqdm(range(len(dataset_train))):
  t = dataset_train[i][0]["t"].tolist()
  max_t = max(max_t, max(t))
  all_train_t+=t

all_test_t = []
for i in tqdm(range(len(dataset_test))):
  t = dataset_test[i][0]["t"].tolist()
  max_t = max(max_t, max(t))
  all_test_t+=t

In [None]:
hist_train, bin_edges_train = np.histogram(all_train_t, bins=100)
hist_test, bin_edges_test = np.histogram(all_test_t, bins=100)
# Compute the 99th percentile
percentile = 99
p_train = np.percentile(all_train_t, percentile)
p_test = np.percentile(all_test_t, percentile)

In [None]:
# Compute bin centers
bin_centers_train = (bin_edges_train[:-1] + bin_edges_train[1:]) / 2
bin_centers_test = (bin_edges_test[:-1] + bin_edges_test[1:]) / 2

# Plot histogram
fig, ax = plt.subplots(1,1)
axes = [ax]
# histogram
axes[0].bar(bin_centers_train, hist_train, width=np.diff(bin_edges_train), edgecolor=None, alpha=0.7, label="train")
axes[0].bar(bin_centers_test, hist_test, width=np.diff(bin_edges_test), edgecolor=None, alpha=0.7, label="test")
axes[0].set_xlabel("Time (s)")
axes[0].set_ylabel("Frequency")
axes[0].set_title("Histogram")

# plot percentile vline
axes[0].axvline(p_train, color=colors[0], linestyle='--', label="{}th percentile (train)".format(percentile))
axes[0].axvline(p_test, color=colors[1], linestyle='--', label="{}th percentile (test)".format(percentile))
# plot max_t
axes[0].axvline(max_t, color=colors[2], linestyle='--', label=r'$t_{max}$')

axes[0].legend()


plt.show()

99% of the events happen before 0.7s. The Maximum duration of the sequence is 1.4s. As networks learn using batch-representation, we trim the remaining part of the sequences which do not contain lots of information, effectively reducing the simulation time by 50%.

## Preprocessing of audio events

We load again the dataset, but this time we apply transforms to trim and transform the events to frames when accessing a sample.

To further reduce the memory footprint and reduce the simulation time, we downscale the number of channels by aggregating the signal, and accumulate the spikes in frames of 10ms. This is called binning. Each frame (bin) contains the spike counts of neighboring neurons that have spiked during the window of 10ms.

In [None]:
# create transforms
sensor_size = tonic.datasets.SHD.sensor_size # 700x1
downsample = 5
new_sensor_size = (sensor_size[0] // downsample, 1, 1)

max_duration = 0.7 * 1e6

frame_dt = 1e-3 # 1ms


crop_time_transform = tonic.transforms.CropTime(max=max_duration)

to_frame = tonic.transforms.ToFrame(
    sensor_size=sensor_size,
    time_window=frame_dt*1e6,
)

def bin_frame(downsample_t, downsample_c):
  """Downsamples the frame by summing the values of the frames"""
  def agg(frame):
    """Input of frame is T, W, C"""
    reshaped_frame = frame.reshape((frame.shape[0]//downsample_t, downsample_t, frame.shape[1], frame.shape[2]//downsample_c, downsample_c))
    agg_t = reshaped_frame.sum(axis=1)
    agg_c = agg_t.sum(axis=-1)
    return agg_c[:, 0, :] # in our case, we drop the polarity channel
  return agg


def pad_frame(frame_size):
  """Pads the frame to the given size"""
  def pad(frame):
    """Input of frame is T, W, C"""
    new_frame = np.zeros((frame_size[0], frame_size[1], frame_size[2]), np.float32)

    shapes = []
    for axis in range(3): # axis
        min_shape = min(frame_size[axis], frame.shape[axis])
        shapes.append(min_shape)

    new_frame[:shapes[0], :shapes[1], :shapes[2]] = frame[:shapes[0], :shapes[1], :shapes[2]]
    return new_frame
  return pad


# Create a composed transform that is applied to the events of the dataset
composed_transforms = tonic.transforms.Compose([
    crop_time_transform,
    to_frame,
    pad_frame((700,1,700)),
    bin_frame(10, 5),
])

dataset_train = tonic.datasets.SHD(save_to= "./data_shd", train= True, transform=composed_transforms)
dataset_test = tonic.datasets.SHD(save_to= "./data_shd", train= False, transform=composed_transforms)

In [None]:
events, label = dataset_train[0]
events.shape

In [None]:
# Transform to frame

events, label = dataset_train[0]

fig, axes = plt.subplots(1,1, sharey=True)
axes = [axes]
pos = axes[0].imshow(events[:,::-1].T, cmap="Blues")
axes[0].set_title("Label: {}".format(label))
axes[0].set_ylabel("Neuron idx (channel)")
axes[0].set_xlabel("Timestep (frame idx)")
fig.colorbar(pos, ax=axes[0], label="Spike count")

plt.tight_layout()
plt.show()


We now have sequences of 70 frames that can be fed as currents to the network.

# Part 3: Audio Classification using event-based data

Tonic also provides a way to cache the dataset on the disk for faster loading during the training time.

## Efficient data loading

In [None]:
# train dataset caching
train_audio_dataloader = DataLoader(
    dataset_train,
    shuffle=True,
    batch_size=128,
    collate_fn=tonic.collation.PadTensors(batch_first=True),
)

cached_dataset = tonic.DiskCachedDataset(dataset_train, cache_path="./cache/fast_dataloading_train")
train_audio_dataloader = DataLoader(cached_dataset, batch_size=128, num_workers=2, drop_last=True)

# test dataset caching
test_audio_dataloader = DataLoader(
    dataset_test,
    shuffle=True,
    batch_size=128,
    collate_fn=tonic.collation.PadTensors(batch_first=True),
)

cached_dataset = tonic.DiskCachedDataset(dataset_test, cache_path="./cache/fast_dataloading_test")
test_audio_dataloader = DataLoader(cached_dataset, batch_size=128, num_workers=2, drop_last=True)

## Feed forward Neural Network

We can now define our first network, which will consist in IAF neurons with a surrogate function consisting of the derivative of the exponential function.

In [None]:
# Let's first define three models using Sinabs
from torch import nn
import sinabs.layers as sl

spike_fn = sinabs.activation.SingleSpike
surr_fn = sinabs.activation.SingleExponential(grad_width=0.5, grad_scale=1.0)

model_iaf= nn.Sequential(
      nn.Linear(140, 64, bias=False), # weights (synapses) input-> 64 neurons
      sl.IAF(min_v_mem=0, tau_syn=1, spike_fn=spike_fn, surrogate_grad_fn=surr_fn),
      nn.Linear(64, 64, bias=False),
      sl.IAF(min_v_mem=0, tau_syn=1, spike_fn=spike_fn, surrogate_grad_fn=surr_fn),
      nn.Linear(64, 20, bias=False),
      sl.IAF(min_v_mem=0, tau_syn=1, spike_fn=spike_fn, surrogate_grad_fn=surr_fn),
  )

In [None]:
data, targets = next(iter(train_audio_dataloader))
pred = model_iaf(data)

print("Shapes: input {} - output {}".format(data.shape, pred.shape))

Sinabs also provides a faster implementation of the model, by parallelizing the timesteps. This "flattens" the time dimension, leading to vectors of lenghts $Batch size \times Timesteps$. At the end, the output is transformed again to the original format (B, T, output dimension).

In [None]:
# Let's first define three models using Sinabs
from torch import nn
import sinabs.layers as sl

optimized = True
# for optimized version, use this:
if optimized:
  model_iaf= nn.Sequential(
      sl.FlattenTime(), # for parallelization on GPU
      nn.Linear(140, 64, bias=False),
      sl.IAFSqueeze(batch_size=128, min_v_mem=0, tau_syn=1, spike_fn=spike_fn, surrogate_grad_fn=surr_fn),
      nn.Linear(64, 64, bias=False),
      sl.IAFSqueeze(batch_size=128, min_v_mem=0, tau_syn=1, spike_fn=spike_fn, surrogate_grad_fn=surr_fn),
      nn.Linear(64, 20, bias=False),
      sl.IAFSqueeze(batch_size=128, min_v_mem=0, tau_syn=1, spike_fn=spike_fn, surrogate_grad_fn=surr_fn),
      sl.UnflattenTime(batch_size=128),
  )

In [None]:
data_input, targets = next(iter(train_audio_dataloader))
pred = model_iaf(data_input)

print("Shapes: input {} - output {}".format(data_input.shape, pred.shape))

Finally, Sinabs provides a few helper functions to estimate the computational cost (Synaptic operations - SynOps) of the model.

In [None]:
def get_stats(model, data):
  analyzer = sinabs.synopcounter.SNNAnalyzer(model)
  output = model(data)  # forward pass
  model_stats = analyzer.get_model_statistics()
  return model_stats

In [None]:
model_stats_iaf_before = get_stats(model_iaf, data_input)
model_stats_iaf_before

## Network Training

Now that everything is setup, let's train the model for a few epochs using backpropagation.

In [None]:
from tqdm.notebook import tqdm

def train(model, dataloader, n_epochs, optimizer, crit):
  model.train()
  if torch.cuda.is_available():
    model.cuda()
  for epoch in range(n_epochs):
      losses = []
      for data, targets in tqdm(train_audio_dataloader):

        # check if cuda
        if torch.cuda.is_available():
          data, targets = data.cuda(), targets.cuda()

        sinabs.reset_states(model)  # each synapse and neuron membranes are reset to zero
        optimizer.zero_grad()
        y_hat = model(data)
        pred = y_hat.sum(1)
        loss = crit(pred, targets)
        loss.backward()
        losses.append(loss)
        optimizer.step()

      print(f"Loss: {torch.stack(losses).mean()}")

  return model

In [None]:
n_epochs = 3
optimizer = torch.optim.Adam(model_iaf.parameters(), lr=1e-3)
crit = nn.functional.cross_entropy

train(model_iaf, train_audio_dataloader, n_epochs, optimizer, crit)

## Evaluation and comparison

Now that the network is trained, let's evaluated it on both the training set and the test set.

In [None]:
# evaluate model accuracy
from sklearn.metrics import accuracy_score

def eval_model(model, dataloader):
  model.eval()
  if torch.cuda.is_available():
    model.cuda()
  all_targets = []
  all_preds = []
  with torch.no_grad():
    test_losses = []
    for data, targets in tqdm(dataloader):
      if torch.cuda.is_available():
        data, targets = data.cuda(), targets.cuda()

      sinabs.reset_states(model)  # each synapse and neuron membranes are reset to zero
      y_hat = model(data)
      pred = y_hat.sum(1)

      loss = crit(pred, targets)
      test_losses.append(loss)

      all_targets.append(targets.cpu().numpy())
      all_preds.append(pred.cpu().numpy())

  # print the accuracy
  all_targets = np.concatenate(all_targets)
  all_preds = np.concatenate(all_preds)

  total_loss = torch.stack(test_losses).mean()
  acc = accuracy_score(all_targets, np.argmax(all_preds, axis=1))
  return acc, total_loss

In [None]:
acc_train_iaf, loss_train = eval_model(model_iaf, train_audio_dataloader)
acc_test_iaf, loss_test = eval_model(model_iaf, test_audio_dataloader)

print("Train accuracy: {:.2f}%".format(acc_train_iaf*100))
print("Test accuracy: {:.2f}%".format(acc_test_iaf*100))

Not bad, given the few epochs.

In [None]:
model_stats_iaf = get_stats(model_iaf, data_input)
for k, v in model_stats_iaf.items():
  print("{}: before = {:.2f} / after = {:.2f}".format(k, v, model_stats_iaf_before[k]))

We can see that the firing rate did not change much. Consequently, the number of synaptic operations stayed stable.
Therefore, the spiking patterns (spatio-temporal distribution) transmit the information.

## Comparison of computational cost of Spiking Models

We can also compare with the other neuron models: LIF and ALIF.

In [None]:
# LIF Model

model_lif = nn.Sequential(
    sl.FlattenTime(),
    nn.Flatten(),
    nn.Linear(140, 64, bias=False),
    sl.LIFSqueeze(batch_size=128, min_v_mem=0, tau_syn=1, tau_mem=2),
    nn.Linear(64, 64, bias=False),
    sl.LIFSqueeze(batch_size=128, min_v_mem=0, tau_syn=1, tau_mem=2),
    nn.Linear(64, 20, bias=False),
    sl.LIFSqueeze(batch_size=128, min_v_mem=0, tau_syn=1, tau_mem=2),
    sl.UnflattenTime(batch_size=128)
)

In [None]:
train(model_lif, train_audio_dataloader, n_epochs, optimizer, crit)
print("Test...")
acc_train_lif, loss_train = eval_model(model_lif, train_audio_dataloader)
acc_test_lif, loss_test = eval_model(model_lif, test_audio_dataloader)

print("Train accuracy: {:.2f}%".format(acc_train_lif*100))
print("Test accuracy: {:.2f}%".format(acc_test_lif*100))

In [None]:
model_alif = nn.Sequential(
    nn.Linear(140, 64, bias=False),
    sl.ALIF(min_v_mem=0., tau_syn=1., tau_mem=2., tau_adapt=1.0),
    nn.Linear(64, 64, bias=False),
    sl.ALIF(min_v_mem=0., tau_syn=1., tau_mem=2., tau_adapt=1.0),
    nn.Linear(64, 20, bias=False),
    sl.ALIF(min_v_mem=0., tau_syn=1., tau_mem=2., tau_adapt=1.0),
)

In [None]:
train(model_alif, train_audio_dataloader, n_epochs, optimizer, crit)

print("Test...")
acc_train_alif, loss_train = eval_model(model_alif, train_audio_dataloader)
acc_test_alif, loss_test = eval_model(model_alif, test_audio_dataloader)

print("Train accuracy: {:.2f}%".format(acc_train_alif*100))
print("Test accuracy: {:.2f}%".format(acc_test_alif*100))

... and plot their computational cost and accuracy.

In [None]:
# Evaluation of spiking models
def get_stats(model, data):
  analyzer = sinabs.synopcounter.SNNAnalyzer(model)
  output = model(data)  # forward pass
  model_stats = analyzer.get_model_statistics()
  return model_stats

data_input, labels = next(iter(train_audio_dataloader))

stats_iaf = get_stats(model_iaf, data_input)
stats_lif = get_stats(model_lif, data_input)
stats_alif = get_stats(model_alif, data_input)

In [None]:
# plot SynOps vs Accuracy and Firing rate

fig, axes = plt.subplots(1,2, sharey=True)
axes[0].scatter(stats_lif["synops"].detach().numpy(), acc_test_lif, label="LIF")
axes[0].scatter(stats_iaf["synops"].detach().numpy(), acc_test_iaf, label="IAF")
axes[0].scatter(stats_alif["synops"].detach().numpy(), acc_test_alif, label="ALIF")
axes[0].set_xlabel("SynOps")
axes[0].set_ylabel("Accuracy")
axes[0].set_title("Accuracy vs SynOps")
axes[0].legend()

axes[1].scatter(stats_lif["firing_rate"].detach().numpy(), acc_test_lif, label="LIF")
axes[1].scatter(stats_iaf["firing_rate"].detach().numpy(), acc_test_iaf, label="IAF")
axes[1].scatter(stats_alif["firing_rate"].detach().numpy(), acc_test_alif, label="ALIF")
axes[1].set_xlabel("Firing rate")
axes[1].set_ylabel("Accuracy")
axes[1].set_title("Accuracy vs Firing rate")
axes[1].legend()

plt.tight_layout()
plt.show()


# Conclusion

We have successfully trained different spiking models on an audio spiking dataset using backpropagation. Each model have different properties. Neuromorphic engineers should carefully tune the hyperparameters of the neurons to balance efficiency and accuracy, and select the network model suited for the hardware at hand.

While synaptic operations (SynOps) give insights about the model complexity, they often hide important computational cost. Alternative metrics are currently investigated by researchers.

# Bonus - Part 4: Image processing - the DVSGestures event-based dataset

### Dataset loading and exploration

In [None]:
dataset_train = tonic.datasets.DVSGesture(save_to= "./data", train= True)
dataset_test = tonic.datasets.DVSGesture(save_to= "./data", train= False)

print("Training samples:", len(dataset_train))
print("Testing samples:", len(dataset_test))

In [None]:
import warnings
import logging

logging.getLogger("matplotlib").setLevel(logging.ERROR)  # Suppress matplotlib warnings
warnings.filterwarnings("ignore", category=UserWarning)  # Ignores UserWarnings

original_events, label = dataset_train[0]

transform = tonic.transforms.ToFrame(
    sensor_size=tonic.datasets.DVSGesture.sensor_size,
    time_window=10000,
)

frames = transform(original_events)
animation = tonic.utils.plot_animation(frames)

In [None]:
# Display the animation inline in a Jupyter notebook
from IPython.display import HTML
HTML(animation.to_jshtml())

In [None]:
# Plot distribution of events
from tqdm.notebook import tqdm

max_t = 0
all_train_t = []
for i in tqdm(range(len(dataset_train))):
  t = dataset_train[i][0]["t"].tolist()
  max_t = max(max_t, max(t))
  t = t[::10]  # take only every 10 events, because of limited memory in notebook
  all_train_t+=t

all_test_t = []
for i in tqdm(range(len(dataset_test))):
  t = dataset_test[i][0]["t"].tolist()
  max_t = max(max_t, max(t))
  t = t[::10]  # take only every 10 events, because of limited memory in notebook
  all_test_t+=t

In [None]:
hist_train, bin_edges_train = np.histogram(all_train_t, bins=100)
hist_test, bin_edges_test = np.histogram(all_test_t, bins=100)
# Compute the 95th percentile
percentile = 99
p_train = np.percentile(all_train_t, percentile)
p_test = np.percentile(all_test_t, percentile)

In [None]:
# Compute bin centers
bin_centers_train = (bin_edges_train[:-1] + bin_edges_train[1:]) / 2
bin_centers_test = (bin_edges_test[:-1] + bin_edges_test[1:]) / 2

# Plot histogram
fig, ax = plt.subplots(1,1)
axes = [ax]
# histogram
axes[0].bar(bin_centers_train, hist_train, width=np.diff(bin_edges_train), edgecolor=None, alpha=0.7, label="train")
axes[0].bar(bin_centers_test, hist_test, width=np.diff(bin_edges_test), edgecolor=None, alpha=0.7, label="test")
axes[0].set_xlabel("Time (us)")
axes[0].set_ylabel("Frequency")
axes[0].set_title("Histogram")

# plot percentile vline
axes[0].axvline(p_train, color=colors[0], linestyle='--', label="{}th percentile (train)".format(percentile))
axes[0].axvline(p_test, color=colors[1], linestyle='--', label="{}th percentile (test)".format(percentile))
# plot max_t
axes[0].axvline(max_t, color=colors[2], linestyle='--', label=r'$t_{max}$')

axes[0].legend()


plt.show()

### Dataset preprocessing

In [None]:
sensor_size = tonic.datasets.DVSGesture.sensor_size

In [None]:
# As shown above, 99% of the events happen before 10s, while the maximum duration of a frame is about 18sec.
# Therefore, we will first crop each sequence before 10sec, and then split every sample in 1sec.
# This way, we will classify 1sec samples.
# For the purpose of this tutorial, we will also downsample the images by 4, leading to 64x64x2 samples

In [None]:
# create transforms
sensor_size = tonic.datasets.DVSGesture.sensor_size # 128x128x2
downsample = 2
new_sensor_size = (sensor_size[0] // downsample, sensor_size[1] // downsample, 2)

crop_time_transform = tonic.transforms.CropTime(max=10e6)
downsample_transform = tonic.transforms.Downsample(spatial_factor=1.0/downsample)

# Create a composed transform that is applied to the events of the dataset
composed_transforms = tonic.transforms.Compose([
    crop_time_transform,
    downsample_transform
])

dataset_train = tonic.datasets.DVSGesture(save_to= "./data", train= True, transform=composed_transforms)
dataset_test = tonic.datasets.DVSGesture(save_to= "./data", train= False, transform=composed_transforms)

In [None]:
# plot a new sample
transformed_events, label = dataset_train[0]

transform = tonic.transforms.ToFrame(
    sensor_size=new_sensor_size,
    time_window=10e3,
)

frames_transformed = transform(transformed_events)
animation_transformed = tonic.utils.plot_animation(frames_transformed)

In [None]:
transformed_events

In [None]:
from IPython.display import HTML
HTML(animation_transformed.to_jshtml())

In [None]:
from tonic import SlicedDataset

frame_transform = tonic.transforms.ToFrame(
    sensor_size=new_sensor_size,
    time_window=1e5,
)

slicing_time_window = 1e6 # microseconds
slicer = tonic.slicers.SliceByTime(time_window=slicing_time_window)

# slice the train dataset
sliced_dataset_train = SlicedDataset(
    dataset_train, slicer=slicer, metadata_path="./metadata/dvs_train",
    transform=frame_transform,
)

# slice the test dataset
sliced_dataset_test = SlicedDataset(
    dataset_test, slicer=slicer, metadata_path="./metadata/dvs_test",
    transform=frame_transform,
)

In [None]:
print("Sliced datasets now lead to:")
print("Training: {} vs {}".format(len(dataset_train), len(sliced_dataset_train)))
print("Testing: {} vs {}".format(len(dataset_test), len(sliced_dataset_test)))

In [None]:
batch_size = 128
# make sure every sequence has the same size
trainloader = DataLoader(
    sliced_dataset_train,
    shuffle=True,
    batch_size=batch_size,
    collate_fn=tonic.collation.PadTensors(batch_first=True),
)

# make sure every sequence has the same size
testloader = DataLoader(
    sliced_dataset_train,
    shuffle=True,
    batch_size=batch_size,
    collate_fn=tonic.collation.PadTensors(batch_first=True),
)

In [None]:
frames, targets = next(iter(trainloader))
frames.shape, targets.shape

### Sample batch

In [None]:
# plot the batch

fig, axes = plt.subplots(3,9)
fig.set_figwidth(15)
fig.set_figheight(5)

# create empty RGB frame holder
frame_holder = np.zeros((batch_size,9,3,new_sensor_size[0],new_sensor_size[1]))
frame_holder[:,:, 1, :,:] = frames[:,:, 0, :,:]
frame_holder[:,:, 2, :,:] = frames[:,:, 1, :,:]


for b_idx in range(3):
    for f_idx in range(frames.shape[1]):
        ax = axes[b_idx, f_idx]
        frame_to_show = frame_holder[b_idx, f_idx, :,:,:]
        ax.imshow(np.moveaxis(frame_to_show, 0, 2))

plt.show()

# Bonus - Part 5: Training a Spiking Neural Network for Vision Applications

In [None]:
import sinabs.layers as sl
import sinabs.exodus.layers as sel
from torch import nn

backend = sl # Sinabs
backend = sel # Sinabs EXODUS

model_iaf = nn.Sequential(
    sl.FlattenTime(),
    nn.Conv2d(2, 8, kernel_size=5, padding=1, bias=False),
    backend.IAFSqueeze(batch_size=batch_size, min_v_mem=-1),
    sl.SumPool2d(2),
    nn.Conv2d(8, 16, kernel_size=3, padding=1, bias=False),
    backend.IAFSqueeze(batch_size=batch_size, min_v_mem=-1),
    sl.SumPool2d(2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1, bias=False),
    backend.IAFSqueeze(batch_size=batch_size, min_v_mem=-1),
    sl.SumPool2d(2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
    backend.IAFSqueeze(batch_size=batch_size, min_v_mem=-1),
    sl.SumPool2d(2),
    nn.Conv2d(64, 11, kernel_size=3, padding=0, bias=False),
    backend.IAFSqueeze(batch_size=batch_size, min_v_mem=-1),
    nn.Flatten(),
    sl.UnflattenTime(batch_size=batch_size),
).cuda()

In [None]:
y_hat = model_iaf(frames.cuda())
y_hat.shape

In [None]:
from tqdm.notebook import tqdm


n_epochs = 1
optimizer = torch.optim.Adam(model_iaf.parameters(), lr=1e-3)
crit = nn.functional.cross_entropy

for epoch in range(n_epochs):
    losses = []
    for data, targets in tqdm(trainloader):
        data, targets = data.cuda(), targets.cuda()
        sinabs.reset_states(model_iaf)
        optimizer.zero_grad()
        y_hat = model_iaf(data)
        pred = y_hat.sum(1)
        loss = crit(pred, targets)
        loss.backward()
        losses.append(loss)
        optimizer.step()
    print(f"Loss: {torch.stack(losses).mean()}")