In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import sinabs
from torchvision import transforms
from PIL import Image
import sinabs.layers as sl
import numpy as np
import quartz
import copy
from imagenet import ImageNetTest
from tqdm.auto import tqdm

np.set_printoptions(suppress=True)

In [None]:
preprocess = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

In [None]:
imagenet = torchvision.datasets.ImageNet('/home/gregorlenz/Development/playground/data/ImageNet/', split='val', transform=preprocess)

In [None]:
testloader = torch.utils.data.DataLoader(imagenet, batch_size=64, shuffle=False, num_workers=4)

In [None]:
device = "cuda"

model = models.vgg16(weights=models.vgg.VGG16_Weights.DEFAULT)
model = model.to(device)
model = model.eval()

In [None]:
# layer1 = nn.Conv2d(3, 3, kernel_size=1, groups=3)
# layer1.weight.data /= layer1.weight.data
# layer1.bias = torch.nn.Parameter(-1*torch.tensor([0.485, 0.456, 0.406]))
# layer1 = layer1.to(device)

# layer2 = nn.Conv2d(3, 3, kernel_size=1, groups=3)
# layer2.weight = nn.Parameter(1/torch.tensor([0.229, 0.224, 0.225]).unsqueeze(1).unsqueeze(1).unsqueeze(1))
# layer2.bias.data -= layer2.bias.data
# layer2 = layer2.to(device)

# model.features = nn.Sequential(
#     layer1,
#     layer2,
#     model.features
# )

In [None]:
predictions = []
for images, labels in tqdm(testloader):
    images = images.to(device)
    labels = labels.to(device)
    with torch.no_grad():
        output = model(images.float()).argmax(1)
        predictions.append((output == labels))
accuracy = torch.cat(predictions).float().mean()
accuracy

In [None]:
assert 1 == 0

In [None]:
model

In [None]:
batch_size = 1
t_max = 64

# flatten model structure
ann = nn.Sequential(*model.features, model.avgpool, nn.Flatten(), *model.classifier)
ann = copy.deepcopy(ann)
ann.eval();

In [None]:
ann

In [None]:
param_layers = [name for name, child in ann.named_children() if isinstance(child, (nn.Conv2d, nn.Linear))]
output_layers = [name for name, child in ann.named_children() if isinstance(child, nn.ReLU)]
output_layers += [param_layers[-1]]
percentile = 99.9

In [None]:
# list(ann[21].parameters())

In [None]:
print(param_layers)
print(output_layers)

In [None]:
param_layers = ['0']
output_layers = ['1']

In [None]:
sinabs.utils.normalize_weights(ann, input_batch.cpu(), output_layers=output_layers, param_layers=param_layers, percentile=percentile)

In [None]:
output1 = model(input_batch)
output2 = ann(input_batch)

assert output1.shape == output2.shape

In [None]:
# output1

In [None]:
# output2

In [None]:
with torch.no_grad():
    output = ann(input_batch.cpu())

probabilities = torch.nn.functional.softmax(output[0], dim=0)
print_probabilities(probabilities)

In [None]:
snn = quartz.from_torch.from_model(ann, t_max=t_max, batch_size=batch_size)

In [None]:
temp_q_values = quartz.encode_inputs(input_batch, t_max=t_max)

In [None]:
# snn = snn.cuda()
# temp_q_values = temp_q_values.cuda()

In [None]:
with torch.no_grad():
    temp_output = snn(temp_q_values.flatten(0, 1)).unflatten(0, (batch_size, -1))
snn_output = quartz.decode_outputs(temp_output, t_max=t_max)

In [None]:
probabilities = torch.nn.functional.softmax(snn_output[0], dim=0)
print_probabilities(probabilities)

In [None]:
snn_output.shape

In [None]:
# torchvision.datasets.ImageNet("../data", split="val")

In [None]:
snn_output