# snnTorch - Tutorial 3
### By Jason K. Eshraghian

# Gradient-based Learning in Convolutional Spiking Neural Networks
In this tutorial, we'll use a convolutional neural network (CNN) to classify the MNIST dataset.
We will use the backpropagation through time (BPTT) algorithm to do so. As before, we will use Stein's neuron model. 

If running in Google Colab:
* Ensure you are connected to GPU by checking Runtime > Change runtime type > Hardware accelerator: GPU
* Next, install the Test PyPi distribution of snnTorch by clicking into the following cell and pressing `Shift+Enter`.

In [None]:
# Install the test PyPi Distribution of snntorch
!pip install -i https://test.pypi.org/simple/ snntorch

## 1. Setting up the Static MNIST Dataset
### 1.1. Import packages and setup environment

In [None]:
import snntorch as snn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import itertools
import matplotlib.pyplot as plt

### 1.2 Define network and SNN parameters
We will use a 2conv-2MaxPool-FCN architecture for a sequence of 50 time steps.

* `alpha` is the decay rate of the synaptic current of a neuron
* `beta` is the decay rate of the membrane potential of a neuron

For the SRM0 model, `alpha` > `beta` to ensure the membrane has an excitatory response to positive inputs.

In [None]:
# Network Architecture
num_inputs = 28*28
num_outputs = 10

# Training Parameters
batch_size=128
data_path='/data/mnist'

# Temporal Dynamics
num_steps = 50
time_step = 1e-3
tau_syn = 9e-3
tau_mem = 8e-3
alpha = float(np.exp(-time_step/tau_syn))
beta = float(np.exp(-time_step/tau_mem))

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

### 1.3 Download MNIST Dataset
To see how to construct a validation set, refer to Tutorial 1.

In [None]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

### 1.4 Create DataLoaders

In [None]:
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

## 2. Define Network
snnTorch treats neurons as activations with recurrent connections. In the last tutorial, we showed how to apply these spiking activations to fully connected layers.
The exact same process cn be applied to convolutions without any modifications, other than dimensionality.

As before, we will use a neuron model and surrogate gradient function:
1. `snntorch.neuron.SRM0` is a leaky integrate and fire (LIF) neuron. Specifically, it is a zero$^{th}$ order spike response model (SRM0) where the change in membrane potential lags behind the input activation. This lag motivates simulating across longer timescales than Stein's model.
2. `snntorch.neuron.FastSigmoidSurrogate` defines separate forward and backward functions. The forward function is a Heaviside step function for spike generation. The backward function is the derivative of a fast sigmoid function, to ensure continuous differentiability.
The `FastSigmoidSurrogate` function has been adapted from:

>Neftci, E. O., Mostafa, H., and Zenke, F. (2019) Surrogate Gradient Learning in Spiking Neural Networks. https://arxiv.org/abs/1901/09948

`snn.neuron.slope` defines the slope of the backward surrogate.

TO-DO: Include visualisation.

In [None]:
from snntorch.neuron import SRM0
from snntorch.neuron import FastSimgoidSurrogate as FSS

spike_grad = FSS.apply
snn.neuron.slope = 50

Now we define our spiking neural network (SNN).
Creating an instance of the `SRM0` neuron requires two arguments and two optional arguments:
1. $I_{syn}$ decay rate, $\alpha$,
2. $V_{mem}$ decay rate, $\beta$,
3. the surrogate spiking function, `spike_grad`=`FSS` (*default*: the gradient of the Heaviside function), and
4. the threshold for spiking, (*default*: 1.0).

`SRM0` requires initialization of its internal states as well as the spike output at $t=0$.
The only difference between the `Stein` model used before is tht `SRM0` requires an additional internal state variable to induce a finite rise time of the membrane potential.

A class method `init_srm0` will take care of this.

For rate coding, the final layer of spikes and membrane potential are used to determine accuracy and loss, respectively.
So their historical values are recorded in `spk3_rec` and `mem3_rec`.

Keep in mind, the dataset we are using is just static MNIST. I.e., it is *not* time-varying.
Therefore, we pass the same MNIST sample to the input at each time step.
This is handled in the line `cur1 = F.max_pool2d(self.conv1(x), 2)`, where `x` is the same input over the whole for-loop.


In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # initialize layers
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=12, kernel_size=5, stride=1, padding=1)
        self.lif1 = SRM0(alpha=alpha, beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=64, kernel_size=5, stride=1, padding=1)
        self.lif2 = SRM0(alpha=alpha, beta=beta, spike_grad=spike_grad)
        self.fc2 = nn.Linear(64*5*5, 10)
        self.lif3 = SRM0(alpha=alpha, beta=beta, spike_grad=spike_grad)

    def forward(self, x):
        # Initialize LIF state variables and spike output tensors
        spk1, syn_pre1, syn_post1, mem1 = self.lif1.init_srm0(batch_size, 12, 13, 13)
        spk2, syn_pre2, syn_post2, mem2 = self.lif1.init_srm0(batch_size, 64, 5, 5)
        spk3, syn_pre3, syn_post3, mem3 = self.lif2.init_srm0(batch_size, 10)

        spk3_rec = []
        mem3_rec = []

        for step in range(num_steps):
            cur1 = F.max_pool2d(self.conv1(x), 2)
            spk1, syn_pre1, syn_post1, mem1 = self.lif1(cur1, syn_pre1, syn_post1, mem1)
            cur2 = F.max_pool2d(self.conv2(spk1), 2)
            spk2, syn_pre2, syn_post2, mem2 = self.lif2(cur2, syn_pre2, syn_post2, mem2)
            cur3 = self.fc2(spk2.view(batch_size, -1))
            spk3, syn_pre3, syn_post3, mem3 = self.lif3(cur3, syn_pre3, syn_post3, mem3)

            spk3_rec.append(spk3)
            mem3_rec.append(mem3)

        return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

net = Net().to(device)

## 3. Training
Time for training! Let's first define a couple of functions to print out test/train accuracy for each minibatch.

In [1]:
def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, 1, 28, 28))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train Set Accuracy: {acc}")
    else:
        print(f"Test Set Accuracy: {acc}")

def train_printer():
    print(f"Epoch {epoch}, Minibatch {minibatch_counter}")
    print(f"Train Set Loss: {loss_hist[counter]}")
    print(f"Test Set Loss: {test_loss_hist[counter]}")
    print_batch_accuracy(data_it, targets_it, train=True)
    print_batch_accuracy(testdata_it, testtargets_it, train=False)
    print("\n")

### 3.1 Optimizer & Loss
* *Output Activation*: We'll apply the softmax function to the membrane potentials of the output layer, rather than the spikes.
* *Loss*: This will then be used to calculate the negative log-likelihood loss.
By encouraging the membrane of the correct neuron class to reach the threshold, we expect that neuron will fire more frequently.
The loss could be applied to the spike count as well, but the membrane is  continuous whereas spike count is discrete.
* *Optimizer*: The Adam optimizer is used for weight updates.
* *Accuracy*: Accuracy is measured by counting the spikes of the output neurons. The neuron that fires the most frequently will be our predicted class.

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=2e-4, betas=(0.9, 0.999))
log_softmax_fn = nn.LogSoftmax(dim=-1)
loss_fn = nn.NLLLoss()

### 3.2 Training Loop
Let's see how the `SRM0` neuron model compares to the `Stein` model in the last tutorial.

In [None]:
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(3):

    minibatch_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data_it, targets_it in train_batch:
        data_it = data_it.to(device)
        targets_it = targets_it.to(device)

        output, mem_rec = net(data_it.view(batch_size, 1, 28, 28)) # [28x28] or [1x28x28]?
        log_p_y = log_softmax_fn(mem_rec)
        loss_val = torch.zeros((1), dtype=dtype, device=device)

        # Sum loss over time steps to perform BPTT
        for step in range(num_steps):
          loss_val += loss_fn(log_p_y[step], targets_it)

        # Gradient calculation
        optimizer.zero_grad()
        loss_val.backward(retain_graph=True)

        # Weight Update
        nn.utils.clip_grad_norm_(net.parameters(), 1)
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        test_data = itertools.cycle(test_loader)
        testdata_it, testtargets_it = next(test_data)
        testdata_it = testdata_it.to(device)
        testtargets_it = testtargets_it.to(device)

        # Test set forward pass
        test_output, test_mem_rec = net(testdata_it.view(batch_size, 1, 28, 28))

        # Test set loss
        log_p_ytest = log_softmax_fn(test_mem_rec)
        log_p_ytest = log_p_ytest.sum(dim=0)
        loss_val_test = loss_fn(log_p_ytest, testtargets_it)
        test_loss_hist.append(loss_val_test.item())

        # Print test/train loss/accuracy
        if counter % 50 == 0:
          train_printer()
        minibatch_counter += 1
        counter += 1

loss_hist_true_grad = loss_hist
test_loss_hist_true_grad = test_loss_hist

Although the loss is decreasing, we can tell `SRM0` is struggling based on its low accuracy.
This might indicate to us that `SRM0` is only appropriate for objectives related to spike-timing.
You can try out for yourself to see what happens if you replace `SRM0` with `Stein`. Don't forget to also modify the state initialization.

## 4. Results
### 4.1 Plot Training/Test Loss

In [None]:
# Plot Loss
fig = plt.figure(facecolor="w", figsize=(10, 5))
plt.plot(loss_hist)
plt.plot(test_loss_hist)
plt.legend(["Test Loss", "Train Loss"])
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

### 4.2 Test Set Accuracy

In [None]:
total = 0
correct = 0
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)

with torch.no_grad():
  net.eval()
  for data in test_loader:
    images, labels = data
    images = images.to(device)
    labels = labels.to(device)

    # If current batch matches batch_size, just do the usual thing
    if images.size()[0] == batch_size:
      outputs, _ = net(images.view(batch_size, 1, 28, 28))

    # If current batch does not match batch_size (e.g., is the final minibatch),
    # modify batch_size in a temp variable and restore it at the end of the else block
    else:
      temp_bs = batch_size
      batch_size = images.size()[0]
      outputs, _ = net(images.view(images.size()[0], 1, 28, 28))
      batch_size = temp_bs

    _, predicted = outputs.sum(dim=0).max(1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print(f"Total correctly classified test set images: {correct}/{total}")
print(f"Test Set Accuracy: {100 * correct / total}%")

That didn't go quite according to plan.

## 5. Spiking MNIST
Let's try this again, but we'll make two changes:
* Replace `SRM0` with the `Stein` neuron model
* Apply rate-coding to our input data to make it time-variant

In [None]:
from snntorch import spikegen

# MNIST to spiking-MNIST
spike_data, spike_targets = spikegen.rate(data_it, targets_it, num_outputs=num_outputs, num_steps=num_steps,
                                                      gain=1, offset=0, convert_targets=False, temporal_targets=False)

### 5.1 Visualiser
We'll animate a sample as we've done on previous tutorials.

In [None]:
!pip install celluloid # animating matplotlib plots made easy

In [None]:
from celluloid import Camera
from IPython.display import HTML

# Animator
spike_data_sample = spike_data[:, 0, 0].cpu()

fig, ax = plt.subplots()
camera = Camera(fig)
plt.axis('off')

for step in range(num_steps):
    im = ax.imshow(spike_data_sample[step, :, :], cmap='plasma')
    camera.snap()

# interval=40 specifies 40ms delay between frames
a = camera.animate(interval=40)
HTML(a.to_html5_video())

In [None]:
print(spike_targets[0])

## 6. Define Network
The network is the same as before. The one difference is that the for-loop iterates through the first dimension of the input:
`cur1 = F.max_pool2d(self.conv1(x[step]), 2)`

In [None]:
from snntorch.neuron import Stein
spike_grad = FSS.apply
snn.neuron.slope = 50

# Define a different network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # initialize layers
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=12, kernel_size=5, stride=1, padding=1)
        self.lif1 = Stein(alpha=alpha, beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=64, kernel_size=5, stride=1, padding=1)
        self.lif1 = Stein(alpha=alpha, beta=beta, spike_grad=spike_grad)
        self.fc2 = nn.Linear(64*5*5, 10)
        self.lif1 = Stein(alpha=alpha, beta=beta, spike_grad=spike_grad)

    def forward(self, x):
        # Initialize LIF state variables and spike output tensors
        spk1, syn1, mem1 = self.lif1.init_stein(batch_size, 12, 13, 13)
        spk2, syn2, mem2 = self.lif1.init_stein(batch_size, 64, 5, 5)
        spk3, syn3, mem3 = self.lif2.init_stein(batch_size, 10)

        spk3_rec = []
        mem3_rec = []

        for step in range(num_steps):
            cur1 = F.max_pool2d(self.conv1(x[step]), 2) # add max-pooling to membrane or spikes?
            spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1)
            cur2 = F.max_pool2d(self.conv2(spk1), 2)
            spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)
            cur3 = self.fc2(spk2.view(batch_size, -1))
            spk3, syn3, mem3 = self.lif3(cur3, syn3, mem3)

            spk3_rec.append(spk3)
            mem3_rec.append(mem3)

        return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

net = Net().to(device)

## 7. Training
We make a slight modification to our print-out functions to handle the new first dimension of the input:

In [None]:
# Print batch accuracy function

def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(num_steps, batch_size, 1, 28, 28))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train Set Accuracy: {acc}")
    else:
        print(f"Test Set Accuracy: {acc}")

def train_printer():
    print(f"Epoch {epoch}, Minibatch {minibatch_counter}")
    print(f"Train Set Loss: {loss_hist[counter]}")
    print(f"Test Set Loss: {test_loss_hist[counter]}")
    print_batch_accuracy(spike_data, spike_targets, train=True)
    print_batch_accuracy(test_spike_data, test_spike_targets, train=False)
    print("\n")


### 7.1 Optimizer & Loss
We'll keep our optimizer and loss the exact same as the static MNIST case.

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=2e-4, betas=(0.9, 0.999))
log_softmax_fn = nn.LogSoftmax(dim=-1)
loss_fn = nn.NLLLoss()

### 7.2 Training Loop

In [None]:
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(3):
    minibatch_counter = 0
    data = iter(train_loader)

    # Minibatch training loop
    for data_it, targets_it in data:
        data_it = data_it.to(device)
        targets_it = targets_it.to(device)

        # Spike generator
        spike_data, spike_targets = spikegen.rate(data_it, targets_it, num_outputs=num_outputs, num_steps=num_steps,
                                                  gain=1, offset=0, convert_targets=False, temporal_targets=False)

        # Forward pass
        output, mem_rec = net(spike_data.view(num_steps, batch_size, 1, 28, 28))
        log_p_y = log_softmax_fn(mem_rec)
        loss_val = torch.zeros((1), dtype=dtype, device=device)

        # Sum loss over time steps to perform BPTT
        for step in range(num_steps):
          loss_val += loss_fn(log_p_y[step], targets_it)

        # Gradient Calculation
        optimizer.zero_grad()
        loss_val.backward(retain_graph=True)
        nn.utils.clip_grad_norm_(net.parameters(), 1)

        # Weight Update
        optimizer.step()

        # Store Loss history
        loss_hist.append(loss_val.item())

        # Test set
        test_data = itertools.cycle(test_loader)
        testdata_it, testtargets_it = next(test_data)
        testdata_it = testdata_it.to(device)
        testtargets_it = testtargets_it.to(device)

        # Test set spike conversion
        test_spike_data, test_spike_targets = spikegen.rate(testdata_it, testtargets_it, num_outputs=num_outputs,
                                                            num_steps=num_steps, gain=1, offset=0, convert_targets=False,
                                                            temporal_targets=False)

        # Test set forward pass
        test_output, test_mem_rec = net(test_spike_data.view(num_steps, batch_size, 1, 28, 28))

        # Test set loss
        log_p_ytest = log_softmax_fn(test_mem_rec)
        log_p_ytest = log_p_ytest.sum(dim=0)
        loss_val_test = loss_fn(log_p_ytest, test_spike_targets)
        test_loss_hist.append(loss_val_test.item())

        # Print test/train loss/accuracy
        if counter % 50 == 0:
            train_printer()
        minibatch_counter += 1
        counter += 1

loss_hist_true_grad = loss_hist
test_loss_hist_true_grad = test_loss_hist

## 8. Spiking MNIST Results
### 8.1 Plot Training/Test Loss

In [None]:
# Plot Loss
fig = plt.figure(facecolor="w", figsize=(10, 5))
plt.plot(loss_hist)
plt.plot(test_loss_hist)
plt.legend(["Test Loss", "Train Loss"])
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

### 8.2 Test Set Accuracy

In [None]:
total = 0
correct = 0
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)

with torch.no_grad():
  net.eval()
  for data in test_loader:
    images, labels = data
    images = images.to(device)
    labels = labels.to(device)

    # If current batch matches batch_size, just do the usual thing
    if images.size()[0] == batch_size:
      spike_test, spike_targets = spikegen.rate(images, labels, num_outputs=num_outputs, num_steps=num_steps,
                                                            gain=1, offset=0, convert_targets=False, temporal_targets=False)

      outputs, _ = net(spike_test.view(num_steps, batch_size, 1, 28, 28))

    # If current batch does not match batch_size (e.g., is the final minibatch),
    # modify batch_size in a temp variable and restore it at the end of the else block
    else:
      temp_bs = batch_size
      batch_size = images.size()[0]
      spike_test, spike_targets = spikegen.rate(images, labels, num_outputs=num_outputs, num_steps=num_steps,
                                                            gain=1, offset=0, convert_targets=False, temporal_targets=False)
      outputs, _ = net(spike_test.view(num_steps, images.size()[0], 1, 28, 28))
      batch_size = temp_bs

    _, predicted = outputs.sum(dim=0).max(1)
    total += spike_targets.size(0)
    correct += (predicted == spike_targets).sum().item()

print(f"Total correctly classified test set images: {correct}/{total}")
print(f"Test Set Accuracy: {100 * correct / total}%")

Professor Lu has kidnapped my daughter and won't return her until I hit 99.99% accuracy, please help
-JE