So this is a working convolutional SNN with the Architecture proposed by that one paper I suggested.
This is neither optimized code nor is it compatible with our data yet but I thought it is a reasonable starting place.

So this Network receives the raw data as inputs not spikes. It seems easiest to just put in the spectogram data (what we called voltage) into the LIF neurons instead of using the poisson encoding first. (poisson encoding was at least useful for us to see if we can generate sensible spikes from the data).  

Right now there is a convolution on the input data itself, I would start out by trying to avoid and only convolve in the spike domain, but we'll see. I'm also unsere whether we should give the input timestep by timestep or not. With the setup right now, it expects it all at once.

In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from spikingjelly.activation_based import neuron, functional, surrogate, layer
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

2024-08-21 17:08:47.896882: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
# Define the Spiking Neural Network class
class CSNN(nn.Module):
    def __init__(self, T=8, spike_grad=surrogate.ATan(), threshold=1.0, num_class=10):
        super(CSNN, self).__init__()

        self.T = T  #time steps for temporal integration

        # Define the layers, right now I hardcoded the shapes
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.lif1 = neuron.LIFNode(surrogate_function=spike_grad, v_threshold=threshold)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.lif2 = neuron.LIFNode(surrogate_function=spike_grad, v_threshold=threshold)
        self.fc1 = nn.Linear(16 * 5 * 5, 128) # Adjusted input size for 2 poolings
        self.lif3 = neuron.LIFNode(surrogate_function=spike_grad, v_threshold=threshold)
        self.fc2 = nn.Linear(128, 64)
        self.lif4 = neuron.LIFNode(surrogate_function=spike_grad, v_threshold=threshold)
        self.fc3 = nn.Linear(64, num_class)
        self.lif5 = neuron.LIFNode(surrogate_function=spike_grad, v_threshold=threshold)

        #batch norm and dropout improve test accuracy
        self.bn1 = nn.BatchNorm2d(6)
        self.bn2 = nn.BatchNorm2d(16)
        self.dropout1 = nn.Dropout(0.2)
        self.dropout2 = nn.Dropout(0.5)

    def forward(self, x):
        # Apply convolutions and pooling without the time loop, here we do 1 convolution before spike encoding
        # the order I've seen being used is conv, batchnorm, LIF, pool
        x = self.bn1(self.conv1(x))
        x = self.pool(self.lif1(x)) 
        x = self.bn2(self.conv2(x))
        x = self.pool(self.lif2(x)) 
        x = self.dropout2(x)
        
        x = x.view(-1, 16 * 5 * 5)  # flatten fc layers
        
        # Now apply the time loop to LIF neurons in the fully connected layers
        mem_fc1 = torch.zeros_like(self.fc1(x)) # Initialize membrane potential
        mem_fc2 = torch.zeros_like(self.fc2(mem_fc1))
        mem_fc3 = torch.zeros_like(self.fc3(mem_fc2))

        for step in range(self.T):
            mem_fc1 = self.lif3(self.fc1(x))
            mem_fc2 = self.lif4(self.fc2(mem_fc1))
            #x = self.dropout2(mem_fc2)
            mem_fc3 = self.lif5(self.fc3(mem_fc2))
        
        return mem_fc3  # Return the final membrane potential

In [5]:
# Hyperparameters
batch_size = 64
learning_rate = 0.001 # 0.001 worked well on my tests
num_epochs = 10 # only 10 right now so it doesn't take forever to load
time_steps = 8  # Time steps for temporal integration, might need to increase it

# Load FashionMNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_set = torchvision.datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
test_set = torchvision.datasets.FashionMNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

# Initialize model, loss and optimizer
model = CSNN(T=time_steps)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
#optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-5)

# Training
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        functional.reset_net(model)  # we have to reset after each batch

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

# Testing
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        functional.reset_net(model)  
        
print(f'Accuracy on the test set: {100 * correct / total:.2f}%')

Epoch [1/10], Loss: 1.7680
Epoch [2/10], Loss: 1.6347
Epoch [3/10], Loss: 1.6232
Epoch [4/10], Loss: 1.6159
Epoch [5/10], Loss: 1.6099
Epoch [6/10], Loss: 1.6064
Epoch [7/10], Loss: 1.6045
Epoch [8/10], Loss: 1.6009
Epoch [9/10], Loss: 1.5982
Epoch [10/10], Loss: 1.5973
Accuracy on the test set: 76.42%
