# Tutorial for using Spiking Neural Networks to Read Lips!

By Noah Baldonado

You should read the paper https://arxiv.org/pdf/2109.12894:
J. K. Eshraghian et al., "Training Spiking Neural Networks Using Lessons From Deep Learning," in Proceedings of the IEEE, vol. 111, no. 9, pp. 1016-1054, Sept. 2023, doi: 10.1109/JPROC.2023.3308088.
keywords: {Deep learning;Neuromorphics;Neurons;Biological neural networks;Training;Brain modeling;Australia;Electronic learning;Brain modeling;Tutorials;Deep learning;neural code;neuromorphic;online learning;spiking neural networks (SNNs)},

And the paper where this dataset is from: https://openaccess.thecvf.com/content/CVPR2022/papers/Tan_Multi-Grained_Spatio-Temporal_Features_Perceived_Network_for_Event-Based_Lip-Reading_CVPR_2022_paper.pdf

I based a lot of code from [this tutorial](https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_7_neuromorphic_datasets.ipynb) by Gregor Lenz and Jason K. Eshraghian

## The Problem:
These researchers from this paper ([here](https://openaccess.thecvf.com/content/CVPR2022/papers/Tan_Multi-Grained_Spatio-Temporal_Features_Perceived_Network_for_Event-Based_Lip-Reading_CVPR_2022_paper.pdf)) made an event camera which captures changes brightness for individual pixels at a high rate, instead of a standard, frame based camera which has a lower rate and contains redundant information as it records frames which contain pixels that might not change. They had people say different words, and recorded almost 20,000 samples of people's faces as they said them, in the hopes of using this data to train a model to read lips.

After creating this special database, they created and trained a complicated model, a "Multi-grained Spatio-Temporal Features Perceived Network" (MSTP) to be able to determine what words people were saying. It used a multi-branch architecture with different frame rates between the two branches, and created a message flow module to connect information between them. Each branch does a 3d convolution, then has several residual blocks, and the branch with more frames has fewer channels. They eventually lead to a sequence model and then it predicts the class.

In this tutorial, instead of making this model, you will learn how to use spiking neural networks to do this task! First, you will see a short summary of how spiking neural networks work. Then, for the main part of the tutorial, you will learn how to use them on this dataset (DVS-Lip), to create a spiking neural network to accomplish the task of reading lips. I will then also explore some other options.

# What Are Spiking Neural Networks?

Spiking Neural Networks (SNNs) are neural networks that use spikes to pass along information, just like how the neurons in the brain do with action potentials. This contrasts from a normal artificial neural network (ANN), which passes a value from each neuron as the output. Also, SNNs work over time, just like the brain (so it is very different from a binarized neural network). As neurons receive spikes, their "membrane potential" increases, and over time, it decreases. When the membrane potential reaches a threshold, the neuron outputs a spike and the membrane potential jumps back down. The time-based nature of an SNN is similar to an RNN, which keeps a state, the membrane potential, and the output (spike) is fed back in for the reset.

# Why are spiking neural networks good?

They can be good for efficiency when you combine them with specialized hardware, as well as it's attractive how it more closely emulates how the brain works. For this dataset, we have information how each pixel changes, and the data is over a non-fixed time interval, so it is the perfect chance to use an SNN!

# How will this model work?
We will convert the data to a set of frames with information for how/if each pixel changes. At each time step, the next frame is fed into the network. For the final outputs, there are a few different ways you could choose to do it, but for this case, it will be with rate encoding, so the class that spikes the most is the predicted class. There are lots more details which will be explained during this tutorial, so now let's start coding!


# Data

The first step is to get the data. You will have to download it from [here](https://drive.google.com/file/d/1dBEgtmctTTWJlWnuWxFtk8gfOdVVpkQ0/view) and put it in your google drive. You can edit the following code cells with your file and folder paths. This is how I organized this, and if you do something different then you will need to edit the paths:


 - In my drive, I have a folder called lip.
 - In lip, place this tutorial, then make folders called models, data, cache.
 - In data, make a folder called DVSLip.
 - In there, place DVS-Lip, the unzipped version of the zip file.

# Ok the coding tutorial is starting below!

## Google Drive
First, you'll mount your google drive and then go into the lip folder

In [None]:
# This is to mount your google drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# This is the google drive folder where I put my notebook, models, data, and everything.
# You should make one like this too, and if it has a different name, change it here
%cd /content/drive/MyDrive/lip

## Tonic
We will use the tonic library to handle the dataset.

In [None]:
# We will use tonic to get this dataset and perform things on it.
!pip install tonic
import tonic
from tonic import MemoryCachedDataset
# import numpy also
import numpy as np

Load in the data...

In [None]:
# We use tonic.datasets.DVSLip. First, let's just load in the data to see how it is.
trainset = tonic.datasets.DVSLip(save_to="./data", train=True)

# How many samples?
print(f'{len(trainset)} samples in the training dataset')

Now, let's look at the first sample

In [None]:
sample = trainset[0]
print(f'The type of a sample is {type(sample)}')
print(f'Length: {len(sample)}')

The first item is the input, the second is the output. First, let's look at the input

In [None]:
sample_input = sample[0]
print(f'Type of input: {type(sample_input)}')
print(f'Length of input: {len(sample_input)}')
print(f'Input: {sample_input}')

Ok, so the input for this sample is a numpy array of length 1719. Each element is a tuple, let's look at the first and last ones.

In [None]:
print(sample_input[0])
print(sample_input[-1])

What is this? This tuple is of the form (x, y, p, t). x and y are for what pixel it is. p is the polarity which is for how that pixel is changing. Finally, t is the timestep. You can see here it goes all the way to 728283 for this example

## Now, let's look at the output

In [None]:
sample_output = sample[1]
print(f'Type of output: {type(sample_output)}')
print(f'Output: {sample_output}')

It is an integer, which is for what class the sample's input corresponds to. Let's look at this

In [None]:
classes = trainset.classes
print(f'# of classes: {len(classes)}')
print(f'Classes: {classes}')
print(f'Sample output {sample_output} means class {classes[sample_output]}')

Now that you have a sense for how the data is formatted right now, let's load it in for the purpose of our task. Also, we will start working with PyTorch, and with snnTorch, the library for spiking neural networks created by Professor Eshraghian. So let's download and import some things we will need.

## Downloading and Importing more


In [None]:
# pyTorch
import torch
import torch.nn as nn
# This is not used until we explore other options later
import torch.nn.functional as F
# This is used for lots of vision related stuff
import torchvision
# Dataloader
from torch.utils.data import DataLoader

# snnTorch
!pip install snntorch
# imports
import snntorch as snn
from snntorch import surrogate # for the surrogate function
from snntorch import functional as SF

# Some hyperparameters
Before loading the data we need to set the batch size, so let's also set the number of epochs

In [10]:
num_epochs = 200
# you might need to make the batch size smaller if it's not working this big
batch_size = 99

Now load the data correctly

In [None]:
# Get the sensor_size
sensor_size = tonic.datasets.DVSLip.sensor_size # should be (128, 128, 2)
print(f'Sensor size: {sensor_size}')

# The tonic.transforms.Compose is to apply transforms sequentially.
# Denoise is used to reduce the noise. If a nearby pixel does not have an event within filter_time from this event, it gets ignored.
# ToFrame is super important! The data is formatted right now as a sequence of events which is a tuple, but we want a sequence of frames
# The time window is the window across which events are combined to make a single frame. Right now, there are hundreds of thousands of timesteps which you saw when you looked at the last event in the first sample.
frame_transform = tonic.transforms.Compose([tonic.transforms.Denoise(filter_time=50000),
                                            tonic.transforms.ToFrame(sensor_size=sensor_size,
                                                                     time_window=30000)])

# Now, we create the trainset (and testset) applying these transforms
trainset = tonic.datasets.DVSLip(save_to="./data", transform=frame_transform, train=True)
testset = tonic.datasets.DVSLip(save_to="./data", transform=frame_transform, train=False)

# Ok now time to switch to pytorch

# To convert to float
class ToFloat(object):
  def __call__(self, sample):
    return sample.float()

# The first is to convert the input which is a numpy array into a torch tensor.
# The next one converts it to float, because the next one after that is normalizing the data and it takes floats.
# The last transform is done to make the data normalized
torch_transform = torchvision.transforms.Compose([
    torch.from_numpy,
    ToFloat(),
    torchvision.transforms.Normalize((0, 0), (1, 1))
    ])

# By using this to cache the data, it will speed up training.
# It may still take a long time for the first epoch, but then after that, it will go way faster.
cached_trainset = MemoryCachedDataset(trainset, transform=torch_transform)
cached_testset = MemoryCachedDataset(testset, transform=torch_transform)

# Here, we finally create the pytorch dataloaders we wil be using
# shuffle is needed because the data is not shuffled at all right now
# drop_left is important because otherwise, the late batch will be the wrong size and will cause an error
trainloader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True, drop_last=True)
# shuffle also used on the testloader beacuse it might take a while to test, and this way you can see more representative results right away
testloader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True, drop_last=True)

classes = trainset.classes
print(f'Training batches: {len(trainloader)}')
print(f'Testing batches: {len(testloader)}')

# Done loading data
Great, now, we can get to making the model! The plan is to make a convolutional, spiking neural network. Each neurons remembers its membrane potential from one timestep to the next.

At a timestep, the neural network will perform convolutions on the input, after which spiking neurons are applied so it is a spiking neural network. The convolutions will hopefully let the model extract important features. Then, after some convolutions, it's time for a fully connected layer (with spikes as well). Overall, this sequence of convolutions then a fully connected layer should seem very normal, it's just that spikes are added! SNNs and ANNs can have the same architecture, it is just that the neurons are different.

# How do they learn?
You might be wondering, how does the spiking neural network learn?

It uses back propagation, just like ANNs, but there's a problem for the derivative of these spikes, because they are either 0 (when there's not a spike) or infinity (when there is a spike)!

The solution is a surrogate gradient. On the forward pass, spikes are still used, but for the backward pass, it replaces the spike (the Heaviside function), with a smoothed out version of it. There are several options for this, in this case you will use the fast sigmoid function.

## Creating the model

Choose the sizes used in the network

In [12]:
# sizes
input_size = sensor_size[0] * sensor_size[1] * sensor_size[2] # 128 * 128 * 2 = 32768
# the size of the hidden layers is arbitrary and could be something else too
hidden_size = 100
output_size = len(classes) # 100

Now, we can define the model

In [13]:
# to define one in PyTorch, it is a subclass of nn.Module
class SpkNet(nn.Module):
  def __init__(self, input_size, hidden_size, output_size, beta, spike_grad):
    super().__init__()

    # Convolutions
    self.pool = nn.MaxPool2d(2, 2)

    self.conv1 = nn.Conv2d(2, 16, 5)
    # after each layer, there is a spiking layer. This is how snnTorch is used to make this a spiking neural network.
    # snn.Leaky is for a leaky integrate-and-fire neuron (LIF). This is a standard type of spiking neuron.
    # beta is the decay rate, and spike_grad is the function that is used for the gradient.
    self.lifc1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
    self.conv2 = nn.Conv2d(16, 32, 5)
    self.lifc2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
    self.conv3 = nn.Conv2d(32, 64, 5)
    self.lifc3 = snn.Leaky(beta=beta, spike_grad=spike_grad)

    # After performing all the convolutions, the result will be of size 9216
    self.fc_input_size = 9216

    # FC
    self.lin1 = nn.Linear(self.fc_input_size, hidden_size)
    self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
    self.lin2 = nn.Linear(hidden_size, hidden_size)
    self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
    self.lin3 = nn.Linear(hidden_size, hidden_size)
    self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)
    self.lin4 = nn.Linear(hidden_size, output_size)
    self.lif4 = snn.Leaky(beta=beta, spike_grad=spike_grad)

  # this defines the forward pass, where it takes an input and outputs the spikes over time
  def forward(self, x):
    # The spikes will be recorded in spk_rec
    spk_rec = []
    # Number of frames
    num_steps = x.shape[0]

    # These keep track of the membrane potential
    # init_leaky() is used to initialize the LIF neurons.
    mem_c1 = self.lifc1.init_leaky()
    mem_c2 = self.lifc2.init_leaky()
    mem_c3 = self.lifc3.init_leaky()
    mem_1 = self.lif1.init_leaky()
    mem_2 = self.lif2.init_leaky()
    mem_3 = self.lif3.init_leaky()
    mem_4 = self.lif4.init_leaky()

    # Now, it will run through the sample, like how an RNN works, and the membrane potentials are kept to be used in the next step.
    print(f'/{num_steps}')
    for step in range(num_steps):
      if (step + 1) % 5 == 0:
        print(step + 1, end=', ')

      # Convolutions
      # It performs the convolution on the current frame, then pools.
      # There is no ReLU activation function needed, because it uses spikes!
      out = self.conv1(x[step])
      out = self.pool(out)
      spk_c1, mem_c1 = self.lifc1(out, mem_c1)
      out = self.conv2(spk_c1)
      out = self.pool(out)
      spk_c2, mem_c2 = self.lifc2(out, mem_c2)
      out = self.conv3(spk_c2)
      out = self.pool(out)
      spk_c3, mem_c3 = self.lifc3(out, mem_c3)

      # FC
      out = spk_c3.view(-1, self.fc_input_size) # Flatten
      out = self.lin1(out)
      spk_1, mem_1 = self.lif1(out, mem_1)
      out = self.lin2(spk_1)
      spk_2, mem_2 = self.lif2(out, mem_2)
      out = self.lin3(spk_2)
      spk_3, mem_3 = self.lif2(out, mem_3)
      out = self.lin4(spk_3)
      spk_4, mem_4= self.lif4(out, mem_4)
      # Keep track of the spikes from the last layer.
      spk_rec.append(spk_4)

    # Return the record of the spikes
    return torch.stack(spk_rec, dim=0)

Now that the model is defined, you can create it.

In [None]:
spike_grad = surrogate.fast_sigmoid(slope=25) # the surrogate gradient
beta = 0.9 # decay rate

# this is for using cuda
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Device: {device}')
model = SpkNet(input_size, hidden_size, output_size, beta, spike_grad).to(device)

# Loss and optimizer

Yay you created the model! Now, it needs a loss function, and also an optimizer.

The correct output class should be the one with the most spikes, so it makes sense for the loss to be minimized when all the spikes happen for the correct output class. However, this could result in weights getting set to zero causing dead neurons, so instead, the loss will aim for a certain percent of spikes in the correct class, with the rest being for the incorrect class.

The loss will be mean squared error and the optimizer will be Adam.

In [15]:
# Learning rate
learning_rate = 0.0001
target_correct = 0.8

# loss and optimizer
criterion = SF.mse_count_loss(correct_rate=target_correct, incorrect_rate=1-target_correct)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.99))

## Training
Now, the model is created and you are ready to train it.

In [None]:
import datetime # so we can see how long each batch is taking
save_checkpoints = True

# loading and saving the model:
# you can import a model that is saved already
load_model_path = None # './models/snn/model'
# you'll save checkpoints to ./models/snn/epoch{epoch}, and save the most recent to ./models/snn/model
save_model_path = './models/snn'


# function for saving
import os
def save(epoch, model, optimizer, path):
  checkpoint = {
    'epoch': epoch,
    'model_state': model.state_dict(),
    'optim_state': optimizer.state_dict()
  }
  if not os.path.exists(path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
  torch.save(checkpoint, path)
  print(f'Saved to {path}')

starting_epoch = 0
# load existing model
if load_model_path is not None:
  checkpoint = torch.load(load_model_path, map_location=torch.device('cpu'))
  starting_epoch = checkpoint['epoch']
  model.load_state_dict(checkpoint['model_state'])
  optimizer.load_state_dict(checkpoint['optim_state'])
  print(f'Starting from after epoch {starting_epoch}')
else:
  print('New model')

And now you can do a training loop. Also, display what a sample is like

In [17]:
from IPython.display import HTML
# @title define function for animation
def show_sample(data, interval=30):
  animation = plot_animation(interval, frames=data)
  return animation

# from https://github.com/neuromorphs/tonic/blob/develop/tonic/utils.py
from typing import Tuple
def plot_animation(interval, frames: np.ndarray, figsize: Tuple[int, int] = (5, 5)):
  try:
    import matplotlib.pyplot as plt
    from matplotlib import animation
  except ImportError:
    raise ImportError(
      "Please install the matplotlib package to plot events. This is an optional dependency."
    )
  fig = plt.figure(figsize=figsize)
  if frames.shape[1] == 2:
    rgb = np.zeros((frames.shape[0], 3, *frames.shape[2:]))
    rgb[:, 1:, ...] = frames
    frames = rgb
  if frames.shape[1] in [1, 2, 3]:
    frames = np.moveaxis(frames, 1, 3)
  ax = plt.imshow(frames[0])
  plt.axis("off")

  def animate(frame):
    ax.set_data(frame)
    return ax

  anim = animation.FuncAnimation(fig, animate, frames=frames, interval=interval)
  return anim

In [None]:
model.train() # this is important to set it to training mode

# record the accuracy
accuracies = []
for epoch in range(starting_epoch, num_epochs):
  print(datetime.datetime.now())

  # trainloader gives one batch at a time
  for i, (data, targets) in enumerate(trainloader):
    # let's plot the first sample in the first batch of the first epoch:
    if epoch == starting_epoch and i == 0:
      data_animate = data.clone() / 1.5
      data_animate[data_animate > 1] = 1
      print(f'Target: {classes[targets[0]]}')
      anim = show_sample(data_animate[:, 0, :, :, :], interval=30)
      HTML(anim.to_jshtml())

    # forward pass
    # move the tensors onto the device
    data = data.to(device)
    targets = targets.to(device)
    # run the model on the input to get the spikes record
    spk_rec = model(data)
    # get the loss
    loss = criterion(spk_rec, targets)

    # backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    accuracy = SF.accuracy_rate(spk_rec, targets)
    accuracies.append(accuracy)

    if i % 1 == 0:
      print(f'Epoch {epoch + 1}/{num_epochs}: Batch {i + 1}/{len(trainloader)}: Loss: {loss.item():.2f}, Accuracy: {accuracy * 100:.2f}%')

  if save_checkpoints:
    save(epoch + 1, model, optimizer, save_model_path + f'/epoch{epoch + 1}')
    save(epoch + 1, model, optimizer, save_model_path + '/model')

In [None]:
# @title If the animation didn't show up, try this
anim = show_sample(data_animate[:, 0, :, :, :], interval=30)
HTML(anim.to_jshtml())

Let's see how the training accuracy changed over training

In [None]:
import matplotlib.pyplot as plt

plt.plot(list(range(starting_epoch, starting_epoch + num_epochs)), accuracies)
plt.title('Training Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')

Congrats, you trained a spiking neural network to read lips! Now you can evaluate it with the test dataset to see how good it is.

## Evaluation

In [None]:
# switch to eval mode
model.eval()

# it doesn't need to keep track of gradients because it is not learning here
with torch.no_grad():
  # to keep track of its accuracy
  correct = 0
  total = 0
  # this is to see how well it did on each word
  classes_correct = [0] * len(classes)
  classes_total = [0] * len(classes)
  # iterate one batch at a time
  for i, (data, targets) in enumerate(testloader):
    data = data.to(device)
    targets = targets.to(device)
    # get the spiking record
    spk_rec = model(data)
    loss = criterion(spk_rec, targets)
    total += batch_size
    for j in range(batch_size):
      single_spk_rec = spk_rec[:, j, :]
      # get the prediction (the class that had the most spikes)
      _, prediction = torch.max(sum(single_spk_rec), 0)
      target = targets[j]
      classes_total[target] += 1
      if prediction == target:
        classes_correct[target] += 1
        correct += 1
    if i % 1 == 0:
      print(f'Batch {i + 1}/{len(testloader)}: Loss: {loss.item():.2f}')
      print(f'Current total: {total}. Current Accuracy: {100 * correct / total:.2f}%')

  print(f'Final Test Accuracy: {100 * correct / total:.2f}%')
  for i in range(100):
    if classes_total[i] != 0:
      print(f'{classes[i]}: {100 * classes_correct[i] / classes_total[i]:.2f}%', end=', ')

Hopefully it did ok! I trained it for 200 epochs and it got around 16% accuracy on the test set. Can this be improved though?

## Other ideas
Now that you made a spiking neural network model for this, you can try other methods to see how other methods do! For example, you could make a 3D CNN which I will show next. A disadvantage of this is that it takes fixed-length input, so all the videos need to be padded to be the same length, which is not good for this case where a video could be of any length. However, it got to 28% accuracy and in fewer epochs than the spiking neural network.

Let's define the model for the 3D CNN

In [18]:
class ConvNet(nn.Module):
  def __init__(self):
    super().__init__()

    # Convolutions
    # The numbers I chose here are so the output is not too big but still works
    self.pool = nn.MaxPool3d(3)
    self.pool2 = nn.MaxPool3d(2)

    self.conv1 = nn.Conv3d(2, 4, 3, padding='same')
    self.conv2 = nn.Conv3d(4, 8, 3, padding='same')
    self.conv3 = nn.Conv3d(8, 16, 3, padding='same')
    self.conv4 = nn.Conv3d(16, 32, 3, padding='same')


    # FC
    self.fc_input_size = 288
    self.hidden_size = 150
    self.lin1 = nn.Linear(self.fc_input_size, self.hidden_size)
    self.lin2 = nn.Linear(self.hidden_size, self.hidden_size)
    self.lin3 = nn.Linear(self.hidden_size, 100)


  def forward(self, x):
    # Now, there is no loop over the timesteps. It is doing convolutions over time and the image at once, like a 3D volume
    # 3d convs
    x = F.relu(self.pool(self.conv1(x)))
    x = F.relu(self.pool(self.conv2(x)))
    x = F.relu(self.pool2(self.conv3(x)))
    x = F.relu(self.pool2(self.conv4(x)))

    # fc
    x = x.view(batch_size, self.fc_input_size)
    x = F.relu(self.lin1(x))
    x = F.relu(self.lin2(x))
    x = self.lin3(x)
    return x


model2 = ConvNet().to(device)

# loss and optimizer
# now the loss is just cross entropy loss, as there's not spiking anymore.
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model2.parameters(), lr=learning_rate, betas=(0.9, 0.99))


# loading a model
load_model_path = None # './models/cnn/model'
save_model_path = './models/cnn'

starting_epoch = 0
if load_model_path is not None:
  checkpoint = torch.load(load_model_path, map_location=torch.device('cpu'))
  starting_epoch = checkpoint['epoch']
  model2.load_state_dict(checkpoint['model_state'])
  optimizer.load_state_dict(checkpoint['optim_state'])
  print(f'Starting from after epoch {starting_epoch}')

We can't just run this now though, because the data has to be in a different format. The CNN takes in a fixed length input, and we also have to change the shape of it so batches comes before timestep. Here is a function to pad or crop all the videos to 70 frames.

In [19]:
def custom_pad(batch, max_length=70, image_size=128):
  # This is an empty frame
  pad = torch.zeros(2, image_size, image_size)
  x = []
  y = []
  for i in range(batch_size):
    y.append(batch[i][1])
    if len(batch[i][0]) > max_length: # too long: just take the first 70 frames
      x.append(batch[i][0][:max_length])
    elif len(batch[i][0] < max_length):  # too short: append empty frames
      x.append(torch.cat((batch[i][0], pad.repeat(max_length - len(batch[i][0]), 1, 1, 1)), 0))
    else:
      x.append(batch[i][0]) # just right length
  return (torch.stack(x, dim=0), torch.LongTensor(y))

Now, recreate the trainloader and testloader, using this padding.

In [18]:
trainloader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=custom_pad, shuffle=True, drop_last=True)
testloader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=custom_pad, shuffle=True, drop_last=True)

Now, you can train the model!

In [None]:
# switch to training mode
model2.train()
print('Training')
for epoch in range(starting_epoch, num_epochs):
  print(datetime.datetime.now())
  for i, (data, targets) in enumerate(trainloader):
    # forward pass
    # switch the batch and timestep dimensions
    data = torch.permute(data, (0, 2, 1, 3, 4))
    data = data.to(device)
    targets = targets.to(device)
    out = model2(data)
    loss = criterion(out, targets)

    # backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if i % 1 == 0:
      print(f'Epoch {epoch + 1}/{num_epochs}: Batch {i + 1}/{len(trainloader)}: Loss: {loss.item():.2f}')

  if save_checkpoints:
    save(epoch + 1, model2, optimizer, save_model_path + f'/epoch{epoch + 1}')
    save(epoch + 1, model2, optimizer, save_model_path + '/model')

# Combining 3D CNN with SNN
Given that both a 3D CNN and a spiking neural network worked, it might be interesting to try combining them! However, the input will still need to be fixed length because of the 3D CNN, but then it will use the SNN on the output. Also, let's add more frames because it is going to be shrunk down by the 3D CNN. You might need to decrease batch size more because the samples are bigger.

In [20]:
batch_size = 50

In [21]:
# now each sample will have 350 frames
def custom_pad(batch, max_length=350, image_size=128):
  # This is an empty frame
  pad = torch.zeros(2, image_size, image_size)
  x = []
  y = []
  for i in range(batch_size):
    y.append(batch[i][1])
    if len(batch[i][0]) > max_length: # too long: just take the first 350 frames
      x.append(batch[i][0][:max_length])
    elif len(batch[i][0] < max_length):  # too short: append empty frames
      x.append(torch.cat((batch[i][0], pad.repeat(max_length - len(batch[i][0]), 1, 1, 1)), 0))
    else:
      x.append(batch[i][0]) # just right length
  return (torch.stack(x, dim=0), torch.LongTensor(y))

Let's redo changing the events to frames, this time shrinking the time window to give it more resolution

In [None]:
# same code from before, but with lower time window and using the custom function above
# Get the sensor_size
sensor_size = tonic.datasets.DVSLip.sensor_size # should be (128, 128, 2)
print(f'Sensor size: {sensor_size}')

# The tonic.transforms.Compose is to apply transforms sequentially.
# Denoise is used to reduce the noise. If a nearby pixel does not have an event within filter_time from this event, it gets ignored.
# ToFrame is super important! The data is formatted right now as a sequence of events which is a tuple, but we want a sequence of frames
# The time window is the window across which events are combined to make a single frame. Right now, there are hundreds of thousands of timesteps which you saw when you looked at the last event in the first sample.
frame_transform = tonic.transforms.Compose([tonic.transforms.Denoise(filter_time=50000),
                                            tonic.transforms.ToFrame(sensor_size=sensor_size,
                                                                     time_window=6000)]) # the time window is 5x less

# Now, we create the trainset (and testset) applying these transforms
trainset = tonic.datasets.DVSLip(save_to="./data", transform=frame_transform, train=True)
testset = tonic.datasets.DVSLip(save_to="./data", transform=frame_transform, train=False)

# Ok now time to switch to pytorch

# To convert to float
class ToFloat(object):
  def __call__(self, sample):
    return sample.float()

# The first is to convert the input which is a numpy array into a torch tensor.
# The next one converts it to float, because the next one after that is normalizing the data and it takes floats.
# The last transform is done to make the data normalized
torch_transform = torchvision.transforms.Compose([
    torch.from_numpy,
    ToFloat(),
    torchvision.transforms.Normalize((0, 0), (1, 1))
    ])

# By using this to cache the data, it will speed up training.
# It may still take a long time for the first epoch, but then after that, it will go way faster.
cached_trainset = MemoryCachedDataset(trainset, transform=torch_transform)
cached_testset = MemoryCachedDataset(testset, transform=torch_transform)

trainloader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=custom_pad, shuffle=True, drop_last=True)
testloader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=custom_pad, shuffle=True, drop_last=True)

classes = trainset.classes
print(f'Training batches: {len(trainloader)}')
print(f'Testing batches: {len(testloader)}')

In [23]:
class CSNN(nn.Module):
  def __init__(self, beta, spike_grad):
    super().__init__()


    # Convolutions
    # The numbers I chose here are so the output is not too big but still works
    self.pool = nn.MaxPool3d(3)
    self.pool2 = nn.MaxPool3d(2)

    self.conv1 = nn.Conv3d(2, 4, 3, padding='same')
    self.conv2 = nn.Conv3d(4, 8, 3, padding='same')
    self.conv3 = nn.Conv3d(8, 8, 3, padding='same')


    # FC
    hidden_size = 100
    output_size = 100
    self.lin1 = nn.Linear(2048, hidden_size)
    self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
    self.lin2 = nn.Linear(hidden_size, hidden_size)
    self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
    self.lin3 = nn.Linear(hidden_size, hidden_size)
    self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)
    self.lin4 = nn.Linear(hidden_size, output_size)
    self.lif4 = snn.Leaky(beta=beta, spike_grad=spike_grad)


  def forward(self, x):
    # Now, there is no loop over the timesteps. It is doing convolutions over time and the image at once, like a 3D volume
    # 3d convs
    x = F.relu(self.pool2(self.conv1(x)))
    x = F.relu(self.pool2(self.conv2(x)))
    x = F.relu(self.pool2(self.conv3(x)))

    # [50, 8, 43, 16, 16])
    new_x = torch.permute(x, (2, 0, 1, 3, 4))
    # [43, 50, 8, 16, 16])

    # spiking fc part
    spk_rec = []
    mem_1 = self.lif1.init_leaky()
    mem_2 = self.lif2.init_leaky()
    mem_3 = self.lif3.init_leaky()
    mem_4 = self.lif4.init_leaky()

    for step in range(43):
      out = new_x[step].reshape(-1, 8 * 16 * 16)
      out = self.lin1(out)
      spk_1, mem_1 = self.lif1(out, mem_1)
      out = self.lin2(spk_1)
      spk_2, mem_2 = self.lif2(out, mem_2)
      out = self.lin3(spk_2)
      spk_3, mem_3 = self.lif2(out, mem_3)
      out = self.lin4(spk_3)
      spk_4, mem_4= self.lif4(out, mem_4)
      spk_rec.append(spk_4)

    return torch.stack(spk_rec, dim=0)

model3 = CSNN(beta, spike_grad).to(device)

# loss and optimizer
criterion = SF.mse_count_loss(correct_rate=target_correct, incorrect_rate=1-target_correct)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.99))


# loading a model
load_model_path = None # './models/csnn/model'
save_model_path = './models/csnn'

starting_epoch = 0
if load_model_path is not None:
  checkpoint = torch.load(load_model_path, map_location=torch.device('cpu'))
  starting_epoch = checkpoint['epoch']
  model3.load_state_dict(checkpoint['model_state'])
  optimizer.load_state_dict(checkpoint['optim_state'])
  print(f'Starting from after epoch {starting_epoch}')

Now, try training!

In [None]:
# switch to training mode
model3.train()
accuracies = []

print('Training')
for epoch in range(starting_epoch, num_epochs):
  print(datetime.datetime.now())
  for i, (data, targets) in enumerate(trainloader):
    # forward pass
    # switch the batch and timestep dimensions
    data = torch.permute(data, (0, 2, 1, 3, 4))
    data = data.to(device)
    targets = targets.to(device)
    spk_rec = model3(data)
    loss = criterion(spk_rec, targets)

    # backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    accuracy = SF.accuracy_rate(spk_rec, targets)
    accuracies.append(accuracy)

    if i % 1 == 0:
      print(f'Epoch {epoch + 1}/{num_epochs}: Batch {i + 1}/{len(trainloader)}: Loss: {loss.item():.2f}, Accuracy: {accuracy * 100:.2f}%')

  if save_checkpoints:
    save(epoch + 1, model3, optimizer, save_model_path + f'/epoch{epoch + 1}')
    save(epoch + 1, model3, optimizer, save_model_path + '/model')

It's not really learning though, so maybe the other models were better.

Plot accuracy

In [None]:
import matplotlib.pyplot as plt


plt.plot(list(range(starting_epoch, starting_epoch + num_epochs)), accuracies)
plt.title('Training Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')

And now, evaluate:

In [None]:
# switch to eval mode
model.eval()

# it doesn't need to keep track of gradients because it is not learning here
with torch.no_grad():
  # to keep track of its accuracy
  correct = 0
  total = 0
  # this is to see how well it did on each word
  classes_correct = [0] * len(classes)
  classes_total = [0] * len(classes)
  # iterate one batch at a time
  for i, (data, targets) in enumerate(testloader):
    data = torch.permute(data, (0, 2, 1, 3, 4))
    data = data.to(device)
    targets = targets.to(device)
    # get the spiking record
    spk_rec = model3(data)
    loss = criterion(spk_rec, targets)
    total += batch_size
    for j in range(batch_size):
      single_spk_rec = spk_rec[:, j, :]
      # get the prediction (the class that had the most spikes)
      _, prediction = torch.max(sum(single_spk_rec), 0)
      target = targets[j]
      classes_total[target] += 1
      if prediction == target:
        classes_correct[target] += 1
        correct += 1
    if i % 1 == 0:
      print(f'Batch {i + 1}/{len(testloader)}: Loss: {loss.item():.2f}')
      print(f'Current total: {total}. Current Accuracy: {100 * correct / total:.2f}%')

  print(f'Final Test Accuracy: {100 * correct / total:.2f}%')
  for i in range(100):
    if classes_total[i] != 0:
      print(f'{classes[i]}: {100 * classes_correct[i] / classes_total[i]:.2f}%', end=', ')

# The End
You finished the tutorial! I hope you enjoyed it and learned!

Group contribution statement: I worked on this tutorial on my own