## Goal: Compare the Mean Performance of 10 ANNs and 10 SNNs on the MNIST Task

In [None]:
!pip install snntorch --quiet

In [None]:
#From SNN
import torch, torch.nn as nn
import snntorch as snn
from statistics import mean

#From ANN
import torch.nn.functional as F
import torch.optim as optim


### DataLoading
Define variables for dataloading.

In [None]:
batch_size = 128
data_path= "raw"
#'/data/mnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Torch Variables
dtype = torch.float

Load MNIST dataset.

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

In [None]:
test_loader.dataset[0][0].size()

### Define Network with snnTorch. 
* `snn.Leaky()` instantiates a simple leaky integrate-and-fire neuron.
* `spike_grad` optionally defines the surrogate gradient. If left undefined, the relevant gradient term is simply set to the output spike itself (1/0) by default.


The problem with `nn.Sequential` is that each hidden layer can only pass one tensor to subsequent layers, whereas most spiking neurons return their spikes and hidden state(s). To handle this:

* `init_hidden` initializes the hidden states (e.g., membrane potential) as instance variables to be processed in the background. 

The final layer is not bound by this constraint, and can return multiple tensors:
* `output=True` enables the final layer to return the hidden state in addition to the spike.

In [None]:
from snntorch import surrogate

beta = 0.9  # neuron decay rate 
spike_grad = surrogate.fast_sigmoid()

#  Initialize Network
net = nn.Sequential(nn.Conv2d(1, 8, 5),
                    #out channel = in channel of subsequent layer
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(8, 16, 5), #8 in channels, 16 out channels, kernel of size 5
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(16*4*4, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)

In [None]:
#For Testing Purposes - Not included in experiment
# net1 = nn.Sequential(nn.Conv2d(1, 8, 5),
#                     nn.MaxPool2d(2),
#                     snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
#                     nn.Conv2d(8, 16, 5),
#                     nn.MaxPool2d(2),
#                     snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
#                     nn.Conv2d(16, 16, 1),
#                     nn.MaxPool2d(2),
#                     snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
#                     nn.Flatten(),
#                     nn.Linear(16*2*2, 10),
#                     snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
#                     ).to(device)

### Define the Forward Pass
Now define the forward pass over multiple time steps of simulation.

In [None]:
from snntorch import utils 

def forward_pass(net, data, num_steps):  
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps): 
      spk_out, mem_out = net(data)
      spk_rec.append(spk_out)
  
  return torch.stack(spk_rec)

Define the optimizer and loss function. Here, we use the MSE Count Loss, which counts up the total number of output spikes at the end of the simulation run. The correct class has a target firing rate of 80% of all time steps, and incorrect classes are set to 20%. 

In [None]:
import snntorch.functional as SF

optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

The accuracy on the full test set, again using `SF.accuracy_rate`.

In [None]:
def test_accuracy(data_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    data_loader = iter(data_loader)
    for data, targets in data_loader:
      data = data.to(device)
      targets = targets.to(device)
      spk_rec = forward_pass(net, data, num_steps)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total

## Training Loop

Now for the training loop. The predicted class will be set to the neuron with the highest firing rate, i.e., a rate-coded output. We will just measure accuracy on the training set. This training loop follows the same syntax as with PyTorch.

In [None]:
num_epochs = 3
num_replicates = 2
num_steps = 25  # run for 25 time steps 

In [None]:
loss_hist = []
acc_hist = []
snn_val_list = []

# training loop
for replicate in range(num_replicates):
  for epoch in range(num_epochs):
      for i, (data, targets) in enumerate(iter(train_loader)):
          data = data.to(device)
          targets = targets.to(device)

          net.train()
          spk_rec = forward_pass(net, data, num_steps)
          loss_val = loss_fn(spk_rec, targets)

          # Gradient calculation + weight update
          optimizer.zero_grad()
          loss_val.backward()
          optimizer.step()

          # Store loss history for future plotting
          loss_hist.append(loss_val.item())

          # print every 25 iterations
          if i % 25 == 0:
            #print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

            # check accuracy on a single batch
            acc = SF.accuracy_rate(spk_rec, targets)  
            acc_hist.append(acc)
            #print(f"Accuracy: {acc * 100:.2f}%\n")
          
          # uncomment for faster termination
          # if i == 150:
          #     break
    
  snn_val_list.append(test_accuracy(test_loader, net, num_steps))
  
print(snn_val_list)
print(mean(snn_val_list))
print(f"the average performance of this artificial neural network is {mean(snn_val_list)}")


In [None]:
print(f"Test set accuracy: {test_accuracy(test_loader, net, num_steps)*100:.3f}%")

### Training an a non-spiking NN on MNIST (as provided by PyTorch)
We will do the following steps in order:
1. Load and normalize the MNIST training and test datasets using torchvision
2. Define a Convolutional Neural Network
3. Define a loss function
4. Train the network on the training data
5. Test the network on the test data

In [None]:
n_epochs = 3
num_replicates = 2
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

In [None]:
#Working on making this analogous to SNN above
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        #x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)


# class Net(nn.Module):
#     def __init__(self):
#         super(Net, self).__init__()
#         self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
#         self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
#         self.conv2_drop = nn.Dropout2d()
#         self.fc1 = nn.Linear(320, 50)
#         self.fc2 = nn.Linear(50, 10)

#     def forward(self, x):
#         x = F.relu(F.max_pool2d(self.conv1(x), 2))
#         x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
#         x = x.view(-1, 320)
#         x = F.relu(self.fc1(x))
#         x = F.dropout(x, training=self.training)
#         x = self.fc2(x)
#         return F.log_softmax(x)

In [None]:
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)

In [None]:
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [None]:
def train(epoch):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
      # torch.save(network.state_dict(), '/results/model.pth')
      # torch.save(optimizer.state_dict(), '/results/optimizer.pth')

In [None]:
def test():
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  #print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
  return 100. * correct / len(test_loader.dataset)

In [None]:
nn_val_list = []
#test()
for replicate in range(num_replicates):
  for epoch in range(1, n_epochs + 1):
    train(epoch)
    #test()
  #nn_val_list.append(test())
  nn_val_list.append(float(test()))
print(nn_val_list)

In [None]:
import matplotlib.pyplot as plt
fig = plt.figure()
plt.plot(train_counter, train_losses, color='blue')
#plt.scatter(test_counter, test_losses, color='red')
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen')
plt.ylabel('negative log likelihood loss')
#fig