In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import sinabs
from torchvision import transforms
from PIL import Image
import sinabs.layers as sl
import numpy as np
import quartz
import copy
from tqdm.auto import tqdm
from quartz.utils import get_accuracy, encode_inputs, decode_outputs, plot_output_histograms, plot_output_comparison, normalize_weights
from typing import List

np.set_printoptions(suppress=True)

In [2]:
preprocess = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [3]:
imagenet = torchvision.datasets.ImageNet('data/ImageNet/', split='val', transform=preprocess)

In [4]:
testloader = torch.utils.data.DataLoader(imagenet, batch_size=128, shuffle=True, num_workers=4)

In [5]:
device = "cuda"
cpu = "cpu"

model = models.vgg11(weights=models.vgg.VGG11_Weights.DEFAULT)
# model = models.vgg11_bn(weights=models.vgg.VGG11_BN_Weights.DEFAULT)

In [6]:
layer1 = nn.Conv2d(3, 3, kernel_size=1, groups=3)
layer1.weight.data /= layer1.weight.data
layer1.bias = torch.nn.Parameter(-1*torch.tensor([0.485, 0.456, 0.406]))

layer2 = nn.Conv2d(3, 3, kernel_size=1, groups=3)
layer2.weight = nn.Parameter(1/torch.tensor([0.229, 0.224, 0.225]).unsqueeze(1).unsqueeze(1).unsqueeze(1))
layer2.bias.data -= layer2.bias.data

model = nn.Sequential(layer1, layer2, *model.features, model.avgpool, nn.Flatten(), *model.classifier[0:2], *model.classifier[3:5], model.classifier[-1]) #*model.classifier)
model = model.eval()

for layer in list(model.children())[2:]:
    if isinstance(layer, nn.ReLU):
        layer.inplace = False

In [7]:
# model

In [8]:
# get_accuracy(model, testloader, device)

In [9]:
norm_model = copy.deepcopy(model)
norm_model = norm_model.eval()

In [10]:
param_layer_names = [name for name, child in norm_model.named_children() if isinstance(child, (nn.Conv2d, nn.Linear))][2:]
# param_layers = ['2', '5', '8', '10', '13']
percentile = 99.99

print(param_layer_names)

['2', '5', '8', '10', '13', '15', '18', '20', '25', '27', '29']


In [11]:
normloader = torch.utils.data.DataLoader(imagenet, batch_size=150, shuffle=True, num_workers=0)
images, labels = next(iter(normloader))

In [12]:
normalize_weights(norm_model.to(device), images.to(device), param_layer_names=param_layer_names, percentile=percentile)

In [13]:
# plot_output_comparison(model.to(cpu), norm_model.to(cpu), images.to(cpu), output_layers=param_layer_names, every_n=10000, every_c=10, savefig=f"norm_activation_correct_biases.png")

In [14]:
# output_layer_names = [name for name, child in norm_model.named_children() if isinstance(child, nn.ReLU)]
# output_layer_names += [param_layer_names[-1]]
# sinabs.utils.normalize_weights(norm_model.to(device), images.to(device), param_layers=param_layer_names, output_layers=output_layer_names, percentile=percentile)

In [15]:
# get_accuracy(norm_model, testloader, device=device)#"cpu")

  0%|          | 0/391 [00:00<?, ?it/s]

68.10799837112427

In [None]:
snnloader = torch.utils.data.DataLoader(imagenet, batch_size=1, shuffle=True, num_workers=4)

accuracies = []
for exponent in range(4, 7):
    t_max = 2**exponent
    snn = quartz.from_torch.from_model(norm_model, t_max=t_max, add_spiking_output=True).eval()
    preprocess_layers = norm_model[:4]
    snn = snn[4:]
    print(f"percentile: {percentile}, t_max: {t_max}")
    accuracy = get_accuracy(snn, snnloader, device, preprocess=preprocess_layers, t_max=t_max, print_early_spikes=True, print_output_time=True)
    np.save(f"{accuracy}_accuracy_{t_max}_t_max.npy", accuracy)