In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import sinabs
from torchvision import transforms
from PIL import Image
import sinabs.layers as sl
import numpy as np
import quartz
import copy
from imagenet import ImageNetTest

np.set_printoptions(suppress=True)

In [None]:
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

imagenet = ImageNetTest(root='/home/gregorlenz/Development/playground/data/ImageNet/test', transform = preprocess)

In [None]:
testloader = torch.utils.data.DataLoader(imagenet, batch_size=64)

In [None]:
model = models.vgg16(weights=models.vgg.VGG16_Weights.DEFAULT)

In [None]:
images, labels = next(iter(testloader))
with torch.no_grad():
    output = model(images)

In [None]:
output.argmax(1)

In [None]:
labels

In [None]:
# torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
# model = torch.hub.load("pytorch/vision:v0.10.0", "vgg16", pretrained=True)
# model.eval();

In [None]:
# model.features

In [None]:
model.avgpool

In [None]:
# model.classifier

In [None]:
import urllib

url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try:
    urllib.URLopener().retrieve(url, filename)
except:
    urllib.request.urlretrieve(url, filename)

In [None]:
input_image = Image.open(filename)
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)  # create a mini-batch as expected by the model

In [None]:
input_tensor

In [None]:
with torch.no_grad():
    output = model(input_batch)

# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)

In [None]:
def print_probabilities(probabilities):
    # Read the categories
    with open("../imagenet_classes.txt", "r") as f:
        categories = [s.strip() for s in f.readlines()]
    # Show top categories per image
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    for i in range(top5_prob.size(0)):
        print(categories[top5_catid[i]], top5_prob[i].item())

In [None]:
print_probabilities(probabilities)

In [None]:
batch_size = 1
t_max = 64

# flatten model structure
ann = nn.Sequential(*model.features, model.avgpool, nn.Flatten(), *model.classifier)
ann = copy.deepcopy(ann)
ann.eval();

In [None]:
ann

In [None]:
param_layers = [name for name, child in ann.named_children() if isinstance(child, (nn.Conv2d, nn.Linear))]
output_layers = [name for name, child in ann.named_children() if isinstance(child, nn.ReLU)]
output_layers += [param_layers[-1]]
percentile = 99.9

In [None]:
# list(ann[21].parameters())

In [None]:
print(param_layers)
print(output_layers)

In [None]:
param_layers = ['0']
output_layers = ['1']

In [None]:
sinabs.utils.normalize_weights(ann, input_batch.cpu(), output_layers=output_layers, param_layers=param_layers, percentile=percentile)

In [None]:
output1 = model(input_batch)
output2 = ann(input_batch)

assert output1.shape == output2.shape

In [None]:
# output1

In [None]:
# output2

In [None]:
with torch.no_grad():
    output = ann(input_batch.cpu())

probabilities = torch.nn.functional.softmax(output[0], dim=0)
print_probabilities(probabilities)

In [None]:
snn = quartz.from_torch.from_model(ann, t_max=t_max, batch_size=batch_size)

In [None]:
temp_q_values = quartz.encode_inputs(input_batch, t_max=t_max)

In [None]:
# snn = snn.cuda()
# temp_q_values = temp_q_values.cuda()

In [None]:
with torch.no_grad():
    temp_output = snn(temp_q_values.flatten(0, 1)).unflatten(0, (batch_size, -1))
snn_output = quartz.decode_outputs(temp_output, t_max=t_max)

In [None]:
probabilities = torch.nn.functional.softmax(snn_output[0], dim=0)
print_probabilities(probabilities)

In [None]:
snn_output.shape

In [None]:
# torchvision.datasets.ImageNet("../data", split="val")

In [None]:
snn_output