In [None]:
import numpy as np
from datetime import datetime 
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.quantization
from torch.utils.tensorboard import SummaryWriter
import torchvision
import matplotlib.pyplot as plt
import time
import collections
from functools import partial
from mnist_model import ConvNet
import copy
from utils import *

In [None]:
DEVICE = 'cpu'

# parameters
RANDOM_SEED = 42
LEARNING_RATE = 0.001
BATCH_SIZE = 10000
num_workers = 10

IMG_SIZE = 32
N_CLASSES = 10

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
valid_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True)

In [None]:
model = ConvNet(10)
model.load_state_dict(torch.load("models/mnist-convnet.pth", map_location=torch.device(DEVICE)))
capture = model.eval()

In [None]:
get_mnist_accuracy(model, valid_loader, device=DEVICE)

## Check model weights

In [None]:
weights, biases = get_weights_biases(model)
percentile = 99.9

In [None]:
max_weights = np.array(([np.percentile(weight, percentile) for weight in weights]))
min_weights = np.array(([np.percentile(weight, 100-percentile) for weight in weights]))

In [None]:
weight_scaling = [max(ma, abs(mi)) for ma, mi in zip(max_weights, min_weights)]

In [None]:
activation_scaling = [1.347, 1.862, 1.811, 1.5]

In [None]:
[1/(w_scale*a_scale) for w_scale, a_scale in zip(weight_scaling, activation_scaling)]

In [None]:
bias_scaling = np.array(([np.percentile(bias, percentile) for bias in biases]))

In [None]:
scaled_model = copy.deepcopy(model)
scaled_model.eval()
with torch.no_grad():
    scaled_model.conv1.weight /= weight_scaling[0] * activation_scaling[0]
    scaled_model.conv1.bias /= weight_scaling[0] * activation_scaling[0]
    scaled_model.conv2.weight /= weight_scaling[1] * activation_scaling[1]
    scaled_model.conv2.bias /= weight_scaling[1] * activation_scaling[1]
    scaled_model.conv3.weight /= weight_scaling[2] * activation_scaling[2]
    scaled_model.conv3.bias /= weight_scaling[2] * activation_scaling[2]
    scaled_model.fc1.weight /= weight_scaling[3] * activation_scaling[3]
    scaled_model.fc1.bias /= weight_scaling[3] * activation_scaling[3]

In [None]:
scaled_model
get_mnist_accuracy(scaled_model, valid_loader, device=DEVICE)

In [None]:
def test_activations(scaled_model, percentile):
    criterion = nn.CrossEntropyLoss()
    batch_size_test = 1000
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    valid_dataset = datasets.MNIST(root='./data', train=False,transform=transform_test, download=False)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size_test, shuffle=False, num_workers=4, pin_memory=True)
    device = 'cuda'
    scaled_model.to(device)

    with torch.no_grad():
        scaled_model.eval()

        activations = {}
        def save_activation(name, mod, inp, out):
            if name not in activations.keys():
                activations[name] = out
            else:
                activations[name] = torch.cat((activations[name],out))

        names = []
        handles = []
        max_weights_percentile = []
        min_weights_percentile = []

        for name, module in scaled_model.named_modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                handles.append(module.register_forward_hook(partial(save_activation, name)))
                names.append(name)
                max_weights_percentile.append(np.percentile(module.weight.cpu().numpy(), percentile).round(2))
                min_weights_percentile.append(np.percentile(module.weight.cpu().numpy(), 100-percentile).round(2))

        running_loss = 0
        for X, y_true in valid_loader:
            X = X.to(device)
            y_output = scaled_model(X)
        [handle.remove() for handle in handles] # remove forward hooks

        str_output = ''.join(["{}: [{},{}]; ".format(names[i], str(min_weights_percentile[i]), str(max_weights_percentile[i])) for i in range(len(names))])
        print("\t" + str(percentile) + "% weights: " + str_output)

        str_output = ''.join(["{}: {}, ".format(name, round(np.percentile(np.maximum(activation.cpu(),0), percentile),3)) for name, activation in activations.items()])
        print("\t" + str(percentile) + "% activations: " + str_output)

In [None]:
test_activations(scaled_model, percentile)

In [None]:
test_activations(model, percentile)