In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import sinabs
from torchvision import transforms
from PIL import Image
import sinabs.layers as sl
import numpy as np
import quartz
import copy
from tqdm.auto import tqdm
from quartz.utils import get_accuracy, encode_inputs, decode_outputs, plot_output_histograms, plot_output_comparison


np.set_printoptions(suppress=True)

In [None]:
preprocess = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [None]:
imagenet = torchvision.datasets.ImageNet('/home/gregorlenz/Development/playground/data/ImageNet/', split='val', transform=preprocess)

In [None]:
testloader = torch.utils.data.DataLoader(imagenet, batch_size=128, shuffle=True, num_workers=4)

In [None]:
device = "cuda"
cpu = "cpu"

model = models.vgg11(weights=models.vgg.VGG11_Weights.DEFAULT)
# model = models.vgg11_bn(weights=models.vgg.VGG11_BN_Weights.DEFAULT)

In [None]:
layer1 = nn.Conv2d(3, 3, kernel_size=1, groups=3)
layer1.weight.data /= layer1.weight.data
layer1.bias = torch.nn.Parameter(-1*torch.tensor([0.485, 0.456, 0.406]))

layer2 = nn.Conv2d(3, 3, kernel_size=1, groups=3)
layer2.weight = nn.Parameter(1/torch.tensor([0.229, 0.224, 0.225]).unsqueeze(1).unsqueeze(1).unsqueeze(1))
layer2.bias.data -= layer2.bias.data

model = nn.Sequential(layer1, layer2, *model.features, model.avgpool, nn.Flatten(), *model.classifier[0:2], *model.classifier[3:5], model.classifier[-1]) #*model.classifier)
model = model.eval()

In [None]:
for layer in list(model.children())[2:]:
    if isinstance(layer, nn.ReLU):
        layer.inplace = False
    # if hasattr(layer, "bias"):
    #     layer.bias.data *= 0

In [None]:
# model

In [None]:
# get_accuracy(model, testloader, device)

In [None]:
def normalize_weights(
    ann: nn.Module,
    sample_data: torch.Tensor,
    output_layers,
    param_layers,
    percentile: float = 99,
    scale_factor = 1.
):
    ann = ann.eval()
    max_out = []
    def save_data(lyr, input, output):
        max_out.append(np.percentile(output.cpu().detach().numpy(), percentile))

    named_layers = dict(ann.named_children())

    handles = []
    for i in range(len(param_layers)):
        param_layer = named_layers[param_layers[i]]
        handle = param_layer.register_forward_hook(save_data)
        handles.append(handle)

    with torch.no_grad():
        _ = ann(sample_data)

    print(max_out)
 
    for i in range(len(param_layers)):
        param_layer = named_layers[param_layers[i]]
        prev_scale = 1 if i == 0 else max_out[i-1]
        scale = max_out[i]
        param_layer.weight.data *= prev_scale/scale
        param_layer.bias.data /= scale

    [handle.remove() for handle in handles]

In [None]:
def normalize_weights(
    ann: nn.Module,
    sample_data: torch.Tensor,
    output_layers,
    param_layers,
    percentile: float = 99,
):
    """
    Rescale the weights of the network, such that the activity of each specified layer is normalized.

    The method implemented here roughly follows the paper:
    `Conversion of Continuous-Valued Deep Networks to Efficient Event-Driven Networks for Image Classification` by Rueckauer et al.
    https://www.frontiersin.org/article/10.3389/fnins.2017.00682

    Args:
         ann(nn.Module): Torch module
         sample_data (nn.Tensor): Input data to normalize the network with
         output_layers (List[str]): List of layers to verify activity of normalization. Typically this is a relu layer
         param_layers (List[str]): List of layers whose parameters preceed `output_layers`
         percentile (float): A number between 0 and 100 to determine activity to be normalized by.
          where a 100 corresponds to the max activity of the network. Defaults to 99.
    """
    # Network activity storage
    output_data = []

    # Hook to save data
    def save_data(lyr, input, output):
        output_data.append(output.clone())

    # All the named layers of the module
    named_layers = dict(ann.named_children())

    all_outputs = []

    for i in range(len(param_layers)):
        param_layer = named_layers[param_layers[i]]
        # output_layer = named_layers[output_layers[i]]

        handle = param_layer.register_forward_hook(save_data)

        with torch.no_grad():
            _ = ann(sample_data)

            # Get max output
            max_lyr_out = np.percentile(output_data[-1].cpu().numpy(), percentile)
            all_outputs.append(max_lyr_out)

            param_layer.weight.data /= max_lyr_out
            if hasattr(param_layer, 'bias'):
                param_layer.bias.data /= np.product(np.array(all_outputs))

            # # Rescale weights to normalize max output
            # for p in param_layer.parameters():
            #     p.data *= 1 / max_lyr_out

        output_data.clear()
        # Deregister hook
        handle.remove()


In [None]:
norm_model = copy.deepcopy(model)
norm_model = norm_model.eval()
norm_model[2].weight.data /= 10
norm_model[2].bias.data /= 10
norm_model[5].weight.data /= 2
norm_model[5].bias.data /= 20

for layer in ['8', '10', '13', '15', '18', '20', '25', '27', '29']:
    norm_model[int(layer)].bias.data /= 20

In [None]:
# normloader = torch.utils.data.DataLoader(imagenet, batch_size=10, shuffle=True, num_workers=6)
# images, labels = next(iter(normloader))

# np.percentile(norm_model[:2](images).detach().numpy(), percentile)

In [None]:
param_layers = [name for name, child in norm_model.named_children() if isinstance(child, (nn.Conv2d, nn.Linear))][2:]
output_layers = [name for name, child in norm_model.named_children() if isinstance(child, nn.ReLU)]
# output_layers += [param_layers[-1]]
# output_layers = ['0'] + output_layers
# output_layers = ['1'] + output_layers

# param_layers = ['25', '27']
# output_layers = ['26', '28']
param_layers = ['2', '5', '8', '10', '13']
output_layers = ['3', '6']
percentile = 99.99

print(param_layers)
print(output_layers)

In [None]:
normloader = torch.utils.data.DataLoader(imagenet, batch_size=150, shuffle=True, num_workers=0)
images, labels = next(iter(normloader))
scale_factor = 1.

In [None]:
# normalize_weights(norm_model.to(device), images.to(device), output_layers=output_layers, param_layers=param_layers, percentile=percentile, scale_factor=scale_factor)
# norm_model = norm_model.cpu()
# images = images.cpu()

In [None]:
device = 'cpu'
plot_output_comparison(model.to(device), norm_model.to(device), images.to(device), output_layers=param_layers, every_n=10000, every_c=10)#, savefig=f"norm_activation_comparison_scale{scale_factor}.png")
device = 'cuda'

In [None]:
# for name, params in norm_model.named_parameters():
#     print(f"Layer {name}    \t {params.min():.2f}/{params.max():.2f}")

In [None]:
get_accuracy(norm_model, testloader, device=device)#"cpu")

In [None]:
assert False

In [None]:
snnloader = torch.utils.data.DataLoader(imagenet, batch_size=2, shuffle=True, num_workers=4)

accuracies = []
# for exponent in range(5, 7):
exponent = 5
t_max = 2**exponent
snn = quartz.from_torch.from_model(norm_model, t_max=t_max, add_spiking_output=True).eval()
snn = snn[2:] # nn.Sequential(*snn[:2], quartz.IF(t_max=t_max, rectification=False), *

In [None]:
assert 1 == 0;

In [None]:
snn = snn.cpu()
norm_model = norm_model.cpu()

torch.random.manual_seed(6)
# snn[1].module.weight.data /= 2

# images, labels = next(iter(snnloader))
for images, labels in iter(snnloader):
    norm_images = norm_model[:2](images)
    temp_images = encode_inputs(norm_images, t_max=t_max)
    conv_output = snn(temp_images)
    snn_output = decode_outputs(conv_output, t_max=t_max)
    print(snn_output.argmax(1) == labels)

ann_output = norm_model(images)

In [None]:
snn[1].early_spikes

In [None]:
ann_output.argmax(1)

In [None]:
snn_output.argmax(1)

In [None]:
accuracy = get_accuracy(snn, testloader, device="cpu", t_max=t_max)
accuracies.append(accuracy)
print(f"{t_max} time steps: {round(accuracy, 3)}%")

In [None]:
snn

In [None]:
snn

In [None]:
with torch.no_grad():
    output = ann(input_batch.cpu())

probabilities = torch.nn.functional.softmax(output[0], dim=0)
print_probabilities(probabilities)

In [None]:
snn = quartz.from_torch.from_model(ann, t_max=t_max, batch_size=batch_size)

In [None]:
temp_q_values = quartz.encode_inputs(input_batch, t_max=t_max)

In [None]:
# snn = snn.cuda()
# temp_q_values = temp_q_values.cuda()

In [None]:
with torch.no_grad():
    temp_output = snn(temp_q_values.flatten(0, 1)).unflatten(0, (batch_size, -1))
snn_output = quartz.decode_outputs(temp_output, t_max=t_max)

In [None]:
probabilities = torch.nn.functional.softmax(snn_output[0], dim=0)
print_probabilities(probabilities)

In [None]:
snn_output.shape

In [None]:
# torchvision.datasets.ImageNet("../data", split="val")

In [None]:
snn_output