<a href="https://colab.research.google.com/github/open-neuromorphic/fpga-snntorch/blob/main/software/ISFPGA_SNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ISFPGA Workshop
## Who needs neuromorphic hardware? Deploying SNNs to FPGAs via HLS Open-Source Neuromorphic Circuit Design
### By Jason K. Eshraghian (www.ncg.ucsc.edu)


[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width="200">](https://github.com/jeshraghian/snntorch/)

In [None]:
!pip install snntorch --quiet # shift + enter

*What will I learn?*

1. Train an SNN classifier using snnTorch
2. Hardware Friendly Training
  - Weight Quantization with Brevitas
  - Stateful Quantization
3. Handling neuromorphic data with Tonic

# 1. Train an SNN Classifier using snnTorch
## 1.1 Imports


In [None]:
# snntorch imports
import snntorch as snn
from snntorch import functional as SF

# pytorch imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# data manipulation
import numpy as np
import itertools

# plotting
import matplotlib.pyplot as plt
from IPython.display import HTML

## 1.2 Boilerplate: DataLoading the MNIST Dataset

In [None]:
# dataloader arguments
batch_size = 128
data_path='/data/mnist'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
## if you're on M1 or M2 GPU:
# device = torch.device("mps")

In [None]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

In [None]:
# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

## 1.3 Construct SNN Model

In [None]:
# Network Architecture
num_inputs =
num_hidden =
num_outputs =

# Temporal Dynamics
num_steps =
beta =

In [None]:
from snntorch import surrogate

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 =
        self.lif1 =
        self.fc2 =
        self.lif2 =

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        # time-loop
        for step in range(num_steps):
          cur1 = self.fc1(...) # batch: 128 x 784
          spk1, mem1 = self.lif1(...)
          cur2 = self.fc2(...)
          spk2, mem2 = self.lif2(...)

          # store in list
          spk2_rec.append(spk2)
          mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0) # time-steps x batch x num_out

# Load the network onto CUDA if available
net = Net().to(device)

## 1.4 Training the SNN

In [None]:
def training_loop(model, dataloader, num_epochs=1):
  loss = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999))
  counter = 0

  # Outer training loop
  for epoch in range(num_epochs):
      train_batch = iter(dataloader)

      # Minibatch training loop
      for data, targets in train_batch:
          data = data.to(device)
          targets = targets.to(device)

          # forward pass
          model.train()
          spk_rec, _ = model(data)

          # initialize the loss & sum over time
          loss_val = torch.zeros((1), dtype=dtype, device=device)
          loss_val = loss(spk_rec.sum(0), targets) # batch x num_out

          # Gradient calculation + weight update
          optimizer.zero_grad()
          loss_val.backward()
          optimizer.step()

          # Print train/test loss/accuracy
          if counter % 10 == 0:
              print(f"Iteration: {counter} \t Train Loss: {loss_val.item()}")
          counter += 1

          if counter == 100:
            break

training_loop(net, train_loader)

In [None]:
def measure_accuracy(model, dataloader):
  with torch.no_grad():
    model.eval()
    running_length = 0
    running_accuracy = 0

    for data, targets in iter(dataloader):
      data = data.to(device)
      targets = targets.to(device)

      # forward-pass
      spk_rec, _ = model(data)
      spike_count = spk_rec.sum(0) # batch x num_outputs
      _, max_spike = spike_count.max(1)

      # correct classes for one batch
      num_correct = (max_spike == targets).sum()

      # total accuracy
      running_length += len(targets)
      running_accuracy += num_correct

    accuracy = (running_accuracy / running_length)

    return accuracy.item()


In [None]:
print(f"Test set accuracy: {measure_accuracy(net, test_loader)}")

### A Sanity Check

In [None]:
def print_sample(model, dataloader, idx=0):
  with torch.no_grad():
    model.eval()

    data, targets = next(iter(dataloader))
    data = data.to(device)
    targets = targets.to(device)

    # forward-pass
    spk_rec, _ = model(data)
    spike_count = spk_rec.sum(0) # batch x num_outputs
    _, max_spike = spike_count.max(1)

    # Plot the sample
    plt.imshow(data[idx].cpu().squeeze(), cmap='gray')
    plt.title(f'Target: {targets[idx].item()}')
    plt.show()


    return

In [None]:
print_sample(net, test_loader)

# 2. Hardware Friendly Training
## 2.1 Weight Quantization

In [None]:
!pip install brevitas --quiet

Just replace all `nn.Linear` layers with `qnn.QuantLinear(num_inputs, num_outputs, weight_bit_width, bias)`.

In [None]:
import brevitas.nn as qnn

# Define Network
class QuantNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 =
        self.lif1 =
        self.fc2 =
        self.lif2 =

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x.flatten(1))
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Load the network onto CUDA if available
qnet = QuantNet().to(device)

In [None]:
training_loop(qnet, train_loader)
print(f"Test set accuracy: {measure_accuracy(qnet, test_loader)}")

## 2.2 SQUAT: Stateful Quantization-Aware Training



In [None]:
from snntorch.functional import quant

# Define Network
class SquatNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Define state quantization parameters
        q_lif = quant.state_quant(num_bits=4, uniform=True)

        # Initialize layers
        self.fc1 =
        self.lif1 =
        self.fc2 =
        self.lif2 =

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x.flatten(1))
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Load the network onto CUDA if available
sqnet = SquatNet().to(device)

In [None]:
training_loop(sqnet, train_loader)
print(f"Test set accuracy: {measure_accuracy(sqnet, test_loader)}")

# 3. Handling Neuromorphic Data with Tonic

In [None]:
!pip install tonic --quiet

## 3.1 PokerDVS Dataset

The dataset used in this tutorial is POKERDVS by T. Serrano-Gotarredona and B. Linares-Barranco:

```
Serrano-Gotarredona, Teresa, and Bernabé Linares-Barranco. "Poker-DVS and MNIST-DVS. Their history, how they were made, and other details." Frontiers in neuroscience 9 (2015): 481.
```

It is comprised of four classes, each being a suite of a playing card deck: clubs, spades, hearts, and diamonds. The data consists of 131 poker pip symbols, and was collected by flipping poker cards in front of a DVS128 camera.

In [None]:
import tonic

poker_train = tonic.datasets.POKERDVS(save_to='./data', train=True)
poker_test = tonic.datasets.POKERDVS(save_to='./data', train=False)

events, target = poker_train[0]
print(events)
tonic.utils.plot_event_grid(events)

In [None]:
import tonic.transforms as transforms
from tonic import DiskCachedDataset

# time_window
frame_transform = tonic.transforms.Compose([tonic.transforms.Denoise(filter_time=10000),
                                            tonic.transforms.ToFrame(
                                            sensor_size=tonic.datasets.POKERDVS.sensor_size,
                                            time_window=1000)
                                            ])

batch_size = 8
cached_trainset = DiskCachedDataset(poker_train, transform=frame_transform, cache_path='./cache/pokerdvs/train')
cached_testset = DiskCachedDataset(poker_test, transform=frame_transform, cache_path='./cache/pokerdvs/test')

train_loader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)
test_loader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)

data, labels = next(iter(train_loader))
print(data.size())
print(labels)

## 3.2 Construct Model

In [None]:
import torch.functional as F

# Define Network
class DVSNet(nn.Module):
    def __init__(self):
        super().__init__()

        beta = 0.9

        # Initialize layers
        self.conv1  =
        self.mp1    =
        self.lif1   =
        self.conv2  =
        self.mp2    =
        self.lif2   =
        self.fc     =
        self.lif3   =


    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the final layer
        spk3_rec = []
        mem3_rec = []

        for step in range(...):
            cur1       =
            spk1, mem1 =
            cur2       =
            spk2, mem2 =
            cur3       =
            spk3, mem3 =

            spk3_rec.append(spk3)
            mem3_rec.append(mem3)

        return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

# Load the network onto CUDA if available
dvsnet = DVSNet().to(device)

In [None]:
training_loop(dvsnet, train_loader, num_epochs=10)
print(f"Test set accuracy: {measure_accuracy(dvsnet, test_loader)}")

That's all folks!