# snnTorch Test: Making $\alpha$ and $\beta$ learnable parameters
### By Jason K. Eshraghian

## Gradient-based Learning in Spiking Neural Networks

In [None]:
# Install the test PyPi Distribution of snntorch
!pip install -i https://test.pypi.org/simple/ snntorch

## 1. Setting up the Static MNIST Dataset
### 1.1. Import packages and setup environment

In [None]:
import snntorch as snn
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import itertools
import matplotlib.pyplot as plt

### 1.2 Define network and SNN parameters
We will use a 784-1000-10 FCN architecture for a sequence of 25 time steps.

* `alpha` is the decay rate of the synaptic current of a neuron
* `beta` is the decay rate of the membrane potential of a neuron

In [None]:
# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10

# Training Parameters
batch_size=128
data_path='/data/mnist'

# Temporal Dynamics
num_steps = 25
time_step = 1e-3
tau_mem = 3e-3
tau_syn = 2.2e-3

# these will be overridden
alpha = float(np.exp(-time_step/tau_syn))
beta = float(np.exp(-time_step/tau_mem))

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

### 1.3 Download MNIST Dataset

In [None]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

### 1.4 Create DataLoaders

In [None]:
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

## 2. Define Network

Override snn.SRM0:
* np.log(alpha) is now torch.log(alpha)
* update tau_srm every time alpha and beta are changed by placing it in `forward`

In [None]:
class SRM0(snn.LIF):

    def __init__(self, alpha, beta, threshold=1.0, spike_grad=None):
        super(SRM0, self).__init__(alpha, beta, threshold, spike_grad)

        self.alpha = alpha
        self.beta = beta
        self.threshold = threshold

    def forward(self, input_, syn_pre, syn_post, mem):

        self.tau_srm = torch.log(self.alpha) / (torch.log(self.beta) - torch.log(self.alpha)) + 1
        spk, reset = self.fire(mem)
        syn_pre = (self.alpha*syn_pre + input_) * (1 - reset)
        syn_post = (self.beta * syn_post - input_) * (1 - reset)
        mem = self.tau_srm * (syn_pre + syn_post)*(1-reset) + (mem*reset - reset * self.threshold)

        return spk, syn_pre, syn_post, mem

Note: alpha and beta are a single parameter each, therefore 1x1

In [None]:
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

    # initialize layers
        self.alpha = nn.Linear(1, 1, bias=False)
        self.beta = nn.Linear(1, 1, bias=False)
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = SRM0(alpha=self.alpha.weight, beta=self.beta.weight, threshold=0.3) # NOTE: lowering threshold induces more spiking, coz SRM0 struggles with spiking.
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = SRM0(alpha=self.alpha.weight, beta=self.beta.weight, threshold=0.3)

    def forward(self, x):
        spk1, pre_syn1, post_syn1, mem1 = self.lif1.init_srm0(batch_size, num_hidden)
        spk2, pre_syn2, post_syn2, mem2 = self.lif2.init_srm0(batch_size, num_outputs)

        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, pre_syn1, post_syn1, mem1 = self.lif1(cur1, pre_syn1, post_syn1, mem1)
            cur2 = self.fc2(spk1)
            spk2, pre_syn2, post_syn2, mem2 = self.lif2(cur2, pre_syn2, post_syn2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

net = Net().to(device)

However, we need to check that alpha > beta AND alpha & beta > 0.
So re-run the above till the following works.

In [None]:
print(net.alpha.weight)
print(net.beta.weight)

## 3. Training

In [1]:
def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train Set Accuracy: {acc}")
    else:
        print(f"Test Set Accuracy: {acc}")

def train_printer():
    print(net.alpha.weight)
    print(net.beta.weight)
    print(f"Epoch {epoch}, Minibatch {minibatch_counter}")
    print(f"Train Set Loss: {loss_hist[counter]}")
    print(f"Test Set Loss: {test_loss_hist[counter]}")
    print_batch_accuracy(data_it, targets_it, train=True)
    print_batch_accuracy(testdata_it, testtargets_it, train=False)
    print("\n")

### 3.1 Training Loop

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=2e-4, betas=(0.9, 0.999))
log_softmax_fn = nn.LogSoftmax(dim=-1)
loss_fn = nn.NLLLoss()

loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(3):
    minibatch_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data_it, targets_it in train_batch:
        data_it = data_it.to(device)
        targets_it = targets_it.to(device)

        output, mem_rec = net(data_it.view(batch_size, -1))
        log_p_y = log_softmax_fn(mem_rec)
        loss_val = torch.zeros((1), dtype=dtype, device=device)

        # Sum loss over time steps: BPTT
        for step in range(num_steps):
          loss_val += loss_fn(log_p_y[step], targets_it)

        # Gradient calculation
        optimizer.zero_grad()
        loss_val.backward(retain_graph=True)

        # Weight Update
        nn.utils.clip_grad_norm_(net.parameters(), 1) # gradient clipping
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        test_data = itertools.cycle(test_loader)
        testdata_it, testtargets_it = next(test_data)
        testdata_it = testdata_it.to(device)
        testtargets_it = testtargets_it.to(device)

        # Test set forward pass
        test_output, test_mem_rec = net(testdata_it.view(batch_size, -1))

        # Test set loss
        log_p_ytest = log_softmax_fn(test_mem_rec)
        log_p_ytest = log_p_ytest.sum(dim=0)
        loss_val_test = loss_fn(log_p_ytest, testtargets_it)
        test_loss_hist.append(loss_val_test.item())

        # Print test/train loss/accuracy
        if counter % 50 == 0:
            train_printer()
        minibatch_counter += 1
        counter += 1

loss_hist_true_grad = loss_hist
test_loss_hist_true_grad = test_loss_hist

## 4. Results
Observations:
* alpha wants to go to 1
* beta wants to go to 0
* once beta hits 0, losses were all `nan`
* Using SRM0, acc went up to 60% then tapered off btwn 35-40%
* Perhaps there would be more promise using snn.Stein

### 4.1 Plot Training/Test Loss

In [None]:
# Plot Loss
fig = plt.figure(facecolor="w", figsize=(10, 5))
plt.plot(loss_hist)
plt.plot(test_loss_hist)
plt.legend(["Test Loss", "Train Loss"])
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

### 4.2 Test Set Accuracy
This function just iterates over all minibatches to obtain a measure of accuracy over the full 10,000 samples in the test set.

In [None]:
total = 0
correct = 0
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)

with torch.no_grad():
  net.eval()
  for data in test_loader:
    images, labels = data
    images = images.to(device)
    labels = labels.to(device)

    # If current batch matches batch_size, just do the usual thing
    if images.size()[0] == batch_size:
      outputs, _ = net(images.view(batch_size, -1))

    # If current batch does not match batch_size (e.g., is the final minibatch),
    # modify batch_size in a temp variable and restore it at the end of the else block
    else:
      temp_bs = batch_size
      batch_size = images.size()[0]
      outputs, _ = net(images.view(images.size()[0], -1))
      batch_size = temp_bs

    _, predicted = outputs.sum(dim=0).max(1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print(f"Total correctly classified test set images: {correct}/{total}")
print(f"Test Set Accuracy: {100 * correct / total}%")