In [None]:
import numpy as np
from datetime import datetime 
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.quantization
from torch.utils.tensorboard import SummaryWriter
import torchvision
import matplotlib.pyplot as plt
import ipdb
import time
from cifar_model import ConvNet
from utils import *

In [None]:
DEVICE = 'cpu'

# parameters
RANDOM_SEED = 42
LEARNING_RATE = 0.001
BATCH_SIZE = 128
num_workers = 10

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
valid_dataset = datasets.CIFAR10(root='./data', train=False,transform=transform)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True)

In [None]:
model = ConvNet(10)#.to(DEVICE)
model.load_state_dict(torch.load("cifar-convnet.pth", map_location=torch.device(DEVICE)))
capture = model.eval()

In [None]:
get_accuracy(model, valid_loader, device=DEVICE)

In [None]:
weights, biases = get_all_weights_biases(model)

In [None]:
np.array(([weight.max().detach().numpy() for weight in weights]))

In [None]:
np.array(([weight.min().detach().numpy() for weight in weights]))

In [None]:
np.array(([bias.max().detach().numpy() for bias in biases]))

In [None]:
np.array(([bias.min().detach().numpy() for bias in biases]))

## fold bn layers into previous conv layers

In [None]:
i = 0
previous_module = None
new_layers = []
for module in model.modules():
    if isinstance(module, (nn.Conv2d, nn.MaxPool2d, nn.BatchNorm2d, nn.Linear)):
        if isinstance(module, nn.BatchNorm2d) and isinstance(previous_module, nn.Conv2d):
            print(module)
            new_layers[-1] = torch.nn.utils.fuse_conv_bn_eval(previous_module, module)
            i += 1
            #if i >= 3: import ipdb; ipdb.set_trace()
        else:
            new_layers.append(module)
        previous_module = module

In [None]:
folded_model = nn.Sequential(*new_layers)

In [None]:
for module in folded_model:
    if type(module) != nn.Sequential:
        print(module)

In [None]:
for name, module in folded_model.named_parameters():
    print(name)
    print(module.max())

In [None]:
# for params in folded_model.named_parameters():
#     print(params)

In [None]:
bn = model.features[1].conv[3][1]

In [None]:
bn.running_var.max()

In [None]:
for name, module in model.named_parameters():
    print(name)
    print(module.max())

In [None]:
# for module in model.modules():
#     print(module)

## Plot weights

In [None]:
from matplotlib import pyplot as plt
import numpy as np

weights = list(model.parameters())
weights = [weight.flatten() for weight in weights]
weights = torch.cat(weights)

capture = plt.hist(weights.cpu().detach().numpy(), bins=200)

## Model parameters

In [None]:
for param_tensor in folded_model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

In [None]:
model.bn1.running_mean

# Quantization

In [None]:
# model.qconfig = torch.quantization.default_qconfig
# print(model.qconfig)
# torch.quantization.prepare(model, inplace=True)

In [None]:
torch.quantization.convert(model, inplace=True)

## Example activation

In [None]:
model.eval()
X, y_true = next(iter(test_loader))
X = X.to(DEVICE)
Y = y_true.to(DEVICE)
Y_hat = model(X)[0]
plt.plot(Y_hat.cpu().detach().numpy().flatten())


In [None]:
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

In [None]:
model.conv1.weight