In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import torchvision.models as models
import sinabs
from torchvision import datasets, transforms
from PIL import Image
import sinabs.layers as sl
import numpy as np
import quartz
import copy
from tqdm.auto import tqdm
from quartz.utils import get_accuracy, encode_inputs, decode_outputs, remove_identity_layers, plot_output_histograms, normalize_outputs, plot_output_comparison, plot_output_comparison_new, normalize_weights, count_n_neurons, fuse_all_conv_bn
from typing import List

np.set_printoptions(suppress=True)

In [None]:
# from cifar10_models.mobilenetv2 import mobilenet_v2

# model = mobilenet_v2(pretrained=True)
# model = nn.Sequential(*model.features, nn.AvgPool2d(4), nn.Flatten(), *model.classifier)
# model.eval();

In [None]:
from cifar10_models.vgg import vgg11_bn

model = vgg11_bn(pretrained=True)
model.eval();

In [None]:
from cifar10_models.resnet import resnet18

model = resnet18(pretrained=True)
model.eval();

In [None]:
relu_count = 0
for module in model.modules():
    if isinstance(module, (nn.ReLU, nn.ReLU6)):
        module.inplace = False
        relu_count += 1
print(f"Model contains {relu_count} relu layers.")

In [None]:
# model

In [None]:
batch_size = 128
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))
])

valid_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=batch_size, num_workers=4)
test_loader = DataLoader(dataset=valid_dataset, batch_size=1000, shuffle=True, num_workers=4)

In [None]:
get_accuracy(model, valid_loader, device)

In [None]:
folded_model = copy.deepcopy(model)
fuse_all_conv_bn(folded_model)

In [None]:
get_accuracy(folded_model, valid_loader, device)

In [None]:
def normalize_outputs(
    model: nn.Module,
    sample_data: torch.Tensor,
    percentile: float = 99,
    max_outputs = []
):
    def save_data(lyr, input, output):
        # max_outputs.append(1.2)
        max_outputs.append(np.percentile(output.cpu().numpy(), percentile))

    module_input = []
    def get_module_input(module, input, output):
        module_input.append(input[0])

    for name, module in model.named_children(): # immediate children
        if list(module.named_children()): # is not empty (not a leaf)
            handle = module.register_forward_hook(get_module_input)
            with torch.no_grad():
                model(sample_data)
            sample_input = module_input[-1]
            max_outputs = normalize_outputs(module, sample_input, percentile, max_outputs)
            handle.remove()

        if isinstance(module, (nn.Conv2d, nn.Linear)):
            handle = module.register_forward_hook(save_data)

            with torch.no_grad():
                _ = model(sample_data)
                max_layer_output = max_outputs[-1]
                module.weight.data /= max_layer_output
                if hasattr(module, 'bias'):
                    bias_scale = np.product(np.array(max_outputs))
                    # print(f"weight scale: {1/max_layer_output}, bias_scale: {1/bias_scale}")
                    module.bias.data /= bias_scale
            handle.remove()
    return max_outputs

In [None]:
norm_model = copy.deepcopy(folded_model)
sample_data = next(iter(test_loader))[0].cuda()
percentile = 99.99
normalize_outputs(norm_model, sample_data=sample_data, percentile=percentile, max_outputs=[])

In [None]:
get_accuracy(norm_model, valid_loader, device)

In [None]:
preprocess_layers = nn.Sequential(
    folded_model.conv1,
    folded_model.bn1,
    folded_model.relu,
)

ann = nn.Sequential(
    folded_model.maxpool,
    folded_model.layer1,
    folded_model.layer2,
    folded_model.layer3,
    folded_model.layer4,
    folded_model.avgpool,
    nn.Flatten(),
    folded_model.fc
)

composed_model = nn.Sequential(
    preprocess_layers,
    ann
)

In [None]:
get_accuracy(composed_model, valid_loader, device)

In [None]:
# plot_output_comparison_new(folded_model, norm_model, sample_input=next(iter(valid_loader))[0].cuda(), every_n=1000)

In [None]:
accuracies = []
for exponent in range(3, 6):
    t_max = 2**exponent
    snn = quartz.from_torch.from_model(ann, t_max=t_max, add_spiking_output=True).eval()
    print(f"percentile: {percentile}, t_max: {t_max}")
    accuracy = get_accuracy(snn, valid_loader, device, preprocess=preprocess_layers, t_max=t_max, print_early_spikes=True, print_output_time=True)
    # np.save(f"{accuracy}_accuracy_{t_max}_t_max.npy", accuracy)