In [17]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import os

import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils

import tonic
import tonic.transforms as transforms
from tonic import DiskCachedDataset
from tonic.dataset import Dataset

import matplotlib.pyplot as plt
from IPython.display import HTML
from collections.abc import Callable

import optuna
import time
from tqdm import tqdm

In [18]:
sensor_size = tonic.datasets.DVSGesture.sensor_size

# 15 time steps
transform = transforms.Compose([
    transforms.ToFrame(sensor_size=sensor_size, n_time_bins=30),
])

train_set = tonic.datasets.DVSGesture(save_to='./data', train=True, transform=transform)
test_set = tonic.datasets.DVSGesture(save_to='./data', train=False, transform=transform)

# Dataloaders
cached_dataloader_args = {
    "batch_size": 16,
    "collate_fn": tonic.collation.PadTensors(batch_first=False), 
    "shuffle": True,
    "num_workers": 1,
    "pin_memory": True
}

train_loader = DataLoader(train_set, **cached_dataloader_args)
test_loader = DataLoader(test_set, **cached_dataloader_args)

data, targets = next(iter(train_loader))
print(f"Data shape: {data.shape}") 

Data shape: torch.Size([30, 16, 2, 128, 128])


In [19]:
if torch.xpu.is_available():
    device = torch.device("xpu")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Running on: {device}")

Running on: xpu


In [20]:
# Global neuron parameters
beta = 0.5  
spike_grad = surrogate.atan()

In [21]:
# -----------------------------------------------------------
# 2. ResNet Building Blocks (Adapted for Single-Step)
# -----------------------------------------------------------

def conv3x3(in_channels, out_channels, beta, spike_grad):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
        nn.BatchNorm2d(out_channels),
        snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
    )

def conv1x1(in_channels, out_channels, beta, spike_grad):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
        nn.BatchNorm2d(out_channels),
        snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
    )

class SEWBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, connect_f, beta, spike_grad):
        super(SEWBlock, self).__init__()
        self.connect_f = connect_f
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels, beta, spike_grad),
            conv3x3(mid_channels, in_channels, beta, spike_grad),
        )

    def forward(self, x):
        out = self.conv(x)
        if self.connect_f == 'ADD':
            out = out + x
        elif self.connect_f == 'AND':
            out = out * x
        elif self.connect_f == 'IAND':
            out = x * (1. - out)
        else:
            raise NotImplementedError(self.connect_f)
        return out

class PlainBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, beta, spike_grad):
        super(PlainBlock, self).__init__()
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels, beta, spike_grad),
            conv3x3(mid_channels, in_channels, beta, spike_grad),
        )

    def forward(self, x):
        return self.conv(x)

class BasicBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, beta, spike_grad):
        super(BasicBlock, self).__init__()
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels, beta, spike_grad),
            nn.Sequential(
                nn.Conv2d(mid_channels, in_channels, kernel_size=3, padding=1, stride=1, bias=False),
                nn.BatchNorm2d(in_channels),
            ),
        )
        self.sn = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)

    def forward(self, x):
        return self.sn(x + self.conv(x))

In [22]:
# -----------------------------------------------------------
# 3. Main Network Class
# -----------------------------------------------------------

class ResNetN(nn.Module):
    def __init__(self, layer_list, num_classes, connect_f=None, beta=0.5, spike_grad=None):
        super(ResNetN, self).__init__()
        
        in_channels = 2 # DVS has 2 channels (on/off)
        conv = []

        # --- Build Layers ---
        for cfg_dict in layer_list:
            channels = cfg_dict['channels']
            mid_channels = cfg_dict.get('mid_channels', channels)

            # 1. Down/Up Sampling if channels change
            if in_channels != channels:
                if cfg_dict['up_kernel_size'] == 3:
                    conv.append(conv3x3(in_channels, channels, beta, spike_grad))
                elif cfg_dict['up_kernel_size'] == 1:
                    conv.append(conv1x1(in_channels, channels, beta, spike_grad))
                else:
                    raise NotImplementedError
            
            in_channels = channels

            # 2. Residual Blocks
            if 'num_blocks' in cfg_dict:
                for _ in range(cfg_dict['num_blocks']):
                    if cfg_dict['block_type'] == 'sew':
                        conv.append(SEWBlock(in_channels, mid_channels, connect_f, beta, spike_grad))
                    elif cfg_dict['block_type'] == 'plain':
                        conv.append(PlainBlock(in_channels, mid_channels, beta, spike_grad))
                    elif cfg_dict['block_type'] == 'basic':
                        conv.append(BasicBlock(in_channels, mid_channels, beta, spike_grad))

            # 3. Pooling
            if 'k_pool' in cfg_dict:
                conv.append(nn.MaxPool2d(cfg_dict['k_pool'], cfg_dict['k_pool']))

        # Flatten features before linear layer
        conv.append(nn.Flatten(1)) 

        self.conv = nn.Sequential(*conv)

        # --- Calculate Feature Size Dynamically ---
        # We run a dummy pass to see how many features come out of the conv stack
        with torch.no_grad():
            dummy_x = torch.zeros([1, 2, 128, 128])
            # Since our layers use init_hidden=True, we can pass a single frame safely
            # Note: The output will be [Batch, Features]
            out_features = self.conv(dummy_x).shape[1]

        # --- Output Head ---
        self.fc = nn.Linear(out_features, num_classes, bias=True)
        # Final neuron returns (spk, mem)
        self.final_lif = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)

    def forward(self, x):
        """
        Processes a SINGLE time-step.
        Input x: [Batch, Channel, Height, Width]
        """
        # Pass through the Conv Blocks (state is handled internally by init_hidden=True)
        features = self.conv(x)
        
        # Pass through Linear
        cur = self.fc(features)
        
        # Pass through Final LIF
        spk, mem = self.final_lif(cur)
        
        return spk, mem

In [23]:
# -----------------------------------------------------------
# 4. Model Wrapper (The SEW ResNet Configuration)
# -----------------------------------------------------------

def SEWResNet(connect_f='ADD', num_classes=11, beta=0.5, spike_grad=None):
    # This configuration defines the depth and width of the ResNet
    layer_list = [
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
    ]
    
    return ResNetN(layer_list, num_classes, connect_f, beta, spike_grad)

In [24]:
# -----------------------------------------------------------
# 5. Initialization & Training Loop
# -----------------------------------------------------------

# Initialize Model
net = SEWResNet(connect_f='ADD', num_classes=11, beta=beta, spike_grad=spike_grad).to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.ce_rate_loss()

num_epochs = 10
hist = {"loss": [], "acc": []}

print("Starting Training with SEW ResNet...")

for epoch in range(num_epochs):
    t0 = time.time()
    iter_loss = 0
    iter_acc = 0
    counter = 0

    net.train()
    
    # 2. Wrap train_loader with tqdm
    # 'desc' sets the text at the start of the bar
    # 'unit' defines the label for iterations
    with tqdm(train_loader, unit="batch", desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
        
        for data, targets in pbar:
            data = data.to(device)
            targets = targets.to(device)
            
            utils.reset(net) 
            spk_rec = []
            
            # Time Loop
            for step in range(data.size(0)):
                spk_out, mem_out = net(data[step])
                spk_rec.append(spk_out)

            spk_rec = torch.stack(spk_rec)
            
            # Loss & Backprop
            loss_val = loss_fn(spk_rec, targets)
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
            
            # Accuracy
            acc = SF.accuracy_rate(spk_rec, targets)

            iter_loss += loss_val.item()
            iter_acc += acc
            counter += 1
            
            # 3. Update the progress bar with current stats
            # This updates the text on the right side of the bar dynamically
            pbar.set_postfix({
                "Loss": f"{iter_loss/counter:.4f}", 
                "Acc": f"{iter_acc/counter:.4f}"
            })

    # End of epoch stats
    epoch_loss = iter_loss / counter
    epoch_acc = iter_acc / counter
    hist['loss'].append(epoch_loss)
    hist['acc'].append(epoch_acc)
    
    # Optional: Print total time for epoch (tqdm handles time too, but if you want explicit log)
    print(f"Epoch {epoch+1} finished in {time.time()-t0:.2f}s")

RuntimeError: Native API failed. Native API returns: 2147483646 (UR_RESULT_ERROR_UNKNOWN)