[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width="300">](https://github.com/jeshraghian/snntorch/)
[<img src='https://github.com/neuromorphs/tonic/blob/develop/docs/_static/tonic-logo-white.png?raw=true' width="200">](https://github.com/neuromorphs/tonic/)


# Training the DVSGesture Dataset from Tonic + snnTorch Tutorial
##### By Malachi Nguyen (mayanguy@ucsc.edu)
##### A special thank you to Professor Jason Eshraghian and my Tutor Giridhar Vadhul for teaching patiently and inspiring their students.

<a href="https://colab.research.google.com/drive/1P2yQCDmp7TilNrEqj_cBzS7vscIs0L_o?usp=sharing">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

For a comprehensive overview on how SNNs work, and what is going on under the hood, [then you might be interested in the snnTorch tutorial series available here.](https://snntorch.readthedocs.io/en/latest/tutorials/index.html)
The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:

> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. "Training Spiking Neural Networks Using Lessons From Deep Learning". Proceedings of the IEEE, 111(9) September 2023.](https://ieeexplore.ieee.org/abstract/document/10242251) </cite>

# Importing Dependencies and Libraries

Please checkout the [Tonic](https://tonic.readthedocs.io/en/latest/) and [SNNTorch](https://snntorch.readthedocs.io/en/latest/tutorials/index.html) libraries for more information.

In [None]:
!pip install tonic --quiet
!pip install snntorch --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.7/110.7 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.5/107.5 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.2/125.2 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.2/76.2 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m47.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m44.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m55.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━

In [None]:
# tonic imports
import tonic
import tonic.transforms as transforms  # Not to be mistaken with torchdata.transfroms
from tonic import DiskCachedDataset # alt: MemoryCachedDataset

# torch imports
import torch
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import torchvision
import torch.nn as nn

# snntorch imports
import snntorch as snn
from snntorch import surrogate
import snntorch.spikeplot as splt
from snntorch import functional as SF
from snntorch import utils

# other imports
import matplotlib.pyplot as plt
from IPython.display import HTML
from IPython.display import display
import numpy as np
import torchdata
import os
from ipywidgets import IntProgress
import time
import statistics
import itertools

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

# 1. The Dataset - Loading and interpreting the DVSGesture Dataset

The dataset used in this tutorial is DVSGesture from a team of researchers at IBM

It is comprised of 11 classes, each being a gesture from a persons hands.
(e.g: 1: hand clapping
2: right hand wave
3: left hand wave)


## 1.1 Loading the Dataset for SNN Torch

1. The dataset presents data in a raw event format, so it must be shaped into a suitable format that can be fed into a model.


2. The following code bins the raw DVS event data into 2000ms time windows, allowing the resulting tensors to be preprocessed before being put into the DataLoader. It also resizes the images from 128x128 pixels to 32x32 for effeciency and memory.


3. One of the challenges in this segment is to figure out how you want to transform and load your data. Creating a dataloader function where you can specify the batch size and other parameters like transformattions allows for flexibility in handling different datasets and training configurations.

*the code below was provided by the man, myth, and legend himself, Professor Jason Eshraghian"


In [None]:
def dataloader(config):
    batch_size = config['batch_size']
    # sensor_size = tonic.datasets.DVSGesture.sensor_size -- the default for DVSGesture Dataset is (128, 128, 2)
    sensor_size = (32, 32, 2)

    train_transform = transforms.Compose([transforms.Denoise(filter_time=10000),
                                          transforms.Downsample(spatial_factor=0.25),
                                          transforms.ToFrame(sensor_size=sensor_size,
                                                             n_time_bins=config['train_time_bin']),
                                          ])

    test_transform = transforms.Compose([transforms.Denoise(filter_time=10000),
                                        transforms.Downsample(spatial_factor=0.25),
                                        transforms.ToFrame(sensor_size=sensor_size,
                                                            n_time_bins=config['test_time_bin']),
                                        ])

    trainset = tonic.datasets.DVSGesture(save_to=config['data_dir'], transform=train_transform, train=True)
    testset = tonic.datasets.DVSGesture(save_to=config['data_dir'], transform=test_transform, train=False)

    cached_trainset = DiskCachedDataset(trainset, cache_path='./data/cache/dvs/train')
    cached_testset = DiskCachedDataset(testset, cache_path='./data/cache/dvs/test')

    train_loader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False))
    test_loader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False))

    return train_loader, test_loader

Now we call our dataloader function to correctly load our data with the correct parameters. You can and should edit this to test different levels of effeciency and accuracy while testing.

In [None]:
# Define configuration parameters
config = {
    'batch_size': 64,
    'data_dir': './data',  # Specify your data directory
    'train_time_bin': 2000,  # Time bin for training data
    'test_time_bin': 2000,  # Time bin for testing data
}

# Call the dataloader function
train_loader, test_loader = dataloader(config)

# Check the length of the loaders for more information
print(f"Number of batches in train loader: {len(train_loader)}")
print(f"Number of batches in test loader: {len(test_loader)}")


Downloading https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/38022171/ibmGestureTrain.tar.gz?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20240323/eu-west-1/s3/aws4_request&X-Amz-Date=20240323T180218Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=a1adb1236417278198c284a783c871bef2b3554527b6ba7f8e73b03f0ff1932c to ./data/DVSGesture/ibmGestureTrain.tar.gz


  0%|          | 0/2443675558 [00:00<?, ?it/s]

Extracting ./data/DVSGesture/ibmGestureTrain.tar.gz to ./data/DVSGesture
Downloading https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/38020584/ibmGestureTest.tar.gz?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20240323/eu-west-1/s3/aws4_request&X-Amz-Date=20240323T180535Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=a8e315bf12ba61cddc833e8c99cbc28b97f24c0f41084698a28f103f155a0a52 to ./data/DVSGesture/ibmGestureTest.tar.gz


  0%|          | 0/691455012 [00:00<?, ?it/s]

Extracting ./data/DVSGesture/ibmGestureTest.tar.gz to ./data/DVSGesture
Number of batches in train loader: 17
Number of batches in test loader: 5


## 1.2 Visualizing your data

It's very important to visualize your data so you know what you're working with. How will you be able to conceptually grasp what to learn if you don't know what it looks like?

Here we can visualize the frames.

In [None]:
import matplotlib.pyplot as plt

# Define the number of rows and columns for subplots
num_rows = 2
num_cols = 10

fig, ax = plt.subplots(num_rows, num_cols, figsize=(20, 4))

for i, (data, targets) in enumerate(train_loader):
    if i >= num_rows * num_cols:
        break

    row = i // num_cols
    col = i % num_cols
    ax[row, col].imshow(data[0][0][0])
    ax[row, col].axis('off')

plt.show()


Another very important thing to visualize is the size and shape of your data. Below is how you can see information on your image (Batchsize, channels, img height, img_width). This information is very important for understanding what sizes to use for your convolutional layering. Loading and transforming in the previous cell allows us understand and see the attributes of our new transformed set.

In [None]:
#printing image size and checking an image within the batch
print("Input image size:", data.shape)

image_index = 1 #make any number you want to check the images index in your batch
image_size = data[image_index].shape
print("Size of the first image in the batch:", image_size)


Input image size: torch.Size([2000, 7, 2, 32, 32])
Size of the first image in the batch: torch.Size([7, 2, 32, 32])


#2.1 Defining the Network

The model used is a sequential network comprised of two sets of convolution layers with 5x5 filters, followed by a final linear and leaky output layer that convert the 800 tensor into 11 output classes.

The forward function gets the spikes from one batch of data and returns them as a tensor.

Please see [this link](https://snntorch.readthedocs.io/en/latest/tutorials/tutorial_6.html#define-the-network) for an in depth explaination of this Network


In [None]:
#parameters
num_classes = 11
spike_grad = surrogate.atan() # arctan surrogate gradient function
beta = 0.5

net = nn.Sequential(nn.Conv2d(2, 12, 5), # first conv layer
                        snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                        nn.MaxPool2d(2),
                        nn.Conv2d(12, 32, 5), # second conv layer
                        snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                        nn.MaxPool2d(2),
                        nn.Flatten(),
                        nn.Linear(32*5*5, num_classes), #flattened linear layer
                        snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                        ).to(device)


#Record the membrane potential and spike response over time:

def forward(net, data):
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net
  for step in range(data.size(0)):
      spk_out, mem_out = net(data[step])
      spk_rec.append(spk_out)
  return torch.stack(spk_rec)

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.003, betas=(0.9, 0.999)) # learning rate = 0.003
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2) # MSE loss function

#2.2 Training

We have loaded and converted our dataset into readable and iterable data for SNNs. We will now train it by iterating over our training batches and using that to create predictions for future data.

# Training Notes and Challenges:
1.   This takes a LONG time... be patient and tweak the number of iterations, epochs and batch sizes to find the optimal balance between speed and accuracy that works for you.  
2.   You may have issues with layer sizes and dimension compatibilty. Issues with multiplying mat1 and mat2 most likely stem from how you defined the layers in your network. Make sure your input tensor size matches the expected input size of the linear layer.
3.   It takes awhile to train.
4.   You may run out of RAM or your computer cannot run this training due to runtiume errors or things of that nature. This could be from how you load your data. Make it size smaller, chop it up, or resize it however needed specific to your dataset. Printing the sizes and loading frames of your data is an easy way to see how you may need to load your data.
5.  Still waiting for training to finish 41 mins in
6.  Checkout this document for a detailed log of my trails and tribulations to optimize the epochs and iterations. This may help you figure out a good number to train your dataset! [here](https://docs.google.com/document/d/118qt3fanVPR6KGj5Ed1BF-nnZkWWEWkNAPr9ZGKe628/edit)
7. Training



In [None]:
num_epochs = 20 #number of complete passes over the entire dataset
num_iters = 6

loss_hist = []
acc_hist = []
test_acc_hist = []

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec = forward(net, data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

        acc = SF.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)
        print(f"Accuracy: {acc * 100:.2f}%\n")

        correct = 0
        total = 0
        for i, (test_data, test_targets) in enumerate(iter(test_loader)):
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)
            spk_rec = forward(net, test_data)
            correct += SF.accuracy_rate(spk_rec, test_targets) * spk_rec.size(1)
            total += spk_rec.size(1)

        test_acc = (correct/total) * 100
        test_acc_hist.append(test_acc)
        print(f"========== Test Set Accuracy: {test_acc:.2f}% ==========\n")

        if i == num_iters:
          break

#3 Results and Fine Tuning

#3.1 Visualizing your results

Visualizing your results is an important step in creating a powerful and accurate SNN. Plotting it on a graph allows you to see which direction you are headed.


In [None]:
# Plot Loss
fig = plt.figure(facecolor="w")
plt.plot(loss_hist)
plt.title("Train Set Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

In [None]:
# Plot Train Accuracy
fig = plt.figure(facecolor="w")
plt.plot(acc_hist)
plt.title("Train Set Accuracy")
plt.xlabel("Iteration")
plt.ylabel("Accuracy")
plt.show()

In [None]:
# Plot Test Accuracy
fig = plt.figure(facecolor="w")
plt.plot(test_acc_hist)
plt.title("Test Set Accuracy")
plt.xlabel("Iteration")
plt.ylabel("Accuracy")
plt.show()