In [25]:
import torch
import torchvision
from torch.utils.data import DataLoader
import snn  # your custom spiking network module
import smnist.tools as tools
import os
import numpy as np
from collections import defaultdict

In [26]:
def load_model(model_path):
    if not os.path.isfile(model_path):
        raise FileNotFoundError(f"Model file not found: {model_path}")
    try:
        model = torch.load(model_path, map_location=torch.device("cpu"))
        print(f"Successfully loaded model: {model_path}")
        return model
    except Exception as e:
        print(f"Error loading model {model_path}: {e}")
        return None

In [27]:
optimized_model_path1 = "smnist/models/SMNIST_BRF.pt"
optimized_model_path2 = "smnist/models/SMNIST_BRF_0.7Loss.pt"
optimized_model_path3 = "smnist/models/SMNIST_BRF_0.6Loss.pt"
optimized_model_path4 = "smnist/models/SMNIST_BRF_0.5Loss.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 10000
sequence_length = 28 * 28
input_size = 1
hidden_size = 256
num_classes = 10
PERMUTED = False  # set True if testing on PSMNIST

In [28]:
# Instantiate
# SimpleVanillaRFRNN    RF
# SimpleResRNN          BRF
# SimpleALIFRNN         ALIF
model = snn.models.SimpleResRNN(
    input_size=input_size,
    hidden_size=hidden_size,
    output_size=num_classes,
    label_last=True
).to(device)

# Load checkpoint
checkpoint = torch.load(optimized_model_path4, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

SimpleResRNN(
  (hidden): BRFCell(
    (linear): Linear(in_features=257, out_features=256, bias=False)
  )
  (out): LICell(
    (linear): Linear(in_features=256, out_features=10, bias=False)
  )
)

In [29]:
test_dataset = torchvision.datasets.MNIST(
    root="smnist/data",
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=(device.type == "cuda"),
    num_workers=0
)


In [30]:
if PERMUTED:
    permuted_idx = torch.load('smnist/models/SMNIST_BRF_init.pt')
else:
    permuted_idx = torch.arange(sequence_length)


In [31]:
def transform_input_batch(tensor, sequence_length_, batch_size_, input_size_, permuted_idx_):
    tensor = tensor.to(device=device).view(batch_size_, sequence_length_, input_size_)
    tensor = tensor.permute(1, 0, 2)
    tensor = tensor[permuted_idx_, :, :]
    return tensor

In [32]:
with torch.no_grad():
    test_loss = 0.
    test_correct = 0
    total_spikes = 0.

    for inputs, targets in test_loader:
        current_batch_size = len(inputs)
        inputs = transform_input_batch(inputs, sequence_length, current_batch_size, input_size, permuted_idx)
        targets = targets.to(device)

        # print(inputs)

        outputs, _, num_spikes = model(inputs)

        total_spikes += num_spikes.item()

        # Loss
        criterion = torch.nn.NLLLoss()
        loss = tools.apply_seq_loss(criterion=criterion, outputs=outputs, target=targets)
        loss_value = loss.item() / sequence_length if not model.label_last else loss.item()
        test_loss += loss_value

        # Accuracy
        batch_correct = tools.count_correct_predictions(outputs.mean(dim=0), targets)
        test_correct += batch_correct

    # Final metrics
    test_loss /= len(test_loader)
    test_accuracy = (test_correct / len(test_dataset)) * 100.0
    SOP = total_spikes / len(test_dataset)
    SOP_per_step = SOP / sequence_length
    firing_rate = total_spikes / (len(test_dataset) * sequence_length * hidden_size)

print(
    f'Test loss: {test_loss:.6f}, Test acc: {test_accuracy:.4f}%, SOP: {SOP:.2f}, SOP per step: {SOP_per_step:.4f}, '
    f'Mean firing rate per neuron: {firing_rate:.6f}'
)

Test loss: 0.160819, Test acc: 95.3700%, SOP: 3436.94, SOP per step: 4.3839, Mean firing rate per neuron: 0.017124
