In [None]:
import numpy as np
from datetime import datetime 
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.quantization
from torch.utils.tensorboard import SummaryWriter
import torchvision
import matplotlib.pyplot as plt
import time
from cifar_model import MobileNet
from utils import *

In [None]:
cpu_device = 'cpu'
gpu_device = 'cuda'
percentile = 99.9

# parameters
RANDOM_SEED = 42
LEARNING_RATE = 0.001
BATCH_SIZE = 128
num_workers = 10

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
valid_dataset = datasets.CIFAR10(root='./data', train=False,transform=transform)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True)

In [None]:
model = MobileNet(10)
model.load_state_dict(torch.load("./cifar-convnet.pth", map_location=torch.device(cpu_device)))
capture = model.eval()

In [None]:
get_accuracy(model, valid_loader, device=cpu_device)

In [None]:
model.features

## fold bn layers into previous conv layers

In [None]:
previous_module = None
new_layers = []
for module in model.modules():
    if isinstance(module, (nn.Conv2d, nn.MaxPool2d, nn.BatchNorm2d, nn.Linear, nn.ReLU6, nn.ReLU, nn.Flatten)):
        if isinstance(module, nn.BatchNorm2d) and isinstance(previous_module, nn.Conv2d):
            new_layers[-1] = torch.nn.utils.fuse_conv_bn_eval(previous_module, module)
        else:
            new_layers.append(module)
        previous_module = module

In [None]:
folded_model = nn.Sequential(*new_layers)
folded_model

In [None]:
get_accuracy(folded_model, valid_loader, device=cpu_device)

## check activations

In [None]:
def check_activations(model, percentile, device):
    criterion = nn.CrossEntropyLoss()
    batch_size_test = 1000
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    valid_dataset = datasets.CIFAR10(root='./data', train=False,transform=transform_test, download=False)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size_test, shuffle=False, num_workers=4, pin_memory=True)
    model.to(device)
    with torch.no_grad():
        model.eval()
        activations = {}
        def save_activation(name, mod, inp, out):
            if name not in activations.keys():
                activations[name] = out
            else:
                activations[name] = torch.cat((activations[name],out))

        names = []
        handles = []
        max_weights_percentile = []
        min_weights_percentile = []

        for name, module in model.named_modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                handles.append(module.register_forward_hook(partial(save_activation, name)))
                names.append(name)
                max_weights_percentile.append(np.percentile(module.weight.cpu().numpy(), percentile).round(2))
                min_weights_percentile.append(np.percentile(module.weight.cpu().numpy(), 100-percentile).round(2))

        running_loss = 0
        for X, y_true in valid_loader:
            X = X.to(device)
            y_output = model(X)
        [handle.remove() for handle in handles] # remove forward hooks

        str_output = ''.join(["{}: [{},{}]; ".format(names[i], str(min_weights_percentile[i]), str(max_weights_percentile[i])) for i in range(len(names))])
        print("\t" + str(percentile) + "% weights: " + str_output)

        str_output = ''.join(["{}: {}, ".format(name, round(np.percentile(np.maximum(activation.cpu(),0), percentile),3)) for name, activation in activations.items()])
        print("\t" + str(percentile) + "% activations: " + str_output)
    return min_weights_percentile, max_weights_percentile, names

In [None]:
min_weights_percentile, max_weights_percentile, names = check_activations(folded_model, percentile, cpu_device)

In [None]:
results = check_activations(model, percentile, cpu_device)

In [None]:
weights = list(zip(min_weights_percentile, max_weights_percentile))
scaling_factors = [1/max(abs(mini), maxi) for mini, maxi in weights]

In [None]:
import copy
scaled_model = copy.deepcopy(folded_model)
capture = scaled_model.eval()

In [None]:
scaling_dict = dict(zip(names, scaling_factors))
scaling_dict

In [None]:
# scaling_dict = {
#  '4': 0.4854368932038835,
#  '9': 0.37593984962406013,
#  '14': 0.4424778761061947,}

In [None]:
with torch.no_grad():
    for name, module in scaled_model.named_children():
        if name in scaling_dict:
            if hasattr(module, 'weight'):
                module.weight *= 1.2 #scaling_dict[name]
            if hasattr(module, 'bias') and module.bias is not None:
                module.bias *= 1.2 #scaling_dict[name]

In [None]:
result = check_activations(scaled_model, percentile, cpu_device)

In [None]:
get_accuracy(scaled_model, valid_loader, device=cpu_device)

In [None]:
%debug