<a href="https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_7_neuromorphic_datasets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt
from snntorch import spikegen


import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import itertools

# Training Parameters
batch_size=128
data_path='/data/mnist'
num_classes = 10  # MNIST has 10 output classes
num_steps = 100
TAU = 5
THRESHOLD = 0.8
beta = 0.5

PATH = "fcn_snn_mnist_latency_tau_5_thresh_0_8_beta_0_5_num_steps_100.pt"

# Torch Variables
dtype = torch.float

from torchvision import datasets, transforms

# Define a transform
transform = transforms.Compose([
            transforms.Resize((32,32)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)

mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# neuron and simulation parameters
spike_grad = surrogate.atan()

#  Initialize Network
net = nn.Sequential(nn.Linear(1024, 1000),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Linear(1000, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True),
                    ).to(device)

In [None]:
# this time, we won't return membrane as we don't need it 

def forward_pass(net, data):  
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(data.size(0)):  # data.size(0) = number of time steps
      m_batch_size = data.size(1)
      spk_out, mem_out = net(data[step].view(m_batch_size, -1))
      spk_rec.append(spk_out)
  
  return torch.stack(spk_rec)

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=2e-2, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss()
# loss_fn = SF.mse_temporal_loss()

In [None]:
num_epochs = 5
num_iters = 50

loss_hist = []
acc_hist = []

net.train()

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
#         data = spikegen.rate(data, num_steps=num_steps)
        data = spikegen.latency(data, num_steps=num_steps, tau=5, threshold=THRESHOLD, clip=True, normalize=True, linear=True)


        targets = targets.to(device)

#         net.train()
#         print(f'input data size = {data.size()}')
#         print(f'input data.view(batch_size, -1) size = {data.view(batch_size, -1).size()}')

        spk_rec = forward_pass(net, data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())
 
        print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

        acc = SF.accuracy_rate(spk_rec, targets) 
        acc_hist.append(acc)
        print(f"Accuracy: {acc * 100:.2f}%\n")

#         This will end train÷ing after 50 iterations by default
        if i == num_iters:
          break

In [None]:
import matplotlib.pyplot as plt

# Plot Loss
fig = plt.figure(facecolor="w")
plt.plot(acc_hist)
plt.title("Train Set Accuracy")
plt.xlabel("Iteration")
plt.ylabel("Accuracy")
plt.show()

In [None]:
plt.plot(loss_hist)

In [None]:
def batch_accuracy(data_loader, threshold, net):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()
    
    data_loader = iter(data_loader)
    for data, targets in data_loader:
      data = data.to(device)
      targets = targets.to(device)
        
      data = spikegen.latency(data, num_steps=num_steps, tau=5, threshold=threshold, clip=True, normalize=True, linear=True)

      spk_rec = forward_pass(net, data)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)
#       print(f'acc {acc}')

  return acc/total

In [None]:
test_acc = batch_accuracy(test_loader, 0.4, net)
print(f"The total accuracy on the test set is: {test_acc * 100:.2f}%")

In [None]:
torch.save(net.state_dict(), PATH)

In [None]:
spk_rec = forward_pass(net, data)

In [None]:
from IPython.display import HTML

idx = 1

fig, ax = plt.subplots(facecolor='w', figsize=(12, 7))
labels=['0', '1', '2', '3', '4', '5', '6', '7', '8','9']
print(f"The target label is: {targets[idx]}")

# plt.rcParams['animation.ffmpeg_path'] = 'C:\\path\\to\\your\\ffmpeg.exe'

#  Plot spike count histogram
anim = splt.spike_count(spk_rec[:, idx].detach().cpu(), fig, ax, labels=labels, 
                        animate=True, interpolate=1)

HTML(anim.to_html5_video())
# anim.save("spike_bar.mp4")

In [None]:
class SaveOutput:
    def __init__(self):
        self.inputs = []
        self.outputs = []
        
    def __call__(self, module, module_in, module_out):
        self.outputs.append(module_out)
        self.inputs.append(module_in)
        
    def clear(self):
        self.outputs = []
        self.inputs = []

In [None]:
# get a data batch of size 1 for simplicity
# THRESHOLD = 0.4
# PATH = "snn_mnist_latency_tau_5_thresh_0_4_beta_0_5_num_steps_100.pt"

test_loader_batch_size_one = DataLoader(mnist_test, batch_size=1, shuffle=False)
for data, targets in test_loader_batch_size_one:
    print(data.size())
    data = spikegen.latency(data, num_steps=num_steps, tau=5, threshold=THRESHOLD, clip=True, normalize=True, linear=True)
    print(data.size())
    print(targets)
    break

input_data_spikes = torch.count_nonzero(data)


In [None]:
# load the pretrained model

# model_save_path = PATH
# net = nn.Sequential(nn.Conv2d(1, 12, 5), #in=1x[32x32] out=12x[28x28]
#                     snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
#                     nn.MaxPool2d(2), #in=12x[28x28] out=12x[14x14]
#                     nn.Conv2d(12, 32, 5),#in=12x[14x14] out=32x[10x10]
#                     snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
#                     nn.MaxPool2d(2), #in=32x[10x10] out=32x[5x5]
#                     nn.Flatten(),
#                     nn.Linear(32*5*5, 10), 
#                     snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
#                     ).to(device)
# net.load_state_dict(torch.load(model_save_path))
# net.eval()

In [None]:
save_output = SaveOutput()

hook_handles = []

for layer in net.modules():
    if isinstance(layer, snn.Leaky):
        handle = layer.register_forward_hook(save_output)
        hook_handles.append(handle)

In [None]:
save_output.clear()
data = data.to(device)
spk_rec = forward_pass(net, data)

In [None]:
added_tensor = torch.zeros(spk_rec.size()[1:])
for i in range(spk_rec.size()[0]):
    added_tensor += spk_rec[i]
    
added_tensor

In [None]:
len(save_output.outputs)

In [None]:
# i = 0
leaky1_none_zero_outputs = 0
# leaky2_none_zero_outputs = 0
leaky3_none_zero_outputs = 0

for i in range(0, len(save_output.outputs), 2):
    l1 = save_output.outputs[i]
#     l2 = save_output.outputs[i+1]
    l3 = save_output.outputs[i+1][0]
#     print("l1 size", l1.size())
    leaky1_none_zero_outputs += torch.count_nonzero(l1)
#     print("l1", l1)
#     print("l2 size", l2.size())
#     leaky2_none_zero_outputs += torch.count_nonzero(l2)
#     print("l2", l2)
#     print("l3 size", l3.size())
    leaky3_none_zero_outputs += torch.count_nonzero(l3)
#     print("l3", l3)
    
#     if i == 2:
#     break

print(f'input_data_spikes = {input_data_spikes}')
print(f'leaky1_none_zero_outputs {leaky1_none_zero_outputs}')
# print(f'leaky2_none_zero_outputs {leaky2_none_zero_outputs}')
print(f'leaky3_none_zero_outputs {leaky3_none_zero_outputs}')

In [None]:
#  Initialize Network
data_in = data[0]
print("in",data_in.size())

c1_out = nn.Conv2d(1, 12, 5)(data_in)
# print("Conv2d",c1_out.size())

l1_out = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)(c1_out)
print("Leaky",l1_out.size())

p1_out = nn.MaxPool2d(2)(l1_out)
# print("MaxPool2d",p1_out.size())

c2_out = nn.Conv2d(12, 32, 5)(p1_out)
# print("Conv2d",c2_out.size())

l2_out = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)(c2_out)
print("Leaky",l2_out.size())

p2_out = nn.MaxPool2d(2)(l2_out)
# print("MaxPool2d",p2_out.size())

f1 = nn.Flatten()
# print("Flatten",f1_out.size())

li1_out = nn.Linear(32*5*5, 10)(f1(p2_out))
# print("Linear",li1_out.size())

l3_out = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)(li1_out)
print("Leaky",l3_out.size())

#                     snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
#                     nn.MaxPool2d(2), #in=12x[28x28] out=12x[14x14]
#                     nn.Conv2d(12, 32, 5),#in=12x[14x14] out=32x[10x10]
#                     snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
#                     nn.MaxPool2d(2), #in=32x[10x10] out=32x[5x5]
#                     nn.Flatten(),
#                     nn.Linear(32*5*5, 10), 
#                     snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
#                     ).to(device)