# SSNN on CIFAR10-DVS — Reproducible Training (Colab)
**Model:** Shrinking Spiking Neural Network (SSNN) with four stages and shrinking timesteps `[8, 6, 4, 2]`  
**Datasets:** CIFAR10-DVS ([40])  
**Paper section:** *Training Configuration / Table 1*  
**Note:** This notebook matches the configuration in Table 1. Edit the hyperparameters in the next cell if your table differs.

In [None]:
# Runtime: Google Colab
# Installs
!pip -q install torch torchvision --index-url https://download.pytorch.org/whl/cu121
!pip -q install snntorch tonic==1.4.3 einops

In [None]:
import os, math, random, time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from einops import rearrange

# spiking + datasets
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils as snnutils
import tonic
from tonic import transforms

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device:', device)

# Reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [None]:
# ===== Config matching Table 1 (edit to match your paper) =====
CFG = {
    "batch_size": 32,          # Table 1: Batch size
    "epochs": 5,               # Table 1: Epochs (increase for full training)
    "lr": 1e-3,                # Table 1: Learning rate
    "weight_decay": 1e-4,      # Table 1: Weight decay
    "beta": 0.9,               # LIF decay
    "n_time_bins": 8,          # Max time bins for CIFAR10-DVS frames
    "timesteps_per_stage": [8,6,4,2],  # Shrinking schedule
    "aux_loss_weights": [0.3, 0.2, 0.1], # stage1, stage2, stage3 auxiliary heads
    "num_workers": 2
}
print(CFG)

In [None]:
# ===== Dataset: CIFAR10-DVS (tonic) =====
save_dir = "/content/data"
Path(save_dir).mkdir(parents=True, exist_ok=True)

sensor_size = tonic.datasets.CIFAR10DVS.sensor_size  # (128,128,2)

frame_transform = transforms.ToFrame(
    sensor_size=sensor_size,
    n_time_bins=CFG["n_time_bins"]
)

trainset = tonic.datasets.CIFAR10DVS(save_to=save_dir, train=True, transform=frame_transform)
testset  = tonic.datasets.CIFAR10DVS(save_to=save_dir, train=False, transform=frame_transform)

# Pad sequences to the same temporal length within a batch
collate_fn = tonic.collation.PadTensors(batch_first=True)  # output: (B, T, C, H, W)

trainloader = DataLoader(trainset, batch_size=CFG["batch_size"], shuffle=True,
                         collate_fn=collate_fn, num_workers=CFG["num_workers"], pin_memory=True)
testloader  = DataLoader(testset,  batch_size=CFG["batch_size"], shuffle=False,
                         collate_fn=collate_fn, num_workers=CFG["num_workers"], pin_memory=True)

# Peek
xb, yb = next(iter(trainloader))
print('batch frames:', xb.shape, 'labels:', yb.shape)  # (B, T, C=2, H=128, W=128)

In [None]:
# ===== Model: 4-stage SSNN with shrinking timesteps =====
# Each stage: Conv2d -> BN -> LIF -> (optional pooling)
# Temporal alignment: lightweight transformer encoder across time between stages

class TemporalAlign(nn.Module):
    def __init__(self, embed_dim, nhead=4, dim_feedforward=256, num_layers=1):
        super().__init__()
        layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nhead,
                                           dim_feedforward=dim_feedforward, batch_first=True)
        self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
    def forward(self, x):
        # x: (B, T, C, H, W) -> flatten spatial, encode time, restore
        B, T, C, H, W = x.shape
        x_flat = x.view(B, T, C*H*W)
        x_enc = self.encoder(x_flat)
        return x_enc.view(B, T, C, H, W)

class Stage(nn.Module):
    def __init__(self, in_ch, out_ch, beta=0.9, pool=True):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
        self.bn   = nn.BatchNorm2d(out_ch)
        self.lif  = snn.Leaky(beta=beta, learn_beta=False)
        self.pool = nn.AvgPool2d(2) if pool else nn.Identity()
    def forward(self, x_t, mem=None):
        # x_t: (B, C, H, W) for a single time step
        z = self.conv(x_t)
        z = self.bn(z)
        spk, mem = self.lif(z, mem)
        spk = self.pool(spk)
        return spk, mem

class Readout(nn.Module):
    def __init__(self, in_ch, n_classes=10):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc  = nn.Linear(in_ch, n_classes)
    def forward(self, x):
        x = self.gap(x).flatten(1)
        return self.fc(x)

class SSNN(nn.Module):
    def __init__(self, beta=0.9, n_classes=10, t_per_stage=(8,6,4,2)):
        super().__init__()
        self.t_per_stage = t_per_stage
        self.stage1 = Stage(2,   32, beta=beta, pool=True)
        self.stage2 = Stage(32,  64, beta=beta, pool=True)
        self.stage3 = Stage(64, 128, beta=beta, pool=True)
        self.stage4 = Stage(128,256, beta=beta, pool=True)

        self.align12 = TemporalAlign(embed_dim=32*64*64,  nhead=4, num_layers=1)
        self.align23 = TemporalAlign(embed_dim=64*32*32,  nhead=4, num_layers=1)
        self.align34 = TemporalAlign(embed_dim=128*16*16, nhead=4, num_layers=1)

        # Early classifiers after stage 1-3
        self.aux1 = Readout(32, n_classes)
        self.aux2 = Readout(64, n_classes)
        self.aux3 = Readout(128, n_classes)
        # Final head
        self.head = Readout(256, n_classes)

    def forward(self, x, train_aux=False):
        # x: (B, T, C=2, H, W)
        B, T, C, H, W = x.shape
        T1, T2, T3, T4 = self.t_per_stage

        mem1 = mem2 = mem3 = mem4 = None
        logits_final = []
        logits_aux1  = []
        logits_aux2  = []
        logits_aux3  = []

        # process stage1 across T1 steps
        s1_seq = []
        for t in range(min(T, T1)):
            spk1, mem1 = self.stage1(x[:, t], mem1)
            s1_seq.append(spk1)
            if train_aux:
                logits_aux1.append(self.aux1(spk1))
        s1 = torch.stack(s1_seq, dim=1)  # (B, T1, C, H, W)

        # align time for stage2 input
        s1_aligned = self.align12(s1)

        # stage2 across T2 steps
        s2_seq = []
        for t in range(min(s1_aligned.shape[1], T2)):
            spk2, mem2 = self.stage2(s1_aligned[:, t], mem2)
            s2_seq.append(spk2)
            if train_aux:
                logits_aux2.append(self.aux2(spk2))
        s2 = torch.stack(s2_seq, dim=1)

        # align and stage3
        s2_aligned = self.align23(s2)
        s3_seq = []
        for t in range(min(s2_aligned.shape[1], T3)):
            spk3, mem3 = self.stage3(s2_aligned[:, t], mem3)
            s3_seq.append(spk3)
            if train_aux:
                logits_aux3.append(self.aux3(spk3))
        s3 = torch.stack(s3_seq, dim=1)

        # align and stage4
        s3_aligned = self.align34(s3)
        s4_seq = []
        for t in range(min(s3_aligned.shape[1], T4)):
            spk4, mem4 = self.stage4(s3_aligned[:, t], mem4)
            s4_seq.append(spk4)
            # final head per time step
            logits_final.append(self.head(spk4))
        # average logits over time
        logit = torch.stack(logits_final, dim=1).mean(dim=1)

        outs = {"logit": logit}
        if train_aux:
            outs["aux1"] = torch.stack(logits_aux1, dim=1).mean(dim=1)
            outs["aux2"] = torch.stack(logits_aux2, dim=1).mean(dim=1)
            outs["aux3"] = torch.stack(logits_aux3, dim=1).mean(dim=1)
        return outs

model = SSNN(beta=CFG["beta"], n_classes=10, t_per_stage=tuple(CFG["timesteps_per_stage"])).to(device)
sum_params = sum(p.numel() for p in model.parameters())
print(model.__class__.__name__, "params:", sum_params/1e6, "M")

In [None]:
# ===== Training / Evaluation =====
def accuracy(logits, targets):
    return (logits.argmax(dim=1) == targets).float().mean().item()

def train_one_epoch(model, loader, opt, epoch):
    model.train()
    ce = nn.CrossEntropyLoss()
    w1, w2, w3 = CFG["aux_loss_weights"]
    total, acc = 0., 0.
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        opt.zero_grad(set_to_none=True)
        outs = model(xb, train_aux=True)
        loss = ce(outs["logit"], yb)
        loss_aux = w1*ce(outs["aux1"], yb) + w2*ce(outs["aux2"], yb) + w3*ce(outs["aux3"], yb)
        loss = loss + loss_aux
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        opt.step()
        total += loss.item()*xb.size(0)
        acc   += (outs["logit"].argmax(1) == yb).sum().item()
    n = len(loader.dataset)
    return total/n, acc/n

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    ce = nn.CrossEntropyLoss()
    total, acc = 0., 0.
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        outs = model(xb, train_aux=False)
        loss = ce(outs["logit"], yb)
        total += loss.item()*xb.size(0)
        acc   += (outs["logit"].argmax(1) == yb).sum().item()
    n = len(loader.dataset)
    return total/n, acc/n

opt = torch.optim.AdamW(model.parameters(), lr=CFG["lr"], weight_decay=CFG["weight_decay"])

best_acc = 0.0
for epoch in range(1, CFG["epochs"]+1):
    tr_loss, tr_acc = train_one_epoch(model, trainloader, opt, epoch)
    te_loss, te_acc = evaluate(model, testloader)
    print(f"Epoch {epoch:02d} | train loss {tr_loss:.4f} acc {tr_acc:.3f} | val loss {te_loss:.4f} acc {te_acc:.3f}")
    if te_acc > best_acc:
        best_acc = te_acc
        torch.save(model.state_dict(), "/content/ssnn_cifar10dvs.pt")
print("Best val acc:", round(best_acc, 4))

### Code availability
This notebook constitutes the *Google Colab file* requested by the reviewer.  
It implements SSNN with four shrinking stages `[8,6,4,2]`, auxiliary heads, and the training
configuration of Table 1. Replace `CFG` values to match your exact table if needed.