In [1]:
# imports
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import itertools

from snn_data_utils import WaveGuideDataset, ToTensor

## Preparing for Surrogate Gradient Descent

In [2]:
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5

## Setting up the CSNN

### Dataloaders

In [3]:
batch_size = 16
data_path = 'root/imgs'
data_csv = 'final_image_database_loss_deciles.csv'

dtype = torch.float32
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
transform = transforms.Compose([ToTensor()])
wg_dataset = WaveGuideDataset(root_dir=data_path, csv_file=data_csv, transform=transform)

In [5]:
len(wg_dataset)

233

In [6]:
wg_train, wg_val = torch.utils.data.random_split(wg_dataset, [176, 57], generator=torch.Generator().manual_seed(42))

In [7]:
train_loader = DataLoader(wg_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(wg_val, batch_size=batch_size, shuffle=True, drop_last=True)

### Define the network (CSNN)
The convolutional network architecture to be used is: 12C5-MP2-64C5-MP2-1024FC10:

-12C5 is a 5x5 convolutional kernel with 12 filters

-MP2 is a 2x2 max-pooling function

-1024FC10 is a fully-connected layer that maps 1,024 neurons to 10 outputs 


In [8]:
# neuron and simulation parameters
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5
num_steps = 50

In [9]:
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.conv1 = nn.Conv2d(3, 12, 5)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(12, 64, 5)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc1 = nn.Linear(64*107*107, 10)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad, output=True)

    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        #print("Size of x:", x.size())
        cur1 = F.max_pool2d(self.conv1(x), 2)
        spk1, mem1 = self.lif1(cur1, mem1)

        #print("Size after c1:", cur1.size())
        cur2 = F.max_pool2d(self.conv2(spk1), 2)
        spk2, mem2 = self.lif2(cur2, mem2)

        #print("Size after c2:", cur2.size())
        cur3 = self.fc1(spk2.view(batch_size, -1))
        spk3, mem3 = self.lif3(cur3, mem3)

        return spk3, mem3

# Load the network onto CUDA if available
net = Net().to(device)

### Forward Pass

In [15]:
data, targets = next(enumerate(train_loader))
data = targets['image'].to(device, dtype=torch.float)
targets = targets['deciles'].to(device, dtype=torch.long)

#for step in range(num_steps):
    #spk_out, mem_out = net(data)

In [16]:
def forward_pass(net, num_steps, data):
    mem_rec = []
    spk_rec = []
    utils.reset(net)  # resets hidden states for all LIF neurons in net

    for step in range(num_steps):
        spk_out, mem_out = net(data)
        spk_rec.append(spk_out)
        mem_rec.append(mem_out)

    return torch.stack(spk_rec), torch.stack(mem_rec)

In [17]:
spk_rec, mem_rec = forward_pass(net, num_steps, data)

## Traininng the SNN

### Loss Definition

In [21]:
# already imported snntorch.functional as SF
loss_fn = SF.ce_rate_loss()

print(spk_rec.size(), targets[:,0][:,0].size())

loss_val = loss_fn(spk_rec, targets[:,0][:,0])
print(f"The loss from an untrained network is {loss_val.item():.3f}")

torch.Size([50, 16, 10]) torch.Size([16])
The loss from an untrained network is 2.303


### Accuracy Metrics

In [19]:
acc = SF.accuracy_rate(spk_rec, targets[:,0][:,0])
print(f"The accuracy of a single batch using an untrained network is {acc*100:.3f}%")

The accuracy of a single batch using an untrained network is 100.000%


In [24]:
def batch_accuracy(train_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    train_loader = enumerate(train_loader)
    for data, targets in train_loader:
      data = targets['image'].to(device, dtype=torch.float)
      targets = targets['deciles'].to(device, dtype=torch.long)
      spk_rec, _ = forward_pass(net, num_steps, data)

      acc += SF.accuracy_rate(spk_rec, targets[:,0][:,0]) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total

In [25]:
test_acc = batch_accuracy(test_loader, net, num_steps)
print(f"The total accuracy on the test set is: {test_acc * 100:.2f}%")

The total accuracy on the test set is: 100.00%


### Training Loop

In [30]:
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.999))
num_epochs = 1
loss_hist = []
test_acc_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):

    # Training loop
    for data, targets in enumerate(train_loader):
        data = targets['image'].to(device, dtype=torch.float)
        targets = targets['deciles'].to(device, dtype=torch.long)

        # forward pass
        net.train()
        spk_rec, _ = forward_pass(net, num_steps, data)

        # initialize the loss & sum over time
        loss_val = loss_fn(spk_rec, targets[:,0][:,0])

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        if counter % 1 == 0:
            with torch.no_grad():
                net.eval()

                # Test set forward pass
                test_acc = batch_accuracy(test_loader, net, num_steps)
                print(f"Iteration {counter}, Test Acc: {test_acc * 100:.2f}%\n")
                test_acc_hist.append(test_acc.item())

        counter += 1

RuntimeError: [enforce fail at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\c10\core\impl\alloc_cpu.cpp:81] data. DefaultCPUAllocator: not enough memory: you tried to allocate 46895104 bytes.

## Result Analysis

### Plot Test Accuracy

In [None]:
# Plot Loss
fig = plt.figure(facecolor="w")
plt.plot(test_acc_hist)
plt.title("Test Set Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.show()

### Spike Counter

spk_rec, mem_rec = forward_pass(net, num_steps, data)

In [None]:
from IPython.display import HTML

idx = 0

fig, ax = plt.subplots(facecolor='w', figsize=(12, 7))
labels=['0', '1', '2', '3', '4', '5', '6', '7', '8','9']

# plt.rcParams['animation.ffmpeg_path'] = 'C:\\path\\to\\your\\ffmpeg.exe'

#  Plot spike count histogram
anim = splt.spike_count(spk_rec[:, idx].detach().cpu(), fig, ax, labels=labels,
                        animate=True, interpolate=4)

HTML(anim.to_html5_video())
# anim.save("spike_bar.mp4")

In [None]:
print(f"The target label is: {targets[idx]}")