In [None]:
import numpy as np
from datetime import datetime 
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.quantization
from torch.utils.tensorboard import SummaryWriter
import torchvision
import matplotlib.pyplot as plt
import time
from cifar_model import MobileNet
from utils import *

In [None]:
DEVICE = 'cpu'

# parameters
RANDOM_SEED = 42
LEARNING_RATE = 0.001
BATCH_SIZE = 128
num_workers = 10

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
valid_dataset = datasets.CIFAR10(root='./data', train=False,transform=transform)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True)

In [None]:
model = MobileNet(10)#.to(DEVICE)
model.load_state_dict(torch.load("cifar-convnet.pth", map_location=torch.device(DEVICE)))
capture = model.eval()

In [None]:
get_accuracy(model, valid_loader, device=DEVICE)

In [None]:
torch.save(model.state_dict(), "./cifar-convnet-pickle2.pth", _use_new_zipfile_serialization=False) # don't forget to set model.eval() after loading

## fold bn layers into previous conv layers

In [None]:
previous_module = None
new_layers = []
for module in model.modules():
    if isinstance(module, (nn.Conv2d, nn.MaxPool2d, nn.BatchNorm2d, nn.Linear, nn.ReLU6, nn.ReLU, nn.Flatten)):
        if isinstance(module, nn.BatchNorm2d) and isinstance(previous_module, nn.Conv2d):
            new_layers[-1] = torch.nn.utils.fuse_conv_bn_eval(previous_module, module)
        else:
            new_layers.append(module)
        previous_module = module

In [None]:
folded_model = nn.Sequential(*new_layers)

In [None]:
get_folded_accuracy(folded_model, valid_loader, device=DEVICE)

In [None]:
# for module in model.modules():
#     if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
#         print(module)
#         print(module.weight.max().cpu().detach().numpy())

In [None]:
# for name, module in list(folded_model.named_parameters())[::1]:
#     print(name, module.max().detach().cpu().numpy())

In [None]:
# for name, module in list(model.named_parameters()):
#     print(name, module.max().detach().cpu().numpy())

In [None]:
input_tensor = torch.rand((10,3,32,32))
from torch.nn import functional as F

In [None]:
folded_model

In [None]:
list(zip(folded_output.flatten(), logits.flatten()))
list(zip(folded_probs.flatten(), probs.flatten()))

In [None]:
folded_model

In [None]:
model

In [None]:
get_accuracy(folded_model, valid_loader, device=DEVICE)

In [None]:
modules = [nn.Conv2d(2, 3, 3),
 nn.BatchNorm2d(3, momentum=0.4),
nn.ReLU6(inplace=True)]
nn.Sequential(*modules
    )

## Plot weights

In [None]:
from matplotlib import pyplot as plt
import numpy as np

weights = list(model.parameters())
weights = [weight.flatten() for weight in weights]
weights = torch.cat(weights)

capture = plt.hist(weights.cpu().detach().numpy(), bins=200)

## Model parameters

In [None]:
for param_tensor in folded_model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

In [None]:
list(model.state_dict())

# Quantization

In [None]:
# model.qconfig = torch.quantization.default_qconfig
# print(model.qconfig)
# torch.quantization.prepare(model, inplace=True)

In [None]:
torch.quantization.convert(model, inplace=True)

## Example activation

In [None]:
model.eval()
X, y_true = next(iter(test_loader))
X = X.to(DEVICE)
Y = y_true.to(DEVICE)
Y_hat = model(X)[0]
plt.plot(Y_hat.cpu().detach().numpy().flatten())


In [None]:
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

In [None]:
model.conv1.weight