## Goal: Train SNN on MNIST task + Compare performance with ANN

In [32]:
!pip install snntorch --quiet

In [33]:
import torch, torch.nn as nn
import snntorch as snn

## DataLoading
Define variables for dataloading.

In [34]:
batch_size = 128
data_path= "raw"
#'/data/mnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Torch Variables
dtype = torch.float

Load MNIST dataset.

In [35]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

## Define Network with snnTorch. 
* `snn.Leaky()` instantiates a simple leaky integrate-and-fire neuron.
* `spike_grad` optionally defines the surrogate gradient. If left undefined, the relevant gradient term is simply set to the output spike itself (1/0) by default.


The problem with `nn.Sequential` is that each hidden layer can only pass one tensor to subsequent layers, whereas most spiking neurons return their spikes and hidden state(s). To handle this:

* `init_hidden` initializes the hidden states (e.g., membrane potential) as instance variables to be processed in the background. 

The final layer is not bound by this constraint, and can return multiple tensors:
* `output=True` enables the final layer to return the hidden state in addition to the spike.

In [36]:
from snntorch import surrogate

beta = 0.9  # neuron decay rate 
spike_grad = surrogate.fast_sigmoid()

#  Initialize Network
net = nn.Sequential(nn.Conv2d(1, 8, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(8, 16, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(16*4*4, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)

## Define the Forward Pass
Now define the forward pass over multiple time steps of simulation.

In [37]:
from snntorch import utils 

def forward_pass(net, data, num_steps):  
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps): 
      spk_out, mem_out = net(data)
      spk_rec.append(spk_out)
  
  return torch.stack(spk_rec)

Define the optimizer and loss function. Here, we use the MSE Count Loss, which counts up the total number of output spikes at the end of the simulation run. The correct class has a target firing rate of 80% of all time steps, and incorrect classes are set to 20%. 

In [38]:
import snntorch.functional as SF

optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

## Training Loop

Now for the training loop. The predicted class will be set to the neuron with the highest firing rate, i.e., a rate-coded output. We will just measure accuracy on the training set. This training loop follows the same syntax as with PyTorch.

In [39]:
num_epochs = 1
num_steps = 25  # run for 25 time steps 

loss_hist = []
acc_hist = []

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec = forward_pass(net, data, num_steps)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # print every 25 iterations
        if i % 25 == 0:
          print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

          # check accuracy on a single batch
          acc = SF.accuracy_rate(spk_rec, targets)  
          acc_hist.append(acc)
          print(f"Accuracy: {acc * 100:.2f}%\n")
        
        # uncomment for faster termination
        # if i == 150:
        #     break


Epoch 0, Iteration 0 
Train Loss: 2.45
Accuracy: 12.50%

Epoch 0, Iteration 25 
Train Loss: 0.74
Accuracy: 41.41%

Epoch 0, Iteration 50 
Train Loss: 0.57
Accuracy: 60.16%

Epoch 0, Iteration 75 
Train Loss: 0.38
Accuracy: 87.50%

Epoch 0, Iteration 100 
Train Loss: 0.32
Accuracy: 90.62%

Epoch 0, Iteration 125 
Train Loss: 0.26
Accuracy: 93.75%

Epoch 0, Iteration 150 
Train Loss: 0.27
Accuracy: 92.19%

Epoch 0, Iteration 175 
Train Loss: 0.22
Accuracy: 94.53%

Epoch 0, Iteration 200 
Train Loss: 0.22
Accuracy: 94.53%

Epoch 0, Iteration 225 
Train Loss: 0.21
Accuracy: 95.31%

Epoch 0, Iteration 250 
Train Loss: 0.20
Accuracy: 93.75%

Epoch 0, Iteration 275 
Train Loss: 0.18
Accuracy: 97.66%

Epoch 0, Iteration 300 
Train Loss: 0.18
Accuracy: 96.09%

Epoch 0, Iteration 325 
Train Loss: 0.17
Accuracy: 96.09%

Epoch 0, Iteration 350 
Train Loss: 0.16
Accuracy: 95.31%

Epoch 0, Iteration 375 
Train Loss: 0.19
Accuracy: 94.53%

Epoch 0, Iteration 400 
Train Loss: 0.19
Accuracy: 95.31%

Ep

Let's see the accuracy on the full test set, again using `SF.accuracy_rate`.

In [41]:
def test_accuracy(data_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    data_loader = iter(data_loader)
    for data, targets in data_loader:
      data = data.to(device)
      targets = targets.to(device)
      spk_rec = forward_pass(net, data, num_steps)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total

In [42]:
print(f"Test set accuracy: {test_accuracy(test_loader, net, num_steps)*100:.3f}%")

Test set accuracy: 97.960%
