In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import sinabs
from torchvision import transforms, datasets
import sinabs.layers as sl
import numpy as np
import quartz
from quartz.utils import get_accuracy, encode_inputs, decode_outputs, normalize_outputs, plot_output_comparison, normalize_weights, plot_output_comparison_ann_snn, count_n_neurons, n_operations, omega_read, omega_write
from mnist_model import ConvNet
from typing import List
import matplotlib.pyplot as plt
import copy
import pickle

np.set_printoptions(suppress=True)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 200
num_workers = 4

transform = transforms.Compose([transforms.ToTensor()])
valid_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
norm_loader = DataLoader(dataset=valid_dataset, batch_size=10000, shuffle=True, num_workers=num_workers)

In [None]:
state_dict = torch.load("mnist-convnet.pth", map_location=torch.device(device))
model = ConvNet()

model[0].weight.data = state_dict['conv1.weight']
model[0].bias.data = state_dict['conv1.bias']
model[3].weight.data = state_dict['conv2.weight']
model[3].bias.data = state_dict['conv2.bias']
model[6].weight.data = state_dict['conv3.weight']
model[6].bias.data = state_dict['conv3.bias']
model[10].weight.data = state_dict['fc1.weight']
model[10].bias.data = state_dict['fc1.bias']

model.eval();
# get_accuracy(model, valid_loader, device)

In [None]:
sum(p.numel() for p in model.parameters() if p.requires_grad)/1_000

In [None]:
n_synops = 31570 # calculated further down, search for n_ops
n_neurons = count_n_neurons(model.cpu(), next(iter(valid_loader))[0][0:1], add_last_layer=True)

In [None]:
from fvcore.nn import FlopCountAnalysis, flop_count_table
flops = FlopCountAnalysis(model, next(iter(valid_loader))[0])
print(flops.total()/batch_size)

In [None]:
stats = {}
for percentile in [98, 99, 99.5, 99.9, 99.99, 99.999]:
    norm_model = copy.deepcopy(model)
    normalize_outputs(norm_model.to(device), sample_data=next(iter(norm_loader))[0].to(device), percentile=percentile, max_outputs=[])

    for exponent in range(1, 6):
        t_max = 2**exponent
        snn = quartz.from_torch.from_model(norm_model, t_max=t_max, add_spiking_output=True).to(device).eval()
        metric = get_accuracy(snn, valid_loader, device, t_max=t_max, calculate_early_spikes=True, calculate_output_time=True)
        metric[t_max]['n_ops'] = n_operations(n_neurons, t_max, n_synops)
        metric[t_max]['read_ops'] = omega_read(n_neurons, t_max, n_synops)
        metric[t_max]['write_ops'] = omega_write(n_neurons, t_max, n_synops)
        if percentile in stats.keys():
            stats[percentile].update(metric)
        else:
            stats[percentile] = metric

In [None]:
with open('mnist-results.pkl', 'wb') as file:
    pickle.dump(stats, file)

In [None]:
norm_model = deepcopy(model)
normalize_outputs(norm_model.to(device), all_test_images.to(device), percentile=percentile, max_outputs=[])

In [None]:
# cpu = 'cpu'
# t_max = 2**6

# sample_data = next(iter(valid_loader))[0]

# snn = quartz.from_torch.from_model(norm_model, t_max=t_max, add_spiking_output=True).to(cpu).eval()
# snn_output_layers = [name for name, child in snn.named_children() if isinstance(child, quartz.IF)]

In [None]:
# plot_output_comparison_model_snn(norm_model.to(cpu), snn.to(cpu), sample_data.to(cpu), 
#     ann_output_layers=output_layer_names, 
#     snn_output_layers=snn_output_layers, 
#     t_max=t_max, 
#     every_n=10, 
#     every_c=1, 
#     savefig=f"ann-snn-comparison-tmax{t_max}-percentile{percentile}.png"
# )

In [None]:
# q_model = deepcopy(norm_model)
# for exponent in range(1, 7):
#     t_max = 2**exponent
#     def quantize(module, input, output):
#         return (output * t_max).round() / t_max

#     for module in q_model.children():
#         if isinstance(module, nn.ReLU):
#             module.register_forward_hook(quantize)
#     q_model[-1].register_forward_hook(quantize)

#     accuracy = get_accuracy(q_model, valid_loader, device)
#     print(f"{t_max} time steps: {round(accuracy, 3)}%")

In [None]:
np.array([3.30459, 2.20934]).round(3)

In [None]:
metrics = []
for exponent in range(1, 6):
    t_max = 2**exponent
    snn = quartz.from_torch.from_model(norm_model, t_max=t_max, add_spiking_output=True).to(device).eval()
    metric = get_accuracy(snn, valid_loader, device, t_max=t_max, calculate_early_spikes=True, calculate_output_time=True)
    metrics.append(metric)
    print(metric)

In [None]:
for percentile in [97, 98, 99, 99.9, 99.99, 99.999]: # [99]: #
    norm_model = deepcopy(model)
    normalize_outputs(norm_model.to(device), all_test_images.to(device), percentile=percentile, max_outputs=[])
    t_max = 16
    snn = quartz.from_torch.from_model(norm_model, t_max=t_max, add_spiking_output=True).to(device).eval()
    accuracy = get_accuracy(snn, valid_loader, device, t_max=t_max, calculate_early_spikes=True, calculate_output_time=True)
    print(accuracy)

In [None]:
accuracies

In [None]:
100 - np.array(accuracies)

In [None]:
n_ops = [layer.n_ops for layer in snn if isinstance(layer, sl.StatefulLayer)]
torch.stack(n_ops).sum()

In [None]:
layer.n_ops

In [None]:
t_max = 2**3+1
images, label = next(iter(valid_loader))
spikes = encode_inputs(images, t_max=t_max).to(device)
snn = quartz.from_torch.from_model(norm_model, t_max=t_max, add_spiking_output=True).to(device).eval()
output_layers = [child for name, child in snn.named_children() if isinstance(child, sl.StatefulLayer)]

plot_output_histograms(snn, spikes, output_layers, t_max=t_max)

In [None]:
t_max = 2**3+1
def quantize(module, input, output):
    return (output * t_max).int() / t_max

for module in model.children():
    if isinstance(module, nn.ReLU):
        module.register_forward_hook(quantize)

param_layers = [child for name, child in model.named_children() if isinstance(child, (nn.Conv2d, nn.Linear))]
output_layers = [child for name, child in model.named_children() if isinstance(child, nn.ReLU)]
output_layers += [param_layers[-1]]
plot_output_histograms(model, images.to(device), output_layers)

In [None]:
snn

In [None]:
t_max = 2**5
snn = quartz.from_torch.from_model(model, t_max=t_max, add_spiking_output=True).to(device).eval()
synop_counter = sinabs.SNNSynOpCounter(snn)
accuracy = get_accuracy(snn, valid_loader, device, t_max=t_max)
print(f"{t_max} time steps: {round(accuracy, 3)}%")

In [None]:
synop_counter.get_total_synops()

In [None]:
synop_counter.get_total_synops()

In [None]:
1.4933e+08