In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import sinabs
from torchvision import transforms, datasets
import sinabs.layers as sl
import numpy as np
import quartz
from quartz.utils import get_accuracy, encode_inputs, decode_outputs, plot_output_histograms
from mnist_model import ConvNet

np.set_printoptions(suppress=True)

In [None]:
device = 'cuda'
batch_size = 100
num_workers = 4

transform = transforms.Compose([transforms.ToTensor()])
valid_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True)

In [None]:
state_dict = torch.load("mnist-convnet.pth", map_location=torch.device(device))
ann = ConvNet()

ann[0].weight.data = state_dict['conv1.weight']
ann[0].bias.data = state_dict['conv1.bias']
ann[3].weight.data = state_dict['conv2.weight']
ann[3].bias.data = state_dict['conv2.bias']
ann[6].weight.data = state_dict['conv3.weight']
ann[6].bias.data = state_dict['conv3.bias']
ann[10].weight.data = state_dict['fc1.weight']
ann[10].bias.data = state_dict['fc1.bias']

ann.eval();
get_accuracy(ann, valid_loader, device)

In [None]:
param_layers = [name for name, child in ann.named_children() if isinstance(child, (nn.Conv2d, nn.Linear))]
output_layers = [name for name, child in ann.named_children() if isinstance(child, nn.ReLU)]
output_layers += [param_layers[-1]]
normalise_loader = DataLoader(dataset=valid_dataset, batch_size=1000, shuffle=True)
sample_data = next(iter(normalise_loader))[0]
percentile = 99.99

sinabs.utils.normalize_weights(ann, sample_data.to(device), output_layers=output_layers, param_layers=param_layers, percentile=percentile)
get_accuracy(ann, valid_loader, device)

In [None]:
param_layers

In [None]:
output_layers

In [None]:
for exponent in range(1, 7):
    snn = quartz.from_torch.from_model(ann, t_max=2**exponent+1, add_spiking_output=True).to(device).eval()
    accuracy = get_accuracy(snn, valid_loader, device, t_max=2**exponent+1)
    print(accuracy)

In [None]:
t_max = 2**3+1
images, label = next(iter(valid_loader))
spikes = encode_inputs(images, t_max=t_max).to(device)
snn = quartz.from_torch.from_model(ann, t_max=t_max, add_spiking_output=True).to(device).eval()
output_layers = [child for name, child in snn.named_children() if isinstance(child, sl.StatefulLayer)]

plot_output_histograms(snn, spikes, output_layers, t_max=t_max)

In [None]:
param_layers = [child for name, child in ann.named_children() if isinstance(child, (nn.Conv2d, nn.Linear))]
output_layers = [child for name, child in ann.named_children() if isinstance(child, nn.ReLU)]
output_layers += [param_layers[-1]]
plot_output_histograms(ann, images.to(device), output_layers)