In [1]:
import tonic
import snntorch as snn
from snntorch import surrogate
import torch.nn as nn
import torch
import torchvision
from snntorch import functional as SF
from snntorch import spikeplot as splt
from snntorch import utils

In [2]:
root = "/Users/matteogiardina/Desktop/BEMACS 2/SNNs/BainsaSNNs/N-MNIST/n-mnist_data"

In [3]:
import tonic.transforms as transforms
sensor_size = tonic.datasets.NMNIST.sensor_size

frame_transform = transforms.Compose(
    [transforms.Denoise(filter_time = 10000),
     transforms.ToFrame(sensor_size = sensor_size, time_window = 1000)])

In [4]:
trainset = tonic.datasets.NMNIST(save_to=root, transform = frame_transform, train=True)
testset = tonic.datasets.NMNIST(save_to=root, transform = frame_transform,train=False)

In [12]:
from torch.utils.data import DataLoader
from tonic import DiskCachedDataset

cached_trainset = DiskCachedDataset(trainset, cache_path = "./cache/nmnist/train")
cached_dataloader = DataLoader(cached_trainset)

batch_size = 128
trainloader = DataLoader(cached_trainset,
                         batch_size = batch_size,
                         collate_fn = tonic.collation.PadTensors)

def load_sample_batched():
    events, target = next(iter(cached_dataloader))

load_sample_batched()

In [13]:
transform = tonic.transforms.Compose([torch.from_numpy,
                                      torchvision.transforms.RandomRotation([-10,10])])

cached_trainset = DiskCachedDataset(trainset, transform=transform, cache_path = "./cache/mnist/train")
# No augmentation (transform) for the test set
cached_testset = DiskCachedDataset(testset, cache_path = "./cache/mnist/train")

batch_size = 128
trainloader = DataLoader(cached_trainset, batch_size = batch_size,
                         collate_fn = tonic.collation.PadTensors(batch_first = False), shuffle = True)
testloader = DataLoader(cached_testset, batch_size = batch_size,
                        collate_fn = tonic.collation.PadTensors(batch_first = False))

Synaptic

In [31]:
# # --- Hyperparameters ---
# spike_grad = surrogate.atan()
# # We need TWO decay rates now: one for current (alpha) and one for membrane (beta)
# # Let's start with some common defaults. You can tune these.
# alpha = 0.9  # Synaptic current decay
# beta = 0.95   # Membrane potential decay

# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

# # --- Define the Synaptic CNN Class ---
# class SynapticCNN(nn.Module):
#     def __init__(self, alpha, beta, spike_grad):
#         super().__init__()

#         # --- Layer 1 ---
#         self.conv1 = nn.Conv2d(2, 12, 5)
#         self.pool1 = nn.MaxPool2d(2)
#         self.lif1 = snn.Synaptic(alpha=alpha, beta=beta, spike_grad=spike_grad)
        
#         # --- Layer 2 ---
#         self.conv2 = nn.Conv2d(12, 32, 5)
#         self.pool2 = nn.MaxPool2d(2)
#         self.lif2 = snn.Synaptic(alpha=alpha, beta=beta, spike_grad=spike_grad)
        
#         # --- Output Layer ---
#         self.flatten = nn.Flatten()
#         self.fc1 = nn.Linear(32*5*5, 10)
#         self.lif_out = snn.Synaptic(alpha=alpha, beta=beta, spike_grad=spike_grad, output=True)

#     def forward(self, x_seq):
#         # x_seq has shape [T, Batch, C, H, W]

#         # --- Initialize Hidden States ---
#         # We must do this manually outside the loop
#         syn1, mem1 = self.lif1.init_synaptic()
#         syn2, mem2 = self.lif2.init_synaptic()
#         syn_out, mem_out = self.lif_out.init_synaptic()
        
#         # This will record the output spikes at each step
#         spk_rec = []
#         mem_rec = []
        
#         # --- Temporal Loop (now inside the model) ---
#         for step in range(x_seq.size(0)):
#             x_step = x_seq[step] # Get input at current time step
            
#             # 1. Pass through Layer 1
#             cur1 = self.pool1(self.conv1(x_step))
#             spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1) # Pass states back in

#             # 2. Pass through Layer 2
#             cur2 = self.pool2(self.conv2(spk1))
#             spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)

#             # 3. Pass through Output Layer
#             flat = self.flatten(spk2)
#             cur_out = self.fc1(flat)
#             spk_out, syn_out, mem_out = self.lif_out(cur_out, syn_out, mem_out)
            
#             spk_rec.append(spk_out)
#             mem_rec.append(mem_out)

#         # Return all output spikes [T, Batch, 10]
#         return torch.stack(spk_rec, dim=0), torch.stack(mem_rec, dim=0)

Leaky

In [37]:
# --- Hyperparameters ---
spike_grad = surrogate.atan()
beta = 0.7   # Membrane potential decay (good, keep this)
threshold = 0.2 # ** NEW: Lowered threshold to encourage spiking **

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

# --- Define the Leaky CNN Class ---
class LeakyCNN(nn.Module):
    def __init__(self, beta, spike_grad, threshold):
        super().__init__()

        # --- Layer 1 ---
        self.conv1 = nn.Conv2d(2, 12, 5)
        self.pool1 = nn.MaxPool2d(2)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)
        
        # --- Layer 2 ---
        self.conv2 = nn.Conv2d(12, 32, 5)
        self.pool2 = nn.MaxPool2d(2)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold)
        
        # --- Output Layer ---
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32*5*5, 10)
        self.lif_out = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=threshold, output=True)

    def forward(self, x_seq):
        # x_seq has shape [T, Batch, C, H, W]

        # --- Initialize Hidden States ---
        # ** CHANGED: init_leaky() only returns mem **
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem_out = self.lif_out.init_leaky()
        
        spk_rec = []
        mem_rec = []
        
        # --- Temporal Loop ---
        for step in range(x_seq.size(0)):
            x_step = x_seq[step]
            
            # 1. Pass through Layer 1
            cur1 = self.pool1(self.conv1(x_step))
            spk1, mem1 = self.lif1(cur1, mem1) 

            # 2. Pass through Layer 2
            cur2 = self.pool2(self.conv2(spk1))
            spk2, mem2 = self.lif2(cur2, mem2)

            # 3. Pass through Output Layer
            flat = self.flatten(spk2)
            cur_out = self.fc1(flat)
            
            spk_out, mem_out = self.lif_out(cur_out, mem_out)
            
            spk_rec.append(spk_out)
            mem_rec.append(mem_out) 

        return torch.stack(spk_rec, dim=0), torch.stack(mem_rec, dim=0)

In [38]:
# net = SynapticCNN(alpha = alpha, beta = beta, spike_grad = spike_grad).to(device)

net = LeakyCNN(beta=beta, spike_grad=spike_grad, threshold=threshold).to(device)

In [33]:
optimizer = torch.optim.Adam(net.parameters(), lr = 2e-3, betas = (0.9, 0.999))
loss_fn = SF.mse_()

In [39]:
num_epochs = 1 # never reaching it
num_iters = 50

loss_hist = []
acc_hist = []

for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(trainloader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()

        spk_rec, mem_rec = net(data)
        
        # mem_sum = torch.sum(mem_rec, dim = 0)
        loss_val = loss_fn(mem_sum, targets)

        # Gradient calculation + weights update
        optimizer.zero_grad() # resets the gradients of the previous iteration
        loss_val.backward()
        optimizer.step()

        # Store the loss in the history for plotting purposes
        loss_hist.append(loss_val.item())

        print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")
        
        acc = SF.accuracy_rate(spk_rec, targets)

        acc_hist.append(acc)
        print(f"Accuracy: {acc*100:.2f}%\n")

        # Training loop breaks after 50 iterations
        if i == num_iters:
            break

Epoch 0, Iteration 0 
Train Loss: 448.58
Accuracy: 17.19%

Epoch 0, Iteration 1 
Train Loss: 554.55
Accuracy: 7.03%

Epoch 0, Iteration 2 
Train Loss: 517.61
Accuracy: 9.38%

Epoch 0, Iteration 3 
Train Loss: 535.44
Accuracy: 8.59%

Epoch 0, Iteration 4 
Train Loss: 454.49
Accuracy: 8.59%

Epoch 0, Iteration 5 
Train Loss: 468.26
Accuracy: 12.50%

Epoch 0, Iteration 6 
Train Loss: 493.25
Accuracy: 4.69%

Epoch 0, Iteration 7 
Train Loss: 465.36
Accuracy: 10.94%

Epoch 0, Iteration 8 
Train Loss: 537.54
Accuracy: 6.25%

Epoch 0, Iteration 9 
Train Loss: 550.62
Accuracy: 4.69%

Epoch 0, Iteration 10 
Train Loss: 556.05
Accuracy: 8.59%

Epoch 0, Iteration 11 
Train Loss: 555.95
Accuracy: 8.59%

Epoch 0, Iteration 12 
Train Loss: 537.21
Accuracy: 10.16%

Epoch 0, Iteration 13 
Train Loss: 449.08
Accuracy: 8.59%

Epoch 0, Iteration 14 
Train Loss: 506.67
Accuracy: 10.16%

Epoch 0, Iteration 15 
Train Loss: 549.89
Accuracy: 10.16%

Epoch 0, Iteration 16 
Train Loss: 490.61
Accuracy: 12.50%

