# Recurrent SNN Plus Learnable Beta

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import os

import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils

import tonic
import tonic.transforms as transforms
from tonic import DiskCachedDataset
from tonic.dataset import Dataset

import matplotlib.pyplot as plt
from IPython.display import HTML
from collections.abc import Callable

import optuna

In [2]:
sensor_size = tonic.datasets.DVSGesture.sensor_size

# 15 time steps
transform = transforms.Compose([
    transforms.ToFrame(sensor_size=sensor_size, n_time_bins=30),
])

train_set = tonic.datasets.DVSGesture(save_to='./data', train=True, transform=transform)
test_set = tonic.datasets.DVSGesture(save_to='./data', train=False, transform=transform)

# Dataloaders
cached_dataloader_args = {
    "batch_size": 16,
    "collate_fn": tonic.collation.PadTensors(batch_first=False), 
    "shuffle": True,
    "num_workers": 2,
    "pin_memory": True
}

train_loader = DataLoader(train_set, **cached_dataloader_args)
test_loader = DataLoader(test_set, **cached_dataloader_args)

data, targets = next(iter(train_loader))
print(f"Data shape: {data.shape}") 

Data shape: torch.Size([30, 16, 2, 128, 128])


In [3]:
if torch.xpu.is_available():
    device = torch.device("xpu")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Running on: {device}")

Running on: xpu


In [4]:
# Neuron parameters
beta = 0.5  
spike_grad = surrogate.atan() 

In [5]:
class RSNN(nn.Module):
    def __init__(self, beta=0.5, threshold=1.0):
        super(RSNN, self).__init__()

        # Surrogate gradient for backprop
        spike_grad = surrogate.atan()

        # --- Layer 1: Conv -> BN -> Pool -> Learnable LIF ---
        # Input: 2 channels, 128x128
        self.conv1 = nn.Conv2d(2, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(2) # Output: 16 x 64 x 64
        self.lif1 = snn.Leaky(beta=beta, threshold=threshold, 
                              learn_beta=True, learn_threshold=True,
                              spike_grad=spike_grad, init_hidden=True)

        # --- Layer 2: Conv -> BN -> Pool -> Learnable LIF ---
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(2) # Output: 32 x 32 x 32
        self.lif2 = snn.Leaky(beta=beta, threshold=threshold, 
                              learn_beta=True, learn_threshold=True,
                              spike_grad=spike_grad, init_hidden=True)

        # --- Layer 3: Conv -> BN -> Pool -> Learnable LIF (New Depth) ---
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.pool3 = nn.MaxPool2d(2) # Output: 64 x 16 x 16
        self.lif3 = snn.Leaky(beta=beta, threshold=threshold, 
                              learn_beta=True, learn_threshold=True,
                              spike_grad=spike_grad, init_hidden=True)
                              
        # --- Layer 4: Conv -> BN -> Pool -> Learnable LIF (New Depth) ---
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.pool4 = nn.MaxPool2d(2) # Output: 128 x 8 x 8
        self.lif4 = snn.Leaky(beta=beta, threshold=threshold, 
                              learn_beta=True, learn_threshold=True,
                              spike_grad=spike_grad, init_hidden=True)

        # --- Layer 5: Flatten -> Linear -> Recurrent LIF (RNN) ---
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128 * 8 * 8, 11) # 11 Classes
        
        # RLeaky adds a recurrent connection to the output neurons
        # all_to_all=True connects every output neuron to every other in the next step
        self.rlif = snn.RLeaky(beta=beta, threshold=threshold, 
                               learn_beta=True, learn_threshold=True,
                               spike_grad=spike_grad, init_hidden=True,
                               linear_features=11, all_to_all=True, output=True)

    def forward(self, x):
        # Pass through the spatial layers
        x = self.pool1(self.bn1(self.conv1(x)))
        x = self.lif1(x)
        
        x = self.pool2(self.bn2(self.conv2(x)))
        x = self.lif2(x)
        
        x = self.pool3(self.bn3(self.conv3(x)))
        x = self.lif3(x)
        
        x = self.pool4(self.bn4(self.conv4(x)))
        x = self.lif4(x)
        
        # Flatten and pass to Recurrent Layer
        x = self.flatten(x)
        x = self.fc(x)
        spk_out, mem_out = self.rlif(x)

        return spk_out, mem_out

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
net = RSNN().to(device)

# Updated Optimizer (Parameters are now learnable, so we pass net.parameters())
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999))
loss_fn = SF.ce_rate_loss()

print(f"Model created with {sum(p.numel() for p in net.parameters())} parameters.")

Model created with 188041 parameters.


: 

In [None]:
num_epochs = 10
hist = {"loss": [], "acc": []}

print("Starting Training...")

for epoch in range(num_epochs):
    iter_loss = 0
    iter_acc = 0
    counter = 0

    net.train()
    
    for data, targets in train_loader:
        data = data.to(device)
        targets = targets.to(device)
        utils.reset(net) 
        spk_rec = []
        
        for step in range(data.size(0)):
            spk_out, mem_out = net(data[step])
            spk_rec.append(spk_out)


        spk_rec = torch.stack(spk_rec)
        loss_val = loss_fn(spk_rec, targets)
        acc = SF.accuracy_rate(spk_rec, targets)
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        iter_loss += loss_val.item()
        iter_acc += acc
        counter += 1


    epoch_loss = iter_loss / counter
    epoch_acc = iter_acc / counter
    hist['loss'].append(epoch_loss)
    hist['acc'].append(epoch_acc)
    
    print(f"Epoch {epoch+1}/{num_epochs} \t Loss: {epoch_loss:.4f} \t Accuracy: {epoch_acc:.4f}")

Starting Training...


In [None]:
print("Starting Testing...")
net.eval()

total = 0
correct = 0

with torch.no_grad():
    for data, targets in test_loader:
        data = data.to(device)
        targets = targets.to(device)
        
        utils.reset(net)
        spk_rec = []

        for step in range(data.size(0)):
            spk_out, mem_out = net(data[step])
            spk_rec.append(spk_out)

        spk_rec = torch.stack(spk_rec)
        
        # Calculate correct predictions
        # SF.accuracy_rate returns a ratio, so we multiply by batch size to get count
        acc = SF.accuracy_rate(spk_rec, targets)
        correct += acc * data.size(1) 
        total += data.size(1)

test_acc = correct / total
print(f"Test Accuracy: {test_acc*100:.2f}%")

In [None]:
fig, ax1 = plt.subplots()

color = 'tab:red'
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss', color=color)
ax1.plot(hist['loss'], color=color, label="Loss")
ax1.tick_params(axis='y', labelcolor=color)

ax2 = ax1.twinx()  
color = 'tab:blue'
ax2.set_ylabel('Accuracy', color=color)  
ax2.plot(hist['acc'], color=color, label="Accuracy")
ax2.tick_params(axis='y', labelcolor=color)

fig.tight_layout()  
plt.title("Training Loss and Accuracy")
plt.show()

In [None]:
import snntorch.spikeplot as splt
import matplotlib.pyplot as plt

data, targets = next(iter(test_loader))
data = data.to(device)
targets = targets.to(device)

net.eval()
utils.reset(net)
spk_rec = []

for step in range(data.size(0)):
    spk_out, mem_out = net(data[step])
    spk_rec.append(spk_out)

spk_rec = torch.stack(spk_rec)

idx = 0 
fig, ax = plt.subplots(facecolor='w', figsize=(12, 8))

# Plot the spikes
splt.raster(spk_rec[:, idx, :], ax, s=20, c="black")


class_labels = train_set.classes
ax.set_yticks(range(len(class_labels)))
ax.set_yticklabels(class_labels)


plt.title(f"Output Spikes (True Target: {class_labels[targets[idx]]})")
plt.xlabel("Time step")
plt.grid(True, linestyle='--', alpha=0.3) 
plt.show()

In [None]:
class OptunaRSNN(nn.Module):
    def __init__(self, slope, beta_init, threshold_init, linear_hidden):
        super(OptunaRSNN, self).__init__()

        # Define Surrogate Gradient using the slope suggested by Optuna
        spike_grad = surrogate.atan(alpha=slope)

        # --- Deep Convolutional Layers ---
        # Note: We initialize beta/threshold with Optuna values, 
        # but set learn_*=True so the network can fine-tune them.
        
        # Layer 1
        self.conv1 = nn.Conv2d(2, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(2)
        self.lif1 = snn.Leaky(beta=beta_init, threshold=threshold_init, 
                              learn_beta=True, learn_threshold=True,
                              spike_grad=spike_grad, init_hidden=True)

        # Layer 2
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(2)
        self.lif2 = snn.Leaky(beta=beta_init, threshold=threshold_init, 
                              learn_beta=True, learn_threshold=True,
                              spike_grad=spike_grad, init_hidden=True)

        # Layer 3
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.pool3 = nn.MaxPool2d(2)
        self.lif3 = snn.Leaky(beta=beta_init, threshold=threshold_init, 
                              learn_beta=True, learn_threshold=True,
                              spike_grad=spike_grad, init_hidden=True)
                              
        # Layer 4
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.pool4 = nn.MaxPool2d(2)
        self.lif4 = snn.Leaky(beta=beta_init, threshold=threshold_init, 
                              learn_beta=True, learn_threshold=True,
                              spike_grad=spike_grad, init_hidden=True)

        # --- Recurrent Output Block ---
        self.flatten = nn.Flatten()
        
        # We tune the size of this linear layer
        self.fc = nn.Linear(128 * 8 * 8, linear_hidden) 
        
        # Recurrent Leaky Layer (RNN)
        self.rlif = snn.RLeaky(beta=beta_init, threshold=threshold_init, 
                               learn_beta=True, learn_threshold=True,
                               spike_grad=spike_grad, init_hidden=True,
                               linear_features=11, all_to_all=True, output=True)

    def forward(self, x):
        x = self.pool1(self.bn1(self.conv1(x)))
        x = self.lif1(x)
        
        x = self.pool2(self.bn2(self.conv2(x)))
        x = self.lif2(x)
        
        x = self.pool3(self.bn3(self.conv3(x)))
        x = self.lif3(x)
        
        x = self.pool4(self.bn4(self.conv4(x)))
        x = self.lif4(x)
        
        x = self.flatten(x)
        x = self.fc(x)
        spk_out, mem_out = self.rlif(x)

        return spk_out, mem_out

In [None]:
def objective(trial):
    # --- Suggest Hyperparameters ---
    slope = trial.suggest_float("slope", 1.0, 10.0)
    beta_init = trial.suggest_float("beta_init", 0.3, 0.95)
    threshold_init = trial.suggest_float("threshold_init", 0.5, 2.0)
    lr = trial.suggest_float("lr", 1e-4, 5e-3, log=True)
    linear_hidden = trial.suggest_int("linear_hidden", 64, 256, step=64)
    
    # --- Setup Model & Training ---
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
    model = OptunaRSNN(slope, beta_init, threshold_init, linear_hidden).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
    loss_fn = SF.ce_rate_loss()
    
    # Use existing loaders (ensure they are defined in your notebook)
    # train_loader, test_loader = ... 

    # --- Fast Training Loop for Optimization ---
    epochs = 5 # Keep low for search, train fully with best params later
    
    for epoch in range(epochs):
        model.train()
        for data, targets in train_loader:
            data, targets = data.to(device), targets.to(device)
            utils.reset(model) # Reset hidden states/memories
            
            spk_rec = []
            for step in range(data.size(0)):
                spk_out, mem_out = model(data[step])
                spk_rec.append(spk_out)
                
            spk_rec = torch.stack(spk_rec)
            loss = loss_fn(spk_rec, targets)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        # --- Validation & Pruning ---
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, targets in test_loader:
                data, targets = data.to(device), targets.to(device)
                utils.reset(model)
                spk_rec = []
                for step in range(data.size(0)):
                    spk_out, mem_out = model(data[step])
                    spk_rec.append(spk_out)
                
                spk_rec = torch.stack(spk_rec)
                acc = SF.accuracy_rate(spk_rec, targets)
                correct += acc * data.size(1)
                total += data.size(1)
        
        val_acc = correct / total
        
        # Report to Optuna
        trial.report(val_acc, epoch)
        
        # Pruning: Stop training if this trial is doing poorly compared to others
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
            
    return val_acc

In [None]:
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20) # Start with 20 trials

print("Best Hyperparameters:", study.best_params)
print("Best Accuracy:", study.best_value)

In [None]:
# best_params = {
#     "beta_init": 0.82,      # Example value
#     "slope": 4.5,           # Example value
#     "lr": 0.001,            # Example value
#     "linear_hidden": 128,   # Example value
#     "threshold_init": 1.0   # Example value (if you optimized it)
# }

# BATCH_SIZE = 16 
# EPOCHS = 100

In [None]:
# net1 = DeepRSNN(
#     slope=best_params["slope"],
#     beta_init=best_params["beta_init"],
#     threshold_init=best_params["threshold_init"],
#     linear_hidden=best_params["linear_hidden"]
# ).to(device)

In [None]:
# optimizer = torch.optim.Adam(net.parameters(), lr=best_params["lr"], betas=(0.9, 0.999))
# loss_fn = SF.ce_rate_loss()

# print(f"--- Final Model Initialized ---")
# print(f"Parameters: {sum(p.numel() for p in net1.parameters())}")
# print(f"Hyperparams: {best_params}")

In [None]:
# best_acc = 0.0
# loss_hist = []
# test_acc_hist = []

# print(f"Starting training for {EPOCHS} epochs...")

# for epoch in range(EPOCHS):
#     # --- Training ---
#     net1.train()
#     epoch_loss = 0
#     for data, targets in train_loader:
#         data, targets = data.to(device), targets.to(device)
#         utils.reset(net)
#         spk_rec = []

#         # Time-loop
#         for step in range(data.size(0)):
#             spk_out, _ = net(data[step])
#             spk_rec.append(spk_out)

#         spk_rec = torch.stack(spk_rec)
#         loss = loss_fn(spk_rec, targets)
        
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         epoch_loss += loss.item()

#     # --- Validation ---
#     net1.eval()
#     correct = 0
#     total = 0
#     with torch.no_grad():
#         for data, targets in test_loader:
#             data, targets = data.to(device), targets.to(device)
#             utils.reset(net)
#             spk_rec = []

#             for step in range(data.size(0)):
#                 spk_out, _ = net(data[step])
#                 spk_rec.append(spk_out)

#             spk_rec = torch.stack(spk_rec)
#             acc = SF.accuracy_rate(spk_rec, targets)
#             correct += acc * data.size(1)
#             total += data.size(1)

#     # --- Stats & Saving ---
#     epoch_acc = correct / total
#     avg_loss = epoch_loss / len(train_loader)
    
#     loss_hist.append(avg_loss)
#     test_acc_hist.append(epoch_acc)

#     print(f"Epoch {epoch+1}/{EPOCHS} \t Loss: {avg_loss:.4f} \t Acc: {epoch_acc*100:.2f}%")

#     # Save Best Model
#     if epoch_acc > best_acc:
#         best_acc = epoch_acc
#         torch.save(net.state_dict(), "best_dvs_gesture_model.pth")
#         print(f"  --> New Best Model Saved! ({best_acc*100:.2f}%)")

# print("Training Complete.")
# print(f"Highest Accuracy Achieved: {best_acc*100:.2f}%")
# print("Best model weights saved to 'best_dvs_gesture_model.pth'")