# Inference with Norse

## 1. Installation, Imports and Dataset

In [None]:
!pip install norse --quiet
!pip install tonic --quiet
!pip install nir --quiet

In [None]:
# imports
import norse.torch as norse

import torch
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import tonic

poker_test = tonic.datasets.POKERDVS(save_to='./data', train=False)

In [None]:
import tonic.transforms as transforms
from tonic import DiskCachedDataset

# time_window
frame_transform = tonic.transforms.Compose([tonic.transforms.Denoise(filter_time=10000),
                                            tonic.transforms.ToFrame(
                                            sensor_size=tonic.datasets.POKERDVS.sensor_size,
                                            time_window=1000)
                                            ])

batch_size = 8
cached_testset = DiskCachedDataset(poker_test, transform=frame_transform, cache_path='./cache/pokerdvs/test')
test_loader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)

## 2. Load NIR graph into Norse

In [None]:
import nir

In [None]:
nir_model = nir.read("nir_model.nir")

In [None]:
net = norse.from_nir(nir_model, dt=0.0001)

In [None]:
net.to(device)

**TODO: add some information about the `nirtorch.GraphExecutor`** 

## 3. Run the model with a single batch of data

The graph executor can run a single forward step. Let's add the recurrence...

In [None]:
def apply(data):
    """
    apply an input data batch to the norse model
    """
    state = None
    hid_rec = []
    out = []
    
    for i, t in enumerate(data):
        z, state = net(t.flatten(1), state)
        out.append(z)
        hid_rec.append(state)
    spk_out = torch.stack(out)
    # hid_rec = torch.stack(hid_rec)
    return spk_out, hid_rec

Apply to a batch of data

In [None]:
data, targets = next(iter(test_loader))

spk, hid = apply(data)

# count the number of spikes for each neuron and assess the winner
predictions = spk.sum(axis=0).argmax(axis=-1)
print(f"Predicted classes: {predictions}")
print(f"Actual classes:    {targets}")

### 4. Measure accuracy for test dataset

In [None]:
def measure_accuracy(model, dataloader):
  with torch.no_grad():
    # model.eval()  # not needed!
    running_length = 0
    running_accuracy = 0

    for data, targets in iter(dataloader):
      data = data.to(device)
      targets = targets.to(device)

      # forward-pass
      spk_rec, _ = model(data)
      spike_count = spk_rec.sum(0) # batch x num_outputs
      _, max_spike = spike_count.max(1)

      # correct classes for one batch
      num_correct = (max_spike == targets).sum()

      # total accuracy
      running_length += len(targets)
      running_accuracy += num_correct

    accuracy = (running_accuracy / running_length)

    return accuracy.item()

In [None]:
measure_accuracy(apply, test_loader)