In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import torchvision.models as models
import sinabs
from torchvision import datasets, transforms
from PIL import Image
import sinabs.layers as sl
import numpy as np
import quartz
import copy
from tqdm.auto import tqdm
from quartz.utils import get_accuracy, encode_inputs, decode_outputs, normalize_outputs, plot_output_comparison, plot_output_comparison_new, normalize_weights, count_n_neurons, fuse_all_conv_bn
from typing import List

np.set_printoptions(suppress=True)

In [None]:
from cifar10_models.vgg import vgg11_bn

model = vgg11_bn(pretrained=True)
model.eval();

In [None]:
sum(p.numel() for p in model.parameters() if p.requires_grad)/1_000_000

In [None]:
relu_count = 0
for module in model.modules():
    if isinstance(module, (nn.ReLU, nn.ReLU6)):
        module.inplace = False
        relu_count += 1
print(f"Model contains {relu_count} relu layers.")

In [None]:
batch_size = 128
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))
])

valid_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=batch_size, num_workers=4)
snn_loader = DataLoader(dataset=valid_dataset, batch_size=16, shuffle=True, num_workers=4)

In [None]:
count_n_neurons(model.cpu(), next(iter(valid_loader))[0][:1], add_last_layer=True)

In [None]:
folded_model = copy.deepcopy(model)
fuse_all_conv_bn(folded_model)

In [None]:
norm_model = copy.deepcopy(folded_model)
test_loader = DataLoader(dataset=valid_dataset, batch_size=10000, shuffle=True, num_workers=4)
sample_data = next(iter(test_loader))[0]#.cuda()
percentile = 99.99
normalize_outputs(norm_model, sample_data=sample_data, percentile=percentile, max_outputs=[])

In [None]:
preprocess_layers = nn.Sequential(
    *norm_model.features[:3]
)
ann = nn.Sequential(
    norm_model.features[3:],
    norm_model.avgpool,
    nn.Flatten(),
    norm_model.classifier
)
composed_model = nn.Sequential(
    preprocess_layers,
    ann
)

In [None]:
100 - get_accuracy(composed_model, valid_loader, device)

In [None]:
# plot_output_comparison_new(folded_model, norm_model, sample_input=next(iter(valid_loader))[0].cuda(), every_n=1000)

In [None]:
accuracies = []
for exponent in range(3, 7):
    t_max = 2**exponent
    snn = copy.deepcopy(ann)
    quartz.from_torch.from_model2(snn, t_max=t_max)
    snn = nn.Sequential(snn, quartz.IF(t_max=t_max, rectification=False))
    snn = snn.eval()
    # print(f"percentile: {percentile}, t_max: {t_max}")
    accuracy = get_accuracy(snn, snn_loader, device, preprocess=preprocess_layers, t_max=t_max, print_early_spikes=True, print_output_time=True)
    print(accuracy)
    accuracies.append(accuracy)

In [None]:
100 - np.array(accuracies)