In [1]:
# imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools

In [2]:
# dataloader arguments
batch_size = 128
data_path='/data/mnist'
# I have an M2 macbook, so I can only use CPU
device = torch.device("cpu")
dtype = torch.float

## Transforming MNIST Images
Spiking neural networks like 0's and 1's, so we want each image to be in a format the we can work with.
1. We standardize each PIL image to 28 x 28, grey scaled tensors where each value is normalized to 0 and 1
2. Then we gather our train and test data and
3. Instantiate our data loaders

In [3]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

train_data = datasets.MNIST(root='data', train=True,
                                   download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False,
                                  download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=True)

## Part 1: Creating network architecture
The input layer (num_inputs) will be 28 x 28 (or 784 pixels) large to match the image size. Each input pixel in an MNIST image will have its own input neuron.
A hidden layer is a layer of neurons in between the input and output layer, and they help us improve performance. We'll use one here 1,000 neurons large. Finally, since we're doing the MNIST task, there are a maximum of 10 output neurons for each handwritten digit we are trying to identify: 0 thru 9.

Each node will be a Leaky Integrate and Fire Neuron, which will "leak" its integrated value if it goes unactivated.

The beta value is the rate at which each neuron "leaks", and the number timesteps gives us a standardized number of computations per training cycle.

In [4]:
# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10

# Temporal Dynamics
num_steps = 25
beta = 0.95

## Part 2: Creating network architecture
Here we define an instance of our Spiking Neural Network according to our above specifications.

- fc1 - Input layer (first fully-connected layer)
- lif1 - Leaky Integrate and Fire neurons associated with fc1
- fc2 - Output layer (last fully-connected layer)
- lif2 - Leaky Integrate and Fire neurons associated with fc2

In each forward pass, we initialize our layers of neurons with random weights; the training process will adjust these weights to get the best prediction.
We also records the states of the neurons lif2 in spk2_rec and mem2_rec for each time step, which tracks the spikes and the weights associated with the output layer. mem2_rec will allow us to keep track of our loss and subsequently update our weights via backpropagation later on.

In [5]:
# Define Network
class SNN_MNIST(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []
        
        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

net = SNN_MNIST()

## Loss Function and Optimizer
Here we instantiate our loss function, Cross Entropy Loss, which is a common loss function for classification.
We also define our hyperparameter optimizater to fine-tune our parameters for the best performance on each forward pass.

In [6]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

In [7]:
def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

## Training our SNN
Here we train our SNN by first doing a forward pass on a batch of our training data transformed to our specifications, getting our list of weights in from the output layer in mem_rec and summing them over each time step. We then adjust our weights based on the calculated loss via back propagation.

We then evalute our SNN similarly, except we evaluate on a batch of our test data and do not adjust our weights.

In [9]:
num_epochs = 2
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history 
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 5 == 0:
                train_printer()
            counter += 1
            iter_counter +=1

Epoch 0, Iteration 0
Train Set Loss: 5.48
Test Set Loss: 8.86
Train set accuracy for a single minibatch: 96.09%
Test set accuracy for a single minibatch: 91.41%


Epoch 0, Iteration 5
Train Set Loss: 9.31
Test Set Loss: 5.32
Train set accuracy for a single minibatch: 94.53%
Test set accuracy for a single minibatch: 97.66%


Epoch 0, Iteration 10
Train Set Loss: 4.82
Test Set Loss: 6.08
Train set accuracy for a single minibatch: 93.75%
Test set accuracy for a single minibatch: 94.53%


Epoch 0, Iteration 15
Train Set Loss: 4.05
Test Set Loss: 3.78
Train set accuracy for a single minibatch: 97.66%
Test set accuracy for a single minibatch: 96.09%


Epoch 0, Iteration 20
Train Set Loss: 3.16
Test Set Loss: 4.82
Train set accuracy for a single minibatch: 99.22%
Test set accuracy for a single minibatch: 96.88%


Epoch 0, Iteration 25
Train Set Loss: 3.82
Test Set Loss: 7.73
Train set accuracy for a single minibatch: 95.31%
Test set accuracy for a single minibatch: 92.19%


Epoch 0, Iteration