In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import torchvision.models as models
import sinabs
from torchvision import datasets, transforms
from PIL import Image
import sinabs.layers as sl
import numpy as np
import quartz
import copy
from tqdm.auto import tqdm
from quartz.utils import get_accuracy, encode_inputs, decode_outputs, normalize_outputs, plot_output_comparison, plot_output_comparison_new, normalize_weights, count_n_neurons, fuse_all_conv_bn, n_operations, omega_read, omega_write
from typing import List
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import torch.nn.functional as F


np.set_printoptions(suppress=True)

In [None]:
class Block(nn.Module):
    '''Depthwise conv + Pointwise conv'''
    def __init__(self, in_planes, out_planes, stride=1):
        super(Block, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False)
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU()

    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.relu2(self.bn2(self.conv2(out)))
        return out


class MobileNet(nn.Module):
    # (128,2) means conv planes=128, conv stride=2, by default conv stride=1
    cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512,  (1024,2), 1024]
    # cfg = [64, (256,2), (512,2), 512, (1024,2), (1024,2), 1024]

    def __init__(self, num_classes=10):
        super(MobileNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.layers = self._make_layers(in_planes=32)
        self.linear = nn.Linear(1024, num_classes)
        self.avg_pool = nn.AvgPool2d(2, stride=1)
        self.flatten = nn.Flatten()

    def _make_layers(self, in_planes):
        layers = []
        for x in self.cfg:
            out_planes = x if isinstance(x, int) else x[0]
            stride = 1 if isinstance(x, int) else x[1]
            layers.append(Block(in_planes, out_planes, stride))
            in_planes = out_planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layers(out)
        out = self.avg_pool(out)
        out = self.flatten(out)
        out = self.linear(out)
        return out


model = MobileNet().eval()
state_dict = {k[7:]:v for k, v in torch.load('mobilenetv1.pth')['net'].items()}
# state_dict = {k[7:]:v for k, v in torch.load('mobilenetv1-short.pth')['net'].items()}
model.load_state_dict(state_dict)

In [None]:
batch_size = 128
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

valid_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=batch_size, num_workers=4)
snn_loader = DataLoader(dataset=valid_dataset, shuffle=True, batch_size=8, num_workers=4)
norm_loader = DataLoader(dataset=valid_dataset, shuffle=True, batch_size=10000, num_workers=4)

In [None]:
n_neurons = count_n_neurons(model.cpu(), next(iter(valid_loader))[0][:1], add_last_layer=True)
n_neurons

In [None]:
sum(p.numel() for p in model.parameters() if p.requires_grad)/1_000_000

In [None]:
folded_model = copy.deepcopy(model)
fuse_all_conv_bn(folded_model)

In [None]:
get_accuracy(folded_model, valid_loader, device)

In [None]:
# get_accuracy(norm_model, valid_loader, device)

In [None]:
# norm_model

In [None]:
stats = {}
for percentile in [99.5, 99.9, 99.99, 99.999]:
    norm_model = copy.deepcopy(folded_model)
    normalize_outputs(norm_model.cpu(), sample_data=next(iter(norm_loader))[0][::3], percentile=percentile, max_outputs=[])
    with torch.no_grad():
        norm_model.linear.weight *= 2
    preprocess_layers = copy.deepcopy(norm_model.conv1).to(device)

    for t_max in [64, 128]:
        snn = copy.deepcopy(norm_model).to(device)
        snn.conv1 = nn.Identity()
        quartz.from_torch.from_model2(snn, t_max=t_max)
        snn = nn.Sequential(snn, quartz.IF(t_max=t_max, rectification=False))
        metric = get_accuracy(snn, snn_loader, device, t_max=t_max, calculate_early_spikes=True, calculate_output_time=True, preprocess=preprocess_layers,)
        n_synops = torch.stack([layer.n_ops for layer in snn if isinstance(layer, sl.StatefulLayer)]).sum().item()
        metric[t_max]['n_synops'] = n_synops
        metric[t_max]['n_neurons'] = n_neurons
        metric[t_max]['n_ops'] = n_operations(n_neurons, t_max, n_synops)
        metric[t_max]['read_ops'] = omega_read(n_neurons, t_max, n_synops)
        metric[t_max]['write_ops'] = omega_write(n_neurons, t_max, n_synops)
        if percentile in stats.keys():
            stats[percentile].update(metric)
        else:
            stats[percentile] = metric


In [None]:
with open('cifar10-results-mobilenet.pkl', 'wb') as file:
    pickle.dump(stats, file)