In [None]:
import torch
import torch.nn as nn
import snntorch as snn
import numpy as np
from snntorch.import_nir import import_from_nir
import nir

In [None]:
from tonic import datasets, transforms

def get_data_loaders(batch_size):
    """Create data loaders"""
    transform = transforms.Compose([
        transforms.ToFrame(sensor_size=(34, 34, 2), n_time_bins=30, include_incomplete=True),
        lambda x: torch.from_numpy(x.astype(np.float32)),
    ])

    trainset = datasets.NMNIST('./data', train=True, transform=transform)
    testset = datasets.NMNIST('./data', train=False, transform=transform)

    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=0,
        drop_last=True, pin_memory=True
    )
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=batch_size, num_workers=0,
        drop_last=True, pin_memory=True
    )

    return trainloader, testloader

In [None]:
graph = nir.read("cnn_sinabs.nir")
graph.nodes.keys()
net = import_from_nir(graph)
print(net)

In [None]:
modules = [e.elem for e in net.get_execution_order()]

# init all I&F neurons
mem_dict = {}
for idx, module in enumerate(modules):
  if isinstance(module, snn.Leaky):
    module.mem = module.init_leaky()

In [None]:
from neurobench.processors.postprocessors import ChooseMaxCount
from neurobench.benchmarks import Benchmark
from neurobench.models import SNNTorchModel

from neurobench.metrics.workload import (
    ActivationSparsity,
    MembraneUpdates,
    SynapticOperations,
    ActivationSparsityByLayer,
)
from neurobench.metrics.static import (
    ParameterCount,
    Footprint,
    ConnectionSparsity,
)

_, testloader = get_data_loaders(batch_size=1)


model = SNNTorchModel(net, custom_forward=False)

static_metrics = [ParameterCount, Footprint, ConnectionSparsity]
workload_metrics = [ActivationSparsity, ActivationSparsityByLayer,MembraneUpdates, SynapticOperations]

benchmark = Benchmark(
    model, testloader, [], [], [static_metrics, workload_metrics]
)
results = benchmark.run(verbose=False)

print(results)