In [None]:
import pytorch_lightning as pl
from cnn import CNN
from nmnist import NMNISTFrames, NMNISTRaster
import sinabs
from tqdm import tqdm
import torch
import logging

logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)

Let's load the trained CNN and get the validation accuracy

In [None]:
batch_size = 64
trainer = pl.Trainer(logger=None)
model = CNN.load_from_checkpoint('checkpoints/cnn-step=7030-epoch=04-valid_loss=0.06-valid_acc=0.98.ckpt')
# increase num_workers to speed up but caused issues on Mac
frames_dataset = NMNISTFrames(save_to='data', batch_size=batch_size, precision=32, num_workers=0)

trainer.test(model, frames_dataset)

In [None]:
for param in model.model.parameters():
    print(param.abs().max().item())

Here we're going to convert the CNN to an SNN and test it on spiking raster data

In [None]:
# increase num_workers to speed up but caused issues on Mac
raster_dataset = NMNISTRaster(save_to='data', batch_size=batch_size, n_time_bins=20, precision=32, num_workers=0)
raster_dataset.setup()
dataloader = raster_dataset.test_dataloader()

In [None]:
snn = sinabs.from_torch.from_model(model.model, batch_size=batch_size, add_spiking_output=False).spiking_model

In [None]:
import sinabs.layers as sl

def get_accuracy(model, dataloader, device, flatten_input=False):
    model = model.to(device)
    predictions = []
    for rasters, labels in tqdm(dataloader):
        rasters = rasters.to(device)
        batch_size = rasters.shape[0]
        labels = labels.to(device)
        [layer.reset_states() for layer in model.modules() if isinstance(layer, sl.StatefulLayer)]
        with torch.no_grad():
            if flatten_input: rasters = rasters.flatten(0, 1)
            output = model(rasters)
            if flatten_input: output = output.unflatten(0, (batch_size, -1))
            predictions.append(output.sum(1).argmax(1) == labels)
    return torch.cat(predictions).float().mean().item() * 100

In [None]:
get_accuracy(snn, dataloader, device="cpu", flatten_input=True)