# DVS Gesture Classification: ANN to SNN (snntorch)

This notebook follows the DVS Gesture setup from `dvsges.ipynb` and performs ANN-to-SNN conversion using `snntorch` (manual weight transfer).


In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import tonic
import tonic.transforms as transforms
from tonic import DiskCachedDataset
from tqdm import tqdm

import snntorch as snn
from snntorch import surrogate
from snntorch import utils


In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
    device = torch.device('xpu')
else:
    device = torch.device('cpu')

print(f'Running on: {device}')


In [None]:
sensor_size = tonic.datasets.DVSGesture.sensor_size
n_time_bins = 30
batch_size = 16

transform = transforms.Compose([
    transforms.ToFrame(sensor_size=sensor_size, n_time_bins=n_time_bins),
])

train_set = tonic.datasets.DVSGesture(save_to='./data', train=True, transform=transform)
test_set = tonic.datasets.DVSGesture(save_to='./data', train=False, transform=transform)

seq_loader_args = {
    'batch_size': batch_size,
    'collate_fn': tonic.collation.PadTensors(batch_first=False),
    'shuffle': True,
    'num_workers': 2,
    'pin_memory': True
}

train_seq_loader = DataLoader(train_set, **seq_loader_args)
test_seq_loader = DataLoader(test_set, **seq_loader_args)

data, targets = next(iter(train_seq_loader))
print(f'Sequence batch shape [T, B, C, H, W]: {data.shape}')


In [None]:
class AggregateFrameDataset(Dataset):
    # Convert [T, C, H, W] event frames into one ANN frame [C, H, W].
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        frames, target = self.base_dataset[idx]
        frames = torch.as_tensor(frames, dtype=torch.float32)
        ann_frame = frames.mean(dim=0)
        target = torch.as_tensor(target, dtype=torch.long)
        return ann_frame, target

# Cache transformed samples on disk: first run is slower, next runs are much faster.
train_cached = DiskCachedDataset(train_set, cache_path='./cache/dvsgesture_ann_train')
test_cached = DiskCachedDataset(test_set, cache_path='./cache/dvsgesture_ann_test')

train_ann_set = AggregateFrameDataset(train_cached)
test_ann_set = AggregateFrameDataset(test_cached)

# Windows/Jupyter often performs better with fewer workers for small random reads.
ann_num_workers = 0 if os.name == 'nt' else 2
ann_loader_args = {
    'batch_size': batch_size,
    'shuffle': True,
    'num_workers': ann_num_workers,
    'pin_memory': torch.cuda.is_available(),
}
if ann_num_workers > 0:
    ann_loader_args['persistent_workers'] = True

train_ann_loader = DataLoader(train_ann_set, **ann_loader_args)
test_ann_loader = DataLoader(test_ann_set, **ann_loader_args)

x_ann, y_ann = next(iter(train_ann_loader))
print(f'ANN batch shape [B, C, H, W]: {x_ann.shape}')
print(f'ANN DataLoader workers: {ann_num_workers}')


In [None]:
class ANNNet(nn.Module):
    def __init__(self, num_classes=11):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 16, kernel_size=5, stride=2, padding=2)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32 * 8 * 8, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.flatten(x)
        x = self.fc1(x)
        return x

ann_model = ANNNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ann_model.parameters(), lr=2e-3)

ann_model


In [None]:
def evaluate_ann(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.float().to(device)
            y = y.long().to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.numel()
    return correct / max(total, 1)

num_epochs = 15
for epoch in range(num_epochs):
    ann_model.train()
    running_loss = 0.0
    steps = 0

    for x, y in tqdm(train_ann_loader, desc=f'ANN Epoch {epoch + 1}/{num_epochs}'):
        x = x.float().to(device)
        y = y.long().to(device)

        optimizer.zero_grad()
        logits = ann_model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        steps += 1

    train_acc = evaluate_ann(ann_model, train_ann_loader)
    test_acc = evaluate_ann(ann_model, test_ann_loader)
    avg_loss = running_loss / max(steps, 1)
    print(f'Epoch {epoch + 1}: loss={avg_loss:.4f}, train_acc={train_acc:.4f}, test_acc={test_acc:.4f}')


In [None]:
beta = 0.5
spike_grad = surrogate.atan()

class SNNNet(nn.Module):
    def __init__(self, num_classes=11):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 16, kernel_size=5, stride=2, padding=2)
        self.pool1 = nn.MaxPool2d(2)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2)
        self.pool2 = nn.MaxPool2d(2)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32 * 8 * 8, num_classes)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.lif1(x)

        x = self.conv2(x)
        x = self.pool2(x)
        x = self.lif2(x)

        x = self.flatten(x)
        x = self.fc1(x)
        spk_out, mem_out = self.lif3(x)
        return spk_out, mem_out

def convert_ann_to_snn(ann_model):
    snn_model = SNNNet().to(device)
    snn_model.conv1.weight.data.copy_(ann_model.conv1.weight.data)
    snn_model.conv1.bias.data.copy_(ann_model.conv1.bias.data)
    snn_model.conv2.weight.data.copy_(ann_model.conv2.weight.data)
    snn_model.conv2.bias.data.copy_(ann_model.conv2.bias.data)
    snn_model.fc1.weight.data.copy_(ann_model.fc1.weight.data)
    snn_model.fc1.bias.data.copy_(ann_model.fc1.bias.data)
    return snn_model

snn_model = convert_ann_to_snn(ann_model)
print('Converted ANN to SNN (snntorch) via weight transfer.')


In [None]:
def evaluate_snn(model, loader, sim_steps=30):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in tqdm(loader, desc='SNN Eval'):
            x = x.float().to(device)
            y = y.long().to(device)

            utils.reset(model)
            spk_sum = 0
            for _ in range(sim_steps):
                spk_out, _ = model(x)
                spk_sum = spk_sum + spk_out

            preds = spk_sum.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.numel()

    return correct / max(total, 1)

ann_test_acc = evaluate_ann(ann_model, test_ann_loader)
snn_test_acc = evaluate_snn(snn_model, test_ann_loader, sim_steps=30)

torch.save(ann_model.state_dict(), 'ann_dvsgesture.pth')
torch.save(snn_model.state_dict(), 'snn_from_ann_dvsgesture_snntorch.pth')

print(f'ANN test accuracy: {ann_test_acc:.4f}')
print(f'SNN test accuracy: {snn_test_acc:.4f}')
print('Saved ann_dvsgesture.pth and snn_from_ann_dvsgesture_snntorch.pth')
