In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import sinabs
from torchvision import transforms, datasets
from PIL import Image
import sinabs.layers as sl
import numpy as np
import quartz
import copy
from mnist_model import ConvNet

np.set_printoptions(suppress=True)

In [None]:
device = 'cuda'
batch_size = 128
num_workers = 4

transform = transforms.Compose([transforms.ToTensor()])
valid_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=True, num_workers=num_workers)

In [None]:
state_dict = torch.load("mnist-convnet.pth", map_location=torch.device(device))

In [None]:
ann = ConvNet()
ann[0].weight.data = state_dict['conv1.weight']
ann[0].bias.data = state_dict['conv1.bias']
ann[4].weight.data = state_dict['conv2.weight']
ann[4].bias.data = state_dict['conv2.bias']
ann[8].weight.data = state_dict['conv3.weight']
ann[8].bias.data = state_dict['conv3.bias']
ann[12].weight.data = state_dict['fc1.weight']
ann[12].bias.data = state_dict['fc1.bias']

In [None]:
def get_mnist_accuracy(model, data_loader, device):
    correct_pred = 0 
    n = 0
    with torch.no_grad():
        model.eval()
        for X, y_true in data_loader:
            X = X.to(device)
            y_true = y_true.to(device)
            y_prob = model(X)
            _, predicted_labels = torch.max(y_prob, 1)
            n += y_true.size(0)
            correct_pred += (predicted_labels == y_true).sum()
    return correct_pred.float() / n

In [None]:
get_mnist_accuracy(ann, valid_loader, device)

In [None]:
output_layers = [name for name, child in ann.named_children() if isinstance(child, nn.ReLU)]
param_layers = [name for name, child in ann.named_children() if isinstance(child, (nn.Conv2d, nn.Linear))]

sample_data = next(iter(valid_loader))[0]

percentile = 99

In [None]:
sinabs.utils.normalize_weights(ann, sample_data.to(device), output_layers=output_layers, param_layers=param_layers, percentile=percentile)

In [None]:
get_mnist_accuracy(ann, valid_loader, device)

In [None]:
t_max = 2**8

snn = quartz.from_torch.from_model(ann, t_max=t_max, batch_size=batch_size).to(device)
snn.eval();

In [None]:
data, labels = next(iter(valid_loader))
temp_data = quartz.encode_inputs(data, t_max=t_max).to(device)

In [None]:
with torch.no_grad():
    temp_output = snn(temp_data.flatten(0, 1)).unflatten(0, (batch_size, -1))
    snn_output = quartz.decode_outputs(temp_output, t_max=t_max)
    

In [None]:
snn_output