In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import sinabs
from torchvision import transforms, datasets
from PIL import Image
import sinabs.layers as sl
import numpy as np
import quartz
import copy
from mnist_model import ConvNet
from tqdm.auto import tqdm

np.set_printoptions(suppress=True)

In [None]:
device = 'cuda'
batch_size = 128
num_workers = 4

transform = transforms.Compose([transforms.ToTensor()])
valid_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)

In [None]:
state_dict = torch.load("mnist-convnet.pth", map_location=torch.device(device))
ann = ConvNet()

In [None]:
ann[0].weight.data = state_dict['conv1.weight']
ann[0].bias.data = state_dict['conv1.bias']
ann[3].weight.data = state_dict['conv2.weight']
ann[3].bias.data = state_dict['conv2.bias']
ann[6].weight.data = state_dict['conv3.weight']
ann[6].bias.data = state_dict['conv3.bias']
ann[9].weight.data = state_dict['fc1.weight']
ann[9].bias.data = state_dict['fc1.bias']

ann.eval();

In [None]:
def get_mnist_accuracy(model, data_loader, device, t_max=None):
    correct_pred = 0 
    n = 0
    model.eval()
    for X, y_true in tqdm(data_loader):
        X = X.to(device)
        if t_max is not None:
            X = quartz.encode_inputs(X, t_max=t_max).to(device)
        y_true = y_true.to(device)
        with torch.no_grad():
            y_prob = model(X)
        if t_max is not None:
            y_prob = quartz.decode_outputs(y_prob, t_max=t_max)
        _, predicted_labels = torch.max(y_prob, 1)
        n += y_true.size(0)
        correct_pred += (predicted_labels == y_true).sum()
    return correct_pred.float() / n

In [None]:
get_mnist_accuracy(ann, valid_loader, device)

In [None]:
output_layers = [name for name, child in ann.named_children() if isinstance(child, nn.ReLU)]
param_layers = [name for name, child in ann.named_children() if isinstance(child, (nn.Conv2d, nn.Linear))]
output_layers += [param_layers[-1]]

sample_data = next(iter(valid_loader))[0]

percentile = 99

In [None]:
sinabs.utils.normalize_weights(ann, sample_data.to(device), output_layers=output_layers, param_layers=param_layers, percentile=percentile)

In [None]:
get_mnist_accuracy(ann, valid_loader, device)

In [None]:
t_max = 2**7

snn = quartz.from_torch.from_model(ann, t_max=t_max, batch_size=batch_size, add_spiking_output=True).to(device).eval()

In [None]:
get_mnist_accuracy(snn, valid_loader, device, t_max=t_max)