In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.quantization
import torchvision
import matplotlib.pyplot as plt
from quartz.utils import get_accuracy, plot_output_histograms
import quartz
from copy import deepcopy

In [None]:
batch_size = 16
device = 'cuda'

valid_dataset = datasets.CIFAR10(root='./data', train=False, transform=transforms.ToTensor(), download=True)
valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=batch_size)
test_loader = DataLoader(dataset=valid_dataset, batch_size=100, shuffle=True)


In [None]:
ann = torch.load("./cifar-convnet-normalised.pth", map_location=torch.device(device))
ann.eval();

In [None]:
q_ann = deepcopy(ann)
for exponent in range(4, 8):
    t_max = 2**exponent+1
    def quantize(module, input, output):
        return (output * t_max).int() / t_max

    for module in q_ann.children():
        if isinstance(module, nn.ReLU):
            module.register_forward_hook(quantize)
    q_ann[-1].register_forward_hook(quantize)

    accuracy = get_accuracy(q_ann, valid_loader, device)
    print(f"{t_max} time steps: {round(accuracy, 4)}%")

In [None]:
sample_data = next(iter(test_loader))[0]
output_layers = [layer for layer in ann.children() if isinstance(layer, nn.ReLU)]
output_layers += [ann[-1]]
# plot_output_histograms(ann, sample_data, )
output_layers

In [None]:
sample_data.shape

In [None]:
get_accuracy(ann, valid_loader, device)

In [None]:
exponent = 4
snn = quartz.from_torch.from_model(ann, t_max=2**exponent+1, add_spiking_output=True)

In [None]:
snn

In [None]:
ann = ann.cuda()

for exponent in range(4, 8):
    t_max = 2**exponent+1
    snn = quartz.from_torch.from_model(ann, t_max=t_max, add_spiking_output=True).to(device).eval()
    snn[-1].rectification = False
    accuracy = get_accuracy(snn, valid_loader, device, t_max=t_max)
    print(f"{t_max} time steps: {round(accuracy, 3)}%")