In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import snntorch as snn
from snntorch import surrogate, functional as SF, utils
from snntorch import spikeplot as splt
import tonic
import tonic.transforms as transforms
import torch.nn.functional as F
torch.manual_seed(42)
np.random.seed(42)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sensor_size = tonic.datasets.DVSGesture.sensor_size

transform = transforms.Compose([
    transforms.ToFrame(sensor_size=sensor_size, n_time_bins=30),
])

train_set = tonic.datasets.DVSGesture(save_to='./data', train=True, transform=transform)
test_set = tonic.datasets.DVSGesture(save_to='./data', train=False, transform=transform)

# Dataloaders
cached_dataloader_args = {
    "batch_size": 16,
    "collate_fn": tonic.collation.PadTensors(batch_first=True), 
    "shuffle": True,
    "num_workers": 2,
    "pin_memory": True
}

train_loader = DataLoader(train_set, **cached_dataloader_args)
test_loader = DataLoader(test_set, **cached_dataloader_args)

data, targets = next(iter(train_loader))
print(f"Data shape: {data.shape}") 



Data shape: torch.Size([16, 30, 2, 128, 128])


In [3]:
if torch.xpu.is_available():
    device = torch.device("xpu")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Running on: {device}")

Running on: cuda


In [4]:
class SResnetNM(nn.Module):
    def __init__(self, n, nFilters, num_steps, leak_mem=0.95, num_cls=11):
        super(SResnetNM, self).__init__()

        self.n = n
        self.num_cls = num_cls
        self.num_steps = num_steps
        self.leak_mem = leak_mem

        print(">>>>>>>>>>>>>>>>>>> S-ResNet NM for DVSGesture >>>>>>>>>>>>>>>>>>>>>>")

        affine_flag = True
        bias_flag = False
        self.nFilters = nFilters

        # Define spike gradient surrogate
        self.spike_grad = surrogate.atan()

        # Initialize layers
        self.conv1 = nn.Conv2d(2, self.nFilters, kernel_size=3, stride=2, padding=1, bias=bias_flag)
        self.bn1 = nn.BatchNorm2d(self.nFilters, eps=1e-4, momentum=0.1, affine=affine_flag)
        self.lif1 = snn.Leaky(beta=leak_mem, spike_grad=self.spike_grad, init_hidden=True)

        # Store layers in lists
        self.conv_list = nn.ModuleList([self.conv1])
        self.bn_list = nn.ModuleList([self.bn1])
        self.lif_list = nn.ModuleList([self.lif1])

        # Create ResNet blocks
        layer_idx = 0
        self.block_sizes = []
        for block in range(3):
            num_layers = 2 * n
            self.block_sizes.append(num_layers)
            for layer in range(num_layers):
                if block != 0 and layer == 0:
                    stride = 2
                    in_channels = self.nFilters * (2 ** (block - 1))
                else:
                    stride = 1
                    in_channels = self.nFilters * (2 ** block)
                
                out_channels = self.nFilters * (2 ** block)
                
                conv = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                                stride=stride, padding=1, bias=bias_flag)
                bn = nn.BatchNorm2d(out_channels, eps=1e-4,
                                   momentum=0.1, affine=affine_flag)
                lif = snn.Leaky(beta=leak_mem, spike_grad=self.spike_grad, init_hidden=True)
                
                self.conv_list.append(conv)
                self.bn_list.append(bn)
                self.lif_list.append(lif)
                layer_idx += 1

        # Skip connection resize layers for downsampling
        self.conv_resize_1 = nn.Conv2d(self.nFilters, self.nFilters * 2,
                                      kernel_size=1, stride=2, padding=0, bias=bias_flag)
        self.resize_bn_1 = nn.BatchNorm2d(self.nFilters * 2, eps=1e-4,
                                         momentum=0.1, affine=affine_flag)
        self.conv_resize_2 = nn.Conv2d(self.nFilters * 2, self.nFilters * 4,
                                      kernel_size=1, stride=2, padding=0, bias=bias_flag)
        self.resize_bn_2 = nn.BatchNorm2d(self.nFilters * 4, eps=1e-4,
                                         momentum=0.1, affine=affine_flag)
        
        # Spiking neurons for resize layers
        self.lif_resize_1 = snn.Leaky(beta=leak_mem, spike_grad=self.spike_grad, init_hidden=True)
        self.lif_resize_2 = snn.Leaky(beta=leak_mem, spike_grad=self.spike_grad, init_hidden=True)

        # Adaptive pooling to handle variable spatial dimensions
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.nFilters * 4, self.num_cls, bias=bias_flag)
        self.lif_out = snn.Leaky(beta=leak_mem, spike_grad=self.spike_grad, init_hidden=True, output=True)

        self.conv1x1_list = nn.ModuleList([self.conv_resize_1, self.conv_resize_2])
        self.bn_conv1x1_list = nn.ModuleList([self.resize_bn_1, self.resize_bn_2])
        self.lif_conv1x1_list = nn.ModuleList([self.lif_resize_1, self.lif_resize_2])

        # Turn off bias of BatchNorm
        for bn_temp in self.bn_list:
            if hasattr(bn_temp, 'bias') and bn_temp.bias is not None:
                bn_temp.bias = None
        for bn_temp in self.bn_conv1x1_list:
            if hasattr(bn_temp, 'bias') and bn_temp.bias is not None:
                bn_temp.bias = None

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight, gain=2)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=2)

    def forward(self, inp):
        # Input shape: [N, T, 2, H, W] (batch_first=True)
        batch_size = inp.size(0)
        num_timesteps = inp.size(1)
        
        # Initialize lists to store outputs
        spk_out_list = []
        
        # Process each timestep
        for t in range(num_timesteps):
            # Get current timestep input
            x = inp[:, t]  # [N, 2, H, W]
            
            # Resize from 128x128 to 64x64 if needed
            if x.shape[-1] == 128:
                x = F.interpolate(x, size=64, mode='bilinear', align_corners=False)
            
            # Initial conv block
            x = self.conv1(x)
            x = self.bn1(x)
            spike, mem = self.lif1(x)
            
            skip = spike.clone()
            x = spike
            
            # Track skip connection index
            skip_idx = 0
            conv_idx = 1  # Start after first conv (index 0)
            
            # Process ResNet blocks
            block_start_idx = 0
            for block in range(3):
                block_layers = self.block_sizes[block]
                
                for layer_in_block in range(block_layers):
                    # Get current layer
                    conv = self.conv_list[conv_idx]
                    bn = self.bn_list[conv_idx]
                    lif = self.lif_list[conv_idx]
                    
                    # Forward through layer
                    x = conv(x)
                    x = bn(x)
                    spike, mem = lif(x)
                    
                    # Add skip connection at the END of each block (except first layer of block)
                    if layer_in_block == block_layers - 1:  # Last layer in block
                        # For blocks after the first one, we need to process skip connection
                        if block > 0:
                            # Get the skip connection processing layer
                            skip_processed = self.conv1x1_list[skip_idx](skip)
                            skip_processed = self.bn_conv1x1_list[skip_idx](skip_processed)
                            skip_spike, _ = self.lif_conv1x1_list[skip_idx](skip_processed)
                            spike = spike + skip_spike
                            skip_idx += 1
                        else:
                            # For first block, just add
                            spike = spike + skip
                        
                        # Update skip for next block
                        skip = spike.clone()
                    
                    x = spike
                    conv_idx += 1
            
            # Final pooling and classification
            x = self.pool(x)
            x = x.view(batch_size, -1)
            x = self.fc(x)
            spk_out, mem_out = self.lif_out(x)
            
            spk_out_list.append(spk_out)
        
        # Stack outputs across time
        spk_out = torch.stack(spk_out_list, dim=0)  # [T, N, num_cls]
        
        return spk_out

net = SResnetNM(n=3, nFilters=16, num_steps=30, num_cls=11).to(device)


>>>>>>>>>>>>>>>>>>> S-ResNet NM for DVSGesture >>>>>>>>>>>>>>>>>>>>>>


In [5]:
optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.ce_rate_loss()

In [6]:
num_epochs = 100
hist = {"loss": [], "acc": []}

print("Starting Training...")

for epoch in range(num_epochs):
    iter_loss = 0
    iter_acc = 0
    counter = 0

    net.train()
    
    # Add tqdm for progress bar
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for data, targets in pbar:
        data = data.to(device)
        targets = targets.to(device)
        
        # Reset states for this batch
        utils.reset(net)
        
        # Forward pass - process all timesteps
        spk_rec = net(data)  # [T, N, num_cls]
        
        # Calculate loss and accuracy
        loss_val = loss_fn(spk_rec, targets)
        acc = SF.accuracy_rate(spk_rec, targets)
        
        # Backward pass
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        iter_loss += loss_val.item()
        iter_acc += acc
        counter += 1
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss_val.item():.4f}',
            'acc': f'{acc:.4f}'
        })

    epoch_loss = iter_loss / counter
    epoch_acc = iter_acc / counter
    hist['loss'].append(epoch_loss)
    hist['acc'].append(epoch_acc)
    
    print(f"Epoch {epoch+1}/{num_epochs} \t Loss: {epoch_loss:.4f} \t Accuracy: {epoch_acc:.4f}")

Starting Training...


Epoch 1/100:   0%|                                       | 0/68 [00:00<?, ?it/s]


ValueError: too many values to unpack (expected 2)

In [None]:
print("Starting Testing...")
net.eval()

total = 0
correct = 0

with torch.no_grad():
    for data, targets in tqdm(test_loader, desc="Testing"):
        data = data.to(device)
        targets = targets.to(device)
        
        # Reset states
        utils.reset(net)
        
        # Forward pass
        spk_rec = net(data)
        
        # Calculate accuracy
        acc = SF.accuracy_rate(spk_rec, targets)
        correct += acc * data.size(0)  # data.size(0) is batch size
        total += data.size(0)

test_acc = correct / total if total > 0 else 0
print(f"Test Accuracy: {test_acc*100:.2f}%")

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(hist['loss'])
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1, 2, 2)
plt.plot(hist['acc'])
plt.title('Training Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.tight_layout()
plt.savefig('training_history.png')
plt.show()

In [None]:
fig, ax1 = plt.subplots()

color = 'tab:red'
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss', color=color)
ax1.plot(hist['loss'], color=color, label="Loss")
ax1.tick_params(axis='y', labelcolor=color)

ax2 = ax1.twinx()  
color = 'tab:blue'
ax2.set_ylabel('Accuracy', color=color)  
ax2.plot(hist['acc'], color=color, label="Accuracy")
ax2.tick_params(axis='y', labelcolor=color)

fig.tight_layout()  
plt.title("Training Loss and Accuracy")
plt.show()

In [None]:
import matplotlib.animation as animation

def play_anim(data, labels, index=0):
    """
    data: Tensor of shape (Time, Batch, Channels, Height, Width)
    labels: Target integers
    index: Which sample in the batch to visualize
    """
    
    # Select the specific sample from the batch: (Time, Channels, Height, Width)
    # We detach from graph and move to cpu for plotting
    sample = data[:, index, :, :, :].cpu()
    label_id = labels[index].item()
    class_name = train_set.classes[label_id]

    # Combine channels for visualization
    # Channel 1 (Positive) -> Red, Channel 0 (Negative) -> Blue
    # We create a single image where positive is +1 and negative is -1
    frames = sample[:, 1, :, :] - sample[:, 0, :, :]

    fig, ax = plt.subplots()
    im = ax.imshow(frames[0], cmap='seismic', vmin=-1.5, vmax=1.5)
    ax.axis('off')
    ax.set_title(f"Label: {class_name}")

    def update(frame_idx):
        # Update the image data for the next frame
        im.set_data(frames[frame_idx])
        return [im]

    # Create animation
    # frames=sample.shape[0] ensures we loop through all 15 time steps
    ani = animation.FuncAnimation(fig, update, frames=sample.shape[0], interval=200, blit=True)
    
    plt.close() # Prevent static plot from showing up separately
    return ani

In [None]:
# 1. Get a batch of data
data_batch, targets_batch = next(iter(test_loader))

# 2. Generate the animation for the first sample in the batch (index 0)
anim = play_anim(data_batch, targets_batch, index=0)

# 3. Render it in the notebook
HTML(anim.to_jshtml())

In [None]:
import snntorch.spikeplot as splt
import matplotlib.pyplot as plt

data, targets = next(iter(test_loader))
data = data.to(device)
targets = targets.to(device)

net.eval()
utils.reset(net)
spk_rec = []

for step in range(data.size(0)):
    spk_out, mem_out = net(data[step])
    spk_rec.append(spk_out)

spk_rec = torch.stack(spk_rec)

idx = 0 
fig, ax = plt.subplots(facecolor='w', figsize=(12, 8))

# Plot the spikes
splt.raster(spk_rec[:, idx, :], ax, s=20, c="black")


class_labels = train_set.classes
ax.set_yticks(range(len(class_labels)))
ax.set_yticklabels(class_labels)


plt.title(f"Output Spikes (True Target: {class_labels[targets[idx]]})")
plt.xlabel("Time step")
plt.grid(True, linestyle='--', alpha=0.3) 
plt.show()