# Introduction


In [2]:
!pip install tonic
!pip install sinabs
!pip install torchmetrics
!pip install numpy --upgrade


Collecting numpy<2.0.0 (from tonic)
  Using cached numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Using cached numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.3.2
    Uninstalling numpy-2.3.2:
      Successfully uninstalled numpy-2.3.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
opencv-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 1.26.4 which is incompatible.
opencv-contrib-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 1.26.4 which is incompatible.
opencv-python-headless 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 1.26.4 which is incompatible.
thinc 8.3.6

Collecting numpy
  Using cached numpy-2.3.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (62 kB)
Using cached numpy-2.3.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.6 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tonic 1.6.0 requires numpy<2.0.0, but you have numpy 2.3.2 which is incompatible.
opencv-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 2.3.2 which is incompatible.
opencv-contrib-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 2.3.2 which is incompatible.
opencv-python-headless 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torchmetrics
import os
import sinabs
import sinabs.layers as sl
from tonic import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

from sinabs.hooks import register_synops_hooks, firing_rate_per_neuron_hook

In [2]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon GPU)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA GPU")
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device("cpu")
    print("Using CPU")

Using CPU


In [3]:
def create_model(batch_size):
    """Create the SNN model"""
    backend = sl
    min_v_mem = -1.

    model = nn.Sequential(
        sl.FlattenTime(),
        nn.Conv2d(2, 8, kernel_size=3, padding=1, bias=False),
        backend.IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem),
        sl.SumPool2d(2),
        nn.Conv2d(8, 16, kernel_size=3, padding=1, bias=False),
        backend.IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem),
        sl.SumPool2d(2),
        nn.Conv2d(16, 32, kernel_size=3, padding=1, bias=False),
        backend.IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem),
        sl.SumPool2d(2),
        nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
        backend.IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem),
        sl.SumPool2d(2),
        nn.Conv2d(64, 10, kernel_size=2, padding=0, bias=False),
        backend.IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem),
        nn.Flatten(),
        sl.UnflattenTime(batch_size=batch_size),
    )
    return model


Bias are set to False to avoid chip running at fixed frequency: https://sinabs.readthedocs.io/v3.0.3/speck/notebooks/leak_neuron.html

# DataLoading

Load NeuromorphicMNIST dataset.

In [4]:
from tonic import datasets, transforms

def get_data_loaders(batch_size):
    """Create data loaders"""
    transform = transforms.Compose([
        transforms.ToFrame(sensor_size=(34, 34, 2), n_time_bins=30, include_incomplete=True),
        lambda x: torch.from_numpy(x.astype(np.float32)),
    ])

    trainset = datasets.NMNIST('./data', train=True, transform=transform)
    testset = datasets.NMNIST('./data', train=False, transform=transform)

    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=0,
        drop_last=True, pin_memory=True
    )
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=batch_size, num_workers=0,
        drop_last=True, pin_memory=True
    )

    return trainloader, testloader

In [5]:
def evaluate_model(model, testloader, device):
    """Evaluate model on test set"""
    acc = torchmetrics.Accuracy('multiclass', num_classes=10).to(device)
    model.eval()

    total_loss = 0
    num_batches = 0

    with torch.no_grad():
        for data, targets in tqdm(testloader, desc="Evaluating"):
            data, targets = data.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            sinabs.reset_states(model)

            y_hat = model(data)
            pred = y_hat.sum(1)
            loss = nn.functional.cross_entropy(pred, targets)

            acc(pred, targets)
            total_loss += loss.item()
            num_batches += 1

    accuracy = acc.compute().item()
    avg_loss = total_loss / num_batches

    return accuracy, avg_loss

In [7]:
def train_and_evaluate():
    """Main training and evaluation function for wandb sweep"""

    learning_rate = 1e-4
    batch_size = 32
    n_epochs = 1

    # Create model and move to device
    model = create_model(batch_size)
    model = model.to(device)


    # Create data loaders
    trainloader, testloader = get_data_loaders(batch_size)

    # Create optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(n_epochs):
        model.train()
        epoch_losses = []

        pbar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{n_epochs}')
        sinabs.reset_states(model)
        sinabs.zero_grad(model)
        optimizer.zero_grad()

        for data, targets in pbar:
            sinabs.reset_states(model)
            optimizer.zero_grad()

            data, targets = data.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            output = model(data).sum(1)

            loss = nn.functional.cross_entropy(output, targets)

            loss.backward()
            optimizer.step()

            current_loss = loss.item()
            epoch_losses.append(current_loss)

            # Update progress bar
            pbar.set_postfix({
                'loss': f'{current_loss:.4f}',
            })

        # Evaluate on test set
        accuracy, test_loss = evaluate_model(model, testloader, device)



    # Final evaluation
    final_accuracy, final_loss = evaluate_model(model, testloader, device)



    print(f"Final accuracy: {final_accuracy:.2%}")
    print(f"Final loss: {final_loss:.4f}")

    save_path = './'
    torch.save(model.cpu(), save_path+f"{final_accuracy}.pth")

In [8]:
train_and_evaluate()

Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/1afc103f-8799-464a-a214-81bb9b1f9337 to ../../data/NMNIST/train.zip


  0%|          | 0/1011893601 [00:00<?, ?it/s]

Extracting ../../data/NMNIST/train.zip to ../../data/NMNIST
Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/a99d0fee-a95b-4231-ad22-988fdb0a2411 to ../../data/NMNIST/test.zip


  0%|          | 0/169674850 [00:00<?, ?it/s]

Extracting ../../data/NMNIST/test.zip to ../../data/NMNIST


Epoch 1/1:   2%|▏         | 29/1875 [01:42<1:48:59,  3.54s/it, loss=2.2289, best_loss=inf]


KeyboardInterrupt: 

In [9]:
import sinabs.layers as sl

from sinabs.hooks import register_synops_hooks, firing_rate_hook, get_hook_data_dict

def record_output_hook(module, input, output):
    data = get_hook_data_dict(module)
    data["output"] = output

def register_hooks(model):
    register_synops_hooks(model)
    for layer in model:
        if isinstance(layer, sl.StatefulLayer) and layer.does_spike:
            layer.register_forward_hook(firing_rate_hook)
            layer.register_forward_hook(record_output_hook)

def sinabs_total_syn_ops(model):
    # NOTE: This is averaged across the batch and timesteps
    return model.hook_data['total_synops_per_timestep'].item()

def sinabs_syn_ops_by_layer(model):
    # NOTE: This is averaged across the batch and timesteps
    counts = [x.item() for x in model.hook_data['synops_per_timestep'].values()]
    return counts

def sinabs_firing_rate_by_layer(model):
    # NOTE: This is per neuron averaged across the batch and timesteps
    rates = []
    for layer in model:
        if isinstance(layer, sl.StatefulLayer) and layer.does_spike:
            rates.append(layer.hook_data['firing_rate'].item())

    return rates

def activation_sparsity(model):
    total_spike_num = 0  # Count of non-zero activations
    total_neuro_num = 0  # Count of all activations

    sparsity_by_layer = []

    for layer in model:
            if isinstance(layer, sl.StatefulLayer) and layer.does_spike:
                total_spike_num += layer.hook_data['output'].gt(0).sum().item()
                total_neuro_num += layer.hook_data['output'].numel()

                sparsity_by_layer.append((layer.hook_data['output'].numel() - layer.hook_data['output'].gt(0).sum().item())/layer.hook_data['output'].numel())

    sparsity = (total_neuro_num - total_spike_num) / total_neuro_num
    return sparsity, sparsity_by_layer

In [14]:
trainloader, testloader = get_data_loaders(32)
loader_data = next(iter(trainloader))



In [16]:
print(x[0].shape)

torch.Size([32, 30, 2, 34, 34])


In [None]:
model = create_model(32)
register_hooks(model)
data = loader_data[0]

results_list = []

for path in os.rglob()
  model = torch.load("./saved_models/")
  model(data)
  sinabs_total_syn_ops(cnn.model), sinabs_syn_ops_by_layer(cnn.model), sinabs_firing_rate_by_layer(cnn.model), activation_sparsity(cnn.model)

Evaluating:   6%|▌         | 19/312 [00:27<08:23,  1.72s/it]