In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import sinabs
from torchvision import transforms
from PIL import Image
import sinabs.layers as sl
import numpy as np
import quartz
import copy
from tqdm.auto import tqdm
from quartz.utils import get_accuracy, encode_inputs, decode_outputs, plot_output_histograms


np.set_printoptions(suppress=True)

In [None]:
preprocess = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [None]:
imagenet = torchvision.datasets.ImageNet('/home/gregorlenz/Development/playground/data/ImageNet/', split='val', transform=preprocess)

In [None]:
testloader = torch.utils.data.DataLoader(imagenet, batch_size=128, shuffle=True, num_workers=4)

In [None]:
device = "cuda"

model = models.vgg11(weights=models.vgg.VGG11_Weights.DEFAULT)
# model = models.vgg11_bn(weights=models.vgg.VGG11_BN_Weights.DEFAULT)

In [None]:
layer1 = nn.Conv2d(3, 3, kernel_size=1, groups=3)
layer1.weight.data /= layer1.weight.data
layer1.bias = torch.nn.Parameter(-1*torch.tensor([0.485, 0.456, 0.406]))

layer2 = nn.Conv2d(3, 3, kernel_size=1, groups=3)
layer2.weight = nn.Parameter(1/torch.tensor([0.229, 0.224, 0.225]).unsqueeze(1).unsqueeze(1).unsqueeze(1))
layer2.bias.data -= layer2.bias.data

model = nn.Sequential(layer1, layer2, *model.features, model.avgpool, nn.Flatten(), *model.classifier[0:2], *model.classifier[3:5], model.classifier[-1]) #*model.classifier)
model = model.eval()

In [None]:
# model

In [None]:
# get_accuracy(model, testloader, device)

In [None]:
def normalize_weights(
    ann: nn.Module,
    sample_data: torch.Tensor,
    output_layers,
    param_layers,
    percentile: float = 99,
    scale_factor = 1.
):
    ann = ann.eval()
    output_data = []
    def save_data(lyr, input, output):
        output_data.append(output.clone())

    named_layers = dict(ann.named_children())

    for i in range(len(output_layers)):
        param_layer = named_layers[param_layers[i]]
        output_layer = named_layers[output_layers[i]]

        handle = output_layer.register_forward_hook(save_data)

        with torch.no_grad():
            _ = ann(sample_data)
            max_lyr_out = np.percentile(output_data[-1].cpu().numpy(), percentile)
            print(max_lyr_out)

            for p in param_layer.parameters():
                p.data *= scale_factor / max_lyr_out

        output_data.clear()
        handle.remove()

In [None]:
norm_model = copy.deepcopy(model)
norm_model = norm_model.eval()
# norm_model[0].bias.data /= 4
# norm_model[1].weight.data /= 4

In [None]:
# normloader = torch.utils.data.DataLoader(imagenet, batch_size=10, shuffle=True, num_workers=6)
# images, labels = next(iter(normloader))

# np.percentile(norm_model[:2](images).detach().numpy(), percentile)

In [None]:
param_layers = [name for name, child in norm_model.named_children() if isinstance(child, (nn.Conv2d, nn.Linear))][1:]
output_layers = [name for name, child in norm_model.named_children() if isinstance(child, nn.ReLU)]
# output_layers += [param_layers[-1]]
# output_layers = ['0'] + output_layers
output_layers = ['1'] + output_layers
percentile = 99.99

print(param_layers)
print(output_layers)

In [None]:
normloader = torch.utils.data.DataLoader(imagenet, batch_size=150, shuffle=True, num_workers=6)
images, labels = next(iter(normloader))

norm_model = norm_model.to(device)
images = images.to(device)

normalize_weights(norm_model, images, output_layers=output_layers, param_layers=param_layers, percentile=percentile, scale_factor=2.6)
norm_model = norm_model.cpu()
images = images.cpu()

In [None]:
# for name, params in norm_model.named_parameters():
#     print(f"Layer {name}    \t {params.min():.2f}/{params.max():.2f}")

In [None]:
get_accuracy(norm_model, testloader, device=device)#"cpu")

In [None]:
snnloader = torch.utils.data.DataLoader(imagenet, batch_size=2, shuffle=True, num_workers=4)

accuracies = []
# for exponent in range(5, 7):
exponent = 5
t_max = 2**exponent
snn = quartz.from_torch.from_model(norm_model, t_max=t_max, add_spiking_output=True).eval()
snn = snn[2:] # nn.Sequential(*snn[:2], quartz.IF(t_max=t_max, rectification=False), *

In [None]:
assert 1 == 0;

In [None]:
snn = snn.cpu()
norm_model = norm_model.cpu()

torch.random.manual_seed(6)
# snn[1].module.weight.data /= 2

# images, labels = next(iter(snnloader))
for images, labels in iter(snnloader):
    norm_images = norm_model[:2](images)
    temp_images = encode_inputs(norm_images, t_max=t_max)
    conv_output = snn(temp_images)
    snn_output = decode_outputs(conv_output, t_max=t_max)
    print(snn_output.argmax(1) == labels)

ann_output = norm_model(images)

In [None]:
snn[1].early_spikes

In [None]:
ann_output.argmax(1)

In [None]:
snn_output.argmax(1)

In [None]:
accuracy = get_accuracy(snn, testloader, device="cpu", t_max=t_max)
accuracies.append(accuracy)
print(f"{t_max} time steps: {round(accuracy, 3)}%")

In [None]:
snn

In [None]:
snn

In [None]:
with torch.no_grad():
    output = ann(input_batch.cpu())

probabilities = torch.nn.functional.softmax(output[0], dim=0)
print_probabilities(probabilities)

In [None]:
snn = quartz.from_torch.from_model(ann, t_max=t_max, batch_size=batch_size)

In [None]:
temp_q_values = quartz.encode_inputs(input_batch, t_max=t_max)

In [None]:
# snn = snn.cuda()
# temp_q_values = temp_q_values.cuda()

In [None]:
with torch.no_grad():
    temp_output = snn(temp_q_values.flatten(0, 1)).unflatten(0, (batch_size, -1))
snn_output = quartz.decode_outputs(temp_output, t_max=t_max)

In [None]:
probabilities = torch.nn.functional.softmax(snn_output[0], dim=0)
print_probabilities(probabilities)

In [None]:
snn_output.shape

In [None]:
# torchvision.datasets.ImageNet("../data", split="val")

In [None]:
snn_output