In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import os
import time
from tqdm import tqdm

# Tonic for data loading (Using your existing method)
import tonic
import tonic.transforms as transforms

# Check Device
if torch.xpu.is_available():
    device = torch.device("xpu")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Running on: {device}")

Running on: xpu


In [2]:
class Config:
    def __init__(self):
        self.T = 20          # Time windows (Time steps)
        self.dt = 1          # Time resolution
        self.b = 8          # Batch size
        self.lr = 1e-4       # Learning rate
        self.epochs = 50     # Number of epochs
        
        # Neuron Parameters from the repo
        self.alpha = 0.3     # Decay factor
        self.beta = 0.       # (Not used in this repo's specific LIF usually)
        self.Vreset = 0.     # Reset voltage
        self.Vthres = 0.3    # Threshold
        
        # Architecture
        self.target_size = 11 # DVS Gesture classes
        self.reduction = 16   # Attention reduction ratio
        self.mode_select = 'spike' # 'spike' or 'mem'
        
cfg = Config()

In [3]:
# --- Temporal Attention Layer (TA) ---
class Tlayer(nn.Module):
    def __init__(self, timeWindows, reduction=16, dimension=4):
        super(Tlayer, self).__init__()
        # Adapts pooling based on input dimensions (4D for FC, 5D for Conv)
        if dimension == 3: self.avg_pool = nn.AdaptiveAvgPool1d(1)
        elif dimension == 4: self.avg_pool = nn.AdaptiveAvgPool2d(1)
        else: self.avg_pool = nn.AdaptiveAvgPool3d(1)

        self.temporal_excitation = nn.Sequential(
            nn.Linear(timeWindows, int(timeWindows // reduction)),
            nn.ReLU(inplace=True),
            nn.Linear(int(timeWindows // reduction), timeWindows),
            nn.Sigmoid()
        )

    def forward(self, input):
        # Input shape: [Batch, Time, Channels, Height, Width] (for Conv)
        b, t = input.size(0), input.size(1)
        
        # Pool spatial dims to get temporal statistics
        temp = self.avg_pool(input) 
        y = temp.view(b, t)
        
        # Calculate attention weights
        y = self.temporal_excitation(y).view(temp.size())
        
        # Apply weights
        return input * y

# --- Custom Integrate and Fire Cell ---
# Used for both Conv and FC layers in the repo
class IFCell(nn.Module):
    def __init__(self, inputSize, hiddenSize, spikeActFun, scale=0.3, pa_dict=None, bias=True):
        super().__init__()
        self.hiddenSize = hiddenSize
        self.spikeActFun = spikeActFun
        self.pa_dict = pa_dict
        self.alpha = pa_dict['alpha']
        self.Vreset = pa_dict['Vreset']
        self.Vthres = pa_dict['Vthres']
        self.h = None # Membrane potential

    def forward(self, input, init_v=None):
        self.batchSize = input.size(0)
        
        # Initialize membrane potential if first step
        if self.h is None:
            if input.dim() == 4: # Conv Layer [B, C, H, W]
                 self.h = torch.zeros(self.batchSize, self.hiddenSize, input.size(2), input.size(3)).to(input.device)
            else: # FC Layer [B, Features]
                 self.h = torch.zeros(self.batchSize, self.hiddenSize).to(input.device)

        # LIF Dynamics
        u = self.h + input
        x_ = u - self.Vthres
        x = self.spikeActFun(x_) # Generate Spike
        
        # Hard Reset (Repo approach) + Leak
        # self.h = x * self.Vreset + (1 - x) * u # If using reset to specific value
        self.h = x * self.Vthres + (1 - x) * u # Soft reset/Threshold subtraction often used
        self.h = self.h * self.alpha # Decay
        
        return x

    def reset(self):
        self.h = None

# --- Wrapper for Convolution + Attention + LIF ---
class ConvAttLIF(nn.Module):
    def __init__(self, in_c, out_c, kernel_size, spikeActFun, padding=1, pa_dict=None, reduction=16, T=60, stride=1, pooling_kernel=1):
        super().__init__()
        self.conv = nn.Conv2d(in_c, out_c, kernel_size, padding=padding, stride=stride)
        self.bn = nn.BatchNorm2d(out_c)
        self.pooling_kernel = pooling_kernel
        if pooling_kernel > 1:
            self.pool = nn.AvgPool2d(pooling_kernel)
        
        self.attention = Tlayer(timeWindows=T, dimension=5, reduction=reduction)
        self.lif = IFCell(0, out_c, spikeActFun, pa_dict=pa_dict)

    def forward(self, data):
        # Reset neuron state at start of batch
        self.lif.reset()
        
        # Data: [Batch, Time, Channel, Height, Width]
        b, t, c, h, w = data.size()
        
        # 1. Spatial Convolution over all time steps (reshaped)
        out = data.reshape(b * t, c, h, w)
        out = self.conv(out)
        out = self.bn(out)
        
        if self.pooling_kernel > 1:
            out = self.pool(out)
        
        # Reshape back to separate Time dimension
        _, c_out, h_out, w_out = out.size()
        out = out.reshape(b, t, c_out, h_out, w_out)
        
        # 2. Temporal Attention
        out = self.attention(out)
        
        # 3. Integrate and Fire (Loop over time)
        output_spikes = []
        for step in range(t):
            spike = self.lif(out[:, step])
            output_spikes.append(spike)
            
        return torch.stack(output_spikes, dim=1) # [B, T, C, H, W]

# --- Wrapper for FC + Attention + LIF ---
class FCAttLIF(nn.Module):
    def __init__(self, in_features, out_features, spikeActFun, pa_dict=None, reduction=16, T=60):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.bn = nn.BatchNorm1d(out_features)
        self.attention = Tlayer(timeWindows=T, dimension=3, reduction=reduction)
        self.lif = IFCell(0, out_features, spikeActFun, pa_dict=pa_dict)

    def forward(self, data):
        self.lif.reset()
        b, t, _ = data.size()
        
        # Linear transform
        out = self.linear(data.reshape(b*t, -1))
        out = self.bn(out)
        out = out.reshape(b, t, -1)
        
        # Attention
        out = self.attention(out)
        
        # LIF
        output_spikes = []
        for step in range(t):
            spike = self.lif(out[:, step])
            output_spikes.append(spike)
            
        return torch.stack(output_spikes, dim=1)

In [4]:
class ActFun(torch.autograd.Function):
    """ Approximate Firing Function (Surrogate Gradient) """
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.ge(0.).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        lens = 0.5 
        temp = abs(input) < lens
        return grad_output * temp.float() / (2 * lens)

class TA_SNN_Net(nn.Module):
    def __init__(self, config):
        super(TA_SNN_Net, self).__init__()
        self.cfg = config
        pa_dict = {'alpha': config.alpha, 'beta': config.beta, 'Vreset': config.Vreset, 'Vthres': config.Vthres}
        
        # CNN Configuration: [InCh, OutCh, Kernel, Padding, Pooling]
        # Standard DVSGesture structure used in the repo
        self.conv1 = ConvAttLIF(2, 64, kernel_size=3, padding=1, stride=1, pooling_kernel=2, 
                                spikeActFun=ActFun.apply, pa_dict=pa_dict, T=config.T, reduction=config.reduction)
        
        self.conv2 = ConvAttLIF(64, 128, kernel_size=3, padding=1, stride=1, pooling_kernel=2, 
                                spikeActFun=ActFun.apply, pa_dict=pa_dict, T=config.T, reduction=config.reduction)
        
        self.conv3 = ConvAttLIF(128, 128, kernel_size=3, padding=1, stride=1, pooling_kernel=2, 
                                spikeActFun=ActFun.apply, pa_dict=pa_dict, T=config.T, reduction=config.reduction)
        
        # Fully Connected Layers
        # 128x128 -> (Pool2) 64 -> (Pool2) 32 -> (Pool2) 16
        flat_size = 128 * 16 * 16 
        
        self.fc1 = FCAttLIF(flat_size, 256, spikeActFun=ActFun.apply, pa_dict=pa_dict, T=config.T, reduction=config.reduction)
        self.fc2 = FCAttLIF(256, config.target_size, spikeActFun=ActFun.apply, pa_dict=pa_dict, T=config.T, reduction=config.reduction)

    def forward(self, x):
        # x shape: [Batch, Time, Channel, H, W]
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        
        # Flatten
        b, t, c, h, w = x.size()
        x = x.reshape(b, t, -1)
        
        x = self.fc1(x)
        x = self.fc2(x)
        
        # Average spikes over time for classification rate
        return torch.sum(x, dim=1) / t 

# Initialize Model
model = TA_SNN_Net(cfg).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
criterion = nn.MSELoss() # Repo uses MSE on Spike Rates often, or CE

In [5]:
sensor_size = tonic.datasets.DVSGesture.sensor_size

# Transform: Create Frames. 
# IMPORTANT: n_time_bins must match cfg.T (60)
transform = transforms.Compose([
    transforms.ToFrame(sensor_size=sensor_size, n_time_bins=cfg.T),
])

train_set = tonic.datasets.DVSGesture(save_to='./data', train=True, transform=transform)
test_set = tonic.datasets.DVSGesture(save_to='./data', train=False, transform=transform)

cached_dataloader_args = {
    "batch_size": cfg.b,
    "collate_fn": tonic.collation.PadTensors(batch_first=True), # Note: batch_first=True needed for [B, T, C, H, W]
    "shuffle": True,
    "num_workers": 2,
    "drop_last": True
}

train_loader = DataLoader(train_set, **cached_dataloader_args)
test_loader = DataLoader(test_set, **cached_dataloader_args)

# Verification
d, t = next(iter(train_loader))
print(f"Input Shape: {d.shape}") # Should be [16, 60, 2, 128, 128]

Input Shape: torch.Size([8, 20, 2, 128, 128])


In [None]:
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.epochs}")
    for i, (images, labels) in enumerate(pbar):
        images = images.float().to(device)
        
        # One-hot encode labels for MSE Loss (Common in this repo style)
        # Or use CrossEntropy on the output rates
        target = torch.zeros(images.size(0), cfg.target_size).to(device)
        target.scatter_(1, labels.unsqueeze(1).long().to(device), 1)
        
        optimizer.zero_grad()
        outputs = model(images) # Returns spike rates [B, Classes]
        
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(device)).sum().item()
        
        pbar.set_postfix({'Loss': train_loss/(i+1), 'Acc': 100*correct/total})

def test(epoch):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.float().to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    acc = 100 * correct / total
    print(f"Test Accuracy: {acc:.2f}%")
    return acc

# --- Run ---
best_acc = 0
for epoch in range(1, cfg.epochs + 1):
    train(epoch)
    acc = test(epoch)
    if acc > best_acc:
        best_acc = acc
        # torch.save(model.state_dict(), 'best_ta_snn.pth')
        print(f"New Best Accuracy: {best_acc:.2f}%")

Epoch 1/50:   0%|                                                                              | 0/134 [00:00<?, ?it/s]