In [None]:
import torch
import torch.nn as nn
import torchvision
import sinabs
from torchvision import transforms
from PIL import Image
import sinabs.layers as sl
import numpy as np
import quartz
import copy

np.set_printoptions(suppress=True)

In [None]:
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
model = torch.hub.load("pytorch/vision:v0.10.0", "vgg11", pretrained=True)
model.eval();

In [None]:
import urllib

url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try:
    urllib.URLopener().retrieve(url, filename)
except:
    urllib.request.urlretrieve(url, filename)

In [None]:
input_image = Image.open(filename)
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)  # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
# if torch.cuda.is_available():
#     input_batch = input_batch.to("cuda")
#     model.to("cuda")

In [None]:
with torch.no_grad():
    output = model(input_batch)

# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)

In [None]:
def print_probabilities(probabilities):
    # Read the categories
    with open("../imagenet_classes.txt", "r") as f:
        categories = [s.strip() for s in f.readlines()]
    # Show top categories per image
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    for i in range(top5_prob.size(0)):
        print(categories[top5_catid[i]], top5_prob[i].item())

In [None]:
print_probabilities(probabilities)

In [None]:
batch_size = 1
t_max = 64

# flatten model structure
ann = nn.Sequential(*model.features, model.avgpool, nn.Flatten(), *model.classifier)
ann = copy.deepcopy(ann)
ann.eval();

In [None]:
sample_data = input_batch.cpu()
param_layers = (nn.Conv2d, nn.Linear)
percentile = 90

# def normalize_network_activations(
#     ann: nn.Module,
#     sample_data: torch.Tensor,
#     param_layers,
#     percentile: float = 99,
# ):
activations = {}
def save_activation(module, input, output):
    activations[module] = output

handles = [module.register_forward_hook(save_activation) for module in ann.children() if isinstance(module, param_layers)]

with torch.no_grad():
    ann(sample_data)

prev_scale_factor = 1.

for module in ann.children():
    if module in activations.keys():
        scale_factor = np.percentile(activations[module].cpu().numpy(), percentile)
        if module.weight is not None:
            module.weight.data *= prev_scale_factor/scale_factor
        if hasattr(module, "bias") and module.bias is not None:
            module.bias.data /= scale_factor
        prev_scale_factor = scale_factor

[handle.remove() for handle in handles];

In [None]:
module = list(ann.children())[-1]
module.weight.data *= 4
module.bias.data *= 4

In [None]:
# ann

In [None]:
[act.mean() for act in activations.values()]

In [None]:
with torch.no_grad():
    output = ann(input_batch.cpu())

probabilities = torch.nn.functional.softmax(output[0], dim=0)
print_probabilities(probabilities)

In [None]:
snn = quartz.from_torch.from_model(ann, t_max=t_max, batch_size=batch_size)

In [None]:
temp_q_values = quartz.encode_inputs(input_batch, t_max=t_max)

In [None]:
# snn = snn.cuda()
# temp_q_values = temp_q_values.cuda()

In [None]:
with torch.no_grad():
    temp_output = snn(temp_q_values.flatten(0, 1)).unflatten(0, (batch_size, -1))
snn_output = quartz.decode_outputs(temp_output, t_max=t_max)

In [None]:
probabilities = torch.nn.functional.softmax(snn_output[0], dim=0)
print_probabilities(probabilities)

In [None]:
snn_output.shape

In [None]:
# torchvision.datasets.ImageNet("../data", split="val")

In [None]:
snn_output