# MNIST training notebook

In [44]:
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
import torch.nn as nn

import snntorch as snn 
from snntorch import spikegen
import snntorch.spikeplot as splt

## Load Datasets and DataLoaders

In [22]:
# Load Datasets
transform = transforms.Compose(
    [
        transforms.Resize((28, 28)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize((0,), (1,))
    ]
)
datapath = "../data"
train_dataset = MNIST(datapath, train=True, transform=transform, download=False)  # Change download=True first time
test_dataset = MNIST(datapath, train=False, transform=transform, download=False)

In [28]:
# Load Dataloaders
BATCH_SIZE = 128

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [36]:
train_dataset.train_data.shape, test_dataset.train_data.shape

(torch.Size([60000, 28, 28]), torch.Size([10000, 28, 28]))

## Define Network

In [39]:
TAU = 5e-3
DT = 1  # ms
BETA = 0.9
THRESHOLD = 1.0
NUM_STEPS = 25

NUM_INPUTS = len(train_dataset.train_data[0].flatten())
NUM_HIDDENS = 1000  # Design choice
NUM_OUTPUTS = 10  # Number of output classes
print(f"NUM_INPUTS: {NUM_INPUTS}, NUM_HIDDENS: {NUM_HIDDENS}, NUM_OUTPUTS: {NUM_OUTPUTS}")

NUM_INPUTS: 784, NUM_HIDDENS: 1000, NUM_OUTPUTS: 10




In [47]:
class Net(nn.Module):
    """Simple 3-layer, feed-forward SNN"""

    def __init__(self, num_inputs: int, num_hiddens: int, num_outputs: int, 
                 tau: float, dt, beta: float, threshold: float, num_steps: int):
        """Initialise hyperparameters and architecture"""
        super().__init__()  # Get good stuff from pytorch.nn.Module

        # Hyperparams
        self.tau = tau
        self.dt = dt
        self.beta = beta
        self.threshold = threshold
        self.num_steps = num_steps  # No. simulation steps for 1 example

        # Architecture
        self.fc1 = nn.Linear(in_features=num_inputs, out_features=num_hiddens)
        self.lif1 = snn.Leaky(beta=beta, threshold=threshold)
        self.fc2 = nn.Linear(in_features=num_hiddens, out_features=num_outputs)
        self.lif2 = snn.Leaky(beta=beta, threshold=threshold)

    def forward(self, x):
        # Initialise membrane potential tensors
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Create arrays to store spikes over time
        spk2_rec = []
        mem2_rec = []

        # Iterate over all timesteps for 1 example
        for step in range(self.num_steps):
            # TODO: x is not spikes, but this generates effectively I = WX for input to LIF potential
            # Also same x is fed into network at each time, should be probabilistic spike over time instead, x[step] (index in time)
            cur1 = self.fc1(x)  
            spk1, mem1 = self.lif1(cur1)
            cur2 = self.fc1(spk1)
            spk2, mem2 = self.lif2(cur2)

            spk2_rec.append(spk2)  # Store spike outputs & membrane voltage
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec)
    
net = Net(num_inputs=NUM_INPUTS, num_hiddens=NUM_HIDDENS, num_outputs=NUM_OUTPUTS,
          tau=TAU, dt=DT, beta=BETA, threshold=THRESHOLD, num_steps=NUM_STEPS)

* `fc1` applies a linear transformation to all input pixels from the MNIST dataset;
* `lif1` integrates the weighted input over time, emitting a spike if the threshold condition is met;
* `fc2` applies a linear transformation to the output spikes of `lif1`;
* `lif2` is another spiking neuron layer, integrating the weighted spikes over time.

## Training