In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import sinabs
from torchvision import transforms, datasets
import sinabs.layers as sl
import numpy as np
import quartz
from quartz.utils import get_accuracy
from mnist_model import ConvNet

np.set_printoptions(suppress=True)

In [None]:
device = 'cuda'
batch_size = 100
num_workers = 4

transform = transforms.Compose([transforms.ToTensor()])
valid_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)

In [None]:
state_dict = torch.load("mnist-convnet.pth", map_location=torch.device(device))
ann = ConvNet()

In [None]:
ann[0].weight.data = state_dict['conv1.weight']
ann[0].bias.data = state_dict['conv1.bias']
ann[3].weight.data = state_dict['conv2.weight']
ann[3].bias.data = state_dict['conv2.bias']
ann[6].weight.data = state_dict['conv3.weight']
ann[6].bias.data = state_dict['conv3.bias']
ann[9].weight.data = state_dict['fc1.weight']
ann[9].bias.data = state_dict['fc1.bias']

ann.eval();

In [None]:
get_accuracy(ann, valid_loader, device)

In [None]:
output_layers = [name for name, child in ann.named_children() if isinstance(child, nn.ReLU)]
param_layers = [name for name, child in ann.named_children() if isinstance(child, (nn.Conv2d, nn.Linear))]
output_layers += [param_layers[-1]]
sample_data = next(iter(valid_loader))[0]
percentile = 99

In [None]:
sinabs.utils.normalize_weights(ann, sample_data.to(device), output_layers=output_layers, param_layers=param_layers, percentile=percentile)

In [None]:
get_accuracy(ann, valid_loader, device)

In [None]:
ann[-1].weight.data /= 2
ann[-1].bias.data /= 2

In [None]:
for exponent in range(3, 8):
    snn = quartz.from_torch.from_model(ann, t_max=2**exponent, batch_size=batch_size, add_spiking_output=True).to(device).eval()
    accuracy = get_accuracy(snn, valid_loader, device, t_max=2**exponent)
    print(accuracy)