In [1]:
import torch
import tonic
import norse
import nir
import numpy as np
from collections import defaultdict
import tqdm

In [2]:
nir_graph = nir.read("scnn_mnist.nir")
{k: type(v) for k, v in nir_graph.nodes.items()}

{'0': nir.ir.Conv2d,
 '1': nir.ir.IF,
 '10': nir.ir.IF,
 '11': nir.ir.Affine,
 '12': nir.ir.IF,
 '2': nir.ir.Conv2d,
 '3': nir.ir.IF,
 '4': nir.ir.SumPool2d,
 '5': nir.ir.Conv2d,
 '6': nir.ir.IF,
 '7': nir.ir.SumPool2d,
 '8': nir.ir.Flatten,
 '9': nir.ir.Affine,
 'input': nir.ir.Input,
 'output': nir.ir.Output}

In [19]:
g = norse.torch.from_nir(nir_graph)
children = list(g.children())

In [23]:
to_frame = tonic.transforms.ToFrame(sensor_size=tonic.datasets.NMNIST.sensor_size, time_window = 1e3)
dataset = tonic.datasets.NMNIST(".", transform=to_frame, train=False)
loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=10, collate_fn=tonic.collation.PadTensors(batch_first=False))

Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/a99d0fee-a95b-4231-ad22-988fdb0a2411 to ./NMNIST/test.zip


  0%|          | 0/169674850 [00:00<?, ?it/s]

Extracting ./NMNIST/test.zip to ./NMNIST


In [24]:
def evaluate(xs):
    out = torch.zeros(10)
    states = defaultdict(lambda: None)
    for x in xs:
        for i in [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 2, 3, 4]:
            module = children[i]
            if isinstance(module, norse.torch.IAFCell):
                x, states[i] = module(x, states[i])
            else:
                x = module(x)
            if isinstance(x, tuple):
                x = x[0]
        out = out + x
    return out / len(xs)

In [None]:
losses = []
accuracies = []
with torch.no_grad():
    for batch in tqdm.tqdm(loader):
        x, y = batch
        pred = evaluate(x)
        acc = pred.argmax(1) == y
        accuracies.append(acc)
        loss = torch.nn.functional.cross_entropy(pred, y)
        losses.append(loss)
print(torch.stack(losses).mean())

  0%|▌                                                                                                     | 5/1000 [00:02<07:04,  2.34it/s]

In [120]:
acc_tensor = torch.stack(accuracies).flatten()
acc_tensor.sum() / len(acc_tensor)

tensor(0.9804)

In [122]:
np.save("norse_acc.npy", (acc_tensor.sum() / len(acc_tensor)).numpy())