In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import torchvision.models as models
import sinabs
from torchvision import datasets, transforms
from PIL import Image
import sinabs.layers as sl
import numpy as np
import quartz
import copy
from tqdm.auto import tqdm
from quartz.utils import get_accuracy, encode_inputs, decode_outputs, normalize_outputs, plot_output_comparison, plot_output_comparison_new, normalize_weights, count_n_neurons, fuse_all_conv_bn
from typing import List
import seaborn as sns
import matplotlib.pyplot as plt

np.set_printoptions(suppress=True)

In [None]:
from cifar10_models.vgg import vgg11_bn

model = vgg11_bn(pretrained=True)
model.eval();

In [None]:
sum(p.numel() for p in model.parameters() if p.requires_grad)/1_000_000

In [None]:
relu_count = 0
for module in model.modules():
    if isinstance(module, (nn.ReLU, nn.ReLU6)):
        module.inplace = False
        relu_count += 1
print(f"Model contains {relu_count} relu layers.")

In [None]:
batch_size = 128
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))
])

valid_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=batch_size, num_workers=4)
snn_loader = DataLoader(dataset=valid_dataset, batch_size=16, shuffle=True, num_workers=4)

In [None]:
count_n_neurons(model.cpu(), next(iter(valid_loader))[0][:1], add_last_layer=True)

In [None]:
folded_model = copy.deepcopy(model)
fuse_all_conv_bn(folded_model)

In [None]:
from fvcore.nn import FlopCountAnalysis, flop_count_table
flops = FlopCountAnalysis(folded_model, next(iter(valid_loader))[0])
# flop_count_table(flops)
print(flops.total()/1e6)
# flops.by_module_and_operator()

In [None]:
n_synops = 1.8e6
n_neurons = 159_754
t_max = 64

n_operations = lambda n_neurons, t_max, n_synops: round((n_synops + 2*n_neurons*t_max)/1e6, 3)
omega_read = lambda n_neurons, t_max, n_synops: round((4*n_neurons*t_max+n_synops)/1e6)
omega_write = lambda n_neurons, t_max, n_synops: round((n_synops + n_neurons*t_max)/1e6)

print(f"Number of operations: {n_operations(n_neurons, t_max, n_synops)}M.")
print(f"Read: {omega_read(n_neurons, t_max, n_synops)}M, write: {omega_write(n_neurons, t_max, n_synops)}M, total: {omega_read(n_neurons, t_max, n_synops)+omega_write(n_neurons, t_max, n_synops)}M")

[n_operations(n_neurons, 2**exponent, n_synops) for exponent in range(3,7)]

In [None]:
norm_model = copy.deepcopy(folded_model)
test_loader = DataLoader(dataset=valid_dataset, batch_size=1000, shuffle=True, num_workers=4)
sample_data = next(iter(test_loader))[0]#.cuda()
percentile = 99.99
input_scale_factor = np.percentile(sample_data, percentile)
normalize_outputs(norm_model.cpu(), sample_data=sample_data, percentile=percentile, max_outputs=[])

In [None]:
# valid_dataset.transforms.transform.transforms.append(lambda x: x/input_scale_factor)

In [None]:
valid_dataset.transforms.transform.transforms

In [None]:
# preprocess_layers = nn.Sequential(
#     *norm_model.features[:3]
# )
# ann = nn.Sequential(
#     norm_model.features[3:],
#     norm_model.avgpool,
#     nn.Flatten(),
#     norm_model.classifier
# )
# composed_model = nn.Sequential(
#     preprocess_layers,
#     ann
# )

In [None]:
# next(iter(valid_loader))[0].shape

In [None]:
get_accuracy(norm_model, valid_loader, device)

In [None]:
# q_ann = copy.deepcopy(norm_model)
# for exponent in range(1, 7):
#     t_max = 2**exponent
#     def quantize(module, input, output):
#         return (output * t_max).round() / t_max

#     for module in q_ann.modules():
#         if isinstance(module, nn.ReLU):
#             module.register_forward_hook(quantize)
#     q_ann.classifier[-1].register_forward_hook(quantize)

#     accuracy = get_accuracy(q_ann, valid_loader, device)
#     print(f"{t_max} time steps: {round(accuracy, 3)}%")

In [None]:
def plot_quantization_error(model1, model2, sample_input, savefig=None):
    sns.set_theme(style="dark")
    output_layer_pairs = [((name1, layer1), (name2, layer2)) for (name1, layer1), (name2, layer2) in zip(model1.named_modules(), model2.named_modules()) if isinstance(layer1, (nn.Conv2d, nn.Linear)) and isinstance(layer2, (nn.Conv2d, nn.Linear))]
    n_output_layers = len(output_layer_pairs)

    model1 = model1.eval()
    model2 = model2.eval()

    activations1 = []
    activations2 = []
    def hook1(module, inp, output):
        activations1.append(output.detach())

    t_max = 2**4
    def quantize(module, input, output):
        q_output = (output * t_max).round() / t_max
        activations2.append(q_output.detach())
        return q_output

    distances = []
    for i, ((name1, layer1), (name2, layer2)) in enumerate(output_layer_pairs):
        if isinstance(layer1, (nn.Conv2d, nn.Linear)):
            handle1 = layer1.register_forward_hook(hook1)
            handle2 = layer2.register_forward_hook(quantize)

            model1(sample_input)
            model2(sample_input)

            print(len(activations1), len(activations2))
            difference = (activations1[-1] - activations2[-1])**2
            distance = difference.sum() / activations1[-1].numel()
            distances.append(distance.item())

            activations1 = []
            activations2 = []
            handle1.remove()
            # handle2.remove()
            # not removing quantization handle here

    axis = sns.barplot(x=np.arange(n_output_layers), y=distances)
    axis.set_yscale("log")
    # axes[i].set_xlabel(f"Original activations layer {name1}")
    # axes[i].set_ylabel('Normalised activations')
    # axes[i].grid(True)

    if savefig:
        plt.tight_layout()
        plt.savefig(savefig)

In [None]:
q_model = copy.deepcopy(norm_model)

sample_input = next(iter(valid_loader))[0].cuda()
plot_quantization_error(norm_model, q_model, sample_input=sample_input)

In [None]:
data = [6.208464128576452e-06, 8.805314791970886e-06, 1.2459289791877382e-05, 1.2434165910235606e-05, 1.755043922457844e-05, 1.7589043636689894e-05, 3.5205608583055437e-05, 2.6136451197089627e-05, 2.123976628354285e-05, 2.2427710064221174e-05, 0.0005176494014449418]

sns.barplot(x=np.arange(len(data)), y=data)

In [None]:
# plot_output_comparison_new(folded_model.cuda(), norm_model.cuda(), sample_input=next(iter(valid_loader))[0].cuda(), every_n=1000)

In [None]:
accuracies = []
for exponent in range(4, 7):
    t_max = 2**exponent
    snn = copy.deepcopy(norm_model)
    quartz.from_torch.from_model2(snn, t_max=t_max)
    snn = nn.Sequential(snn, quartz.IF(t_max=t_max, rectification=False))
    snn = snn.eval()
    # print(f"percentile: {percentile}, t_max: {t_max}")
    accuracy = get_accuracy(snn, snn_loader, device, t_max=t_max, print_early_spikes=True, print_output_time=True) # preprocess=preprocess_layers, 
    print(accuracy)
    accuracies.append(accuracy)

In [None]:
100 - np.array(accuracies)

In [None]:
n_ops = [layer.n_ops for layer in snn.modules() if isinstance(layer, sl.StatefulLayer)]
torch.stack(n_ops).sum()

In [None]:
n_ops

In [None]:
n_synops_rate = 576e6

3*n_synops_rate