In [13]:
# imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
import itertools

from snn_data_utils import WaveGuideDataset, ToTensor

## Setting up the Static Wave Guide Dataset

In [32]:
batch_size = 4
data_path = 'root/imgs'
data_csv = 'final_image_database_loss_deciles.csv'

dtype = torch.float32
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [15]:
transform = transforms.Compose([ToTensor()])
wg_dataset = WaveGuideDataset(root_dir=data_path, csv_file=data_csv, transform=transform)

In [16]:
len(wg_dataset)

233

In [17]:
wg_train, wg_val = torch.utils.data.random_split(wg_dataset, [173, 60], generator=torch.Generator().manual_seed(42))

In [18]:
train_loader = DataLoader(wg_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(wg_val, batch_size=batch_size, shuffle=True, drop_last=True)


## Define the network (A simple Spiked MLP)

In [19]:
# Network Architecture
num_inputs = 442*442*3
num_hidden = 1024
num_outputs = 10

# Temporal Dynamics
num_steps = 50
beta = 0.95

In [56]:
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):
        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Load the network onto CUDA if available
net = Net().to(device)

## Traininng the SNN

### Accuracy Metric 

In [96]:
# pass data into the network, sum the spikes over time
# and compare the neuron with the highest number of spikes
# with the target

def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    print(idx)
    print() 
    # As our net has 10 outputs, we can extract any column that idx maps into (each line has equal values)
    acc = np.mean(torch.flatten(targets[:, idx][:, :1]).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

### Loss Definition

In [22]:
loss = nn.CrossEntropyLoss()

### Optimizer

In [23]:
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

### One iteration

In [34]:
data, targets = next(enumerate(train_loader))
data = targets['image'].to(device, dtype=torch.float)
targets = targets['deciles'].to(device)

In [41]:
targets

torch.Size([4, 1, 10])

In [57]:
spk_rec, mem_rec = net(data.view(data.size(0), -1))

In [50]:
mem_rec[0].size(), targets[:,0].size()


(torch.Size([4, 10]), torch.Size([4, 10]))

In [51]:
loss(mem_rec[0], targets[:,0])

tensor(2.3049, grad_fn=<DivBackward1>)

In [97]:
print_batch_accuracy(data, targets[:,0])

tensor([4, 4, 4, 4])
tensor([0.0662, 0.1258, 0.0397, 0.1722])
Test set accuracy for a single minibatch: 10.10%


### Training Loop

In [95]:
num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = enumerate(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = targets['image'].to(device, dtype=dtype) # Must have dtype explicit param!! Avoid trouble
        targets = targets['deciles'].to(device, dtype=dtype) # Must have dtype explicit param!! Avoid trouble

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets[:,0])

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(enumerate(test_loader))
            test_data = test_targets['image'].to(device, dtype=dtype)
            test_targets = test_targets['deciles'].to(device, dtype=dtype)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets[:,0])
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer()
            counter += 1
            iter_counter +=1

Epoch 0, Iteration 0
Train Set Loss: 149.74
Test Set Loss: 141.53
tensor([4, 4, 4, 4])


IndexError: index 4 is out of bounds for dimension 0 with size 1