## Libraries & Setup

### Setup

In [2]:
from functools import partial

import numpy as np
import pandas as pd
import torch.nn.parallel

from torch.nn import CrossEntropyLoss
from torch.optim import SGD

from utils.adaquant import optimize_layer_adaquant
from utils.load_dataset import load_dataset
from utils.utils_functions import val_loop, set_seeds
from utils.mobilenet_v2 import MobileNetV2
from utils.quantize import QConv2d, QLinear
from utils.trainer import Trainer
set_seeds()

In [3]:
acc = -1
loss = -1
best_prec1 = 0
dtype = torch.float32

## Data

In [11]:
cal_dataloader, train_dataloader, val_dataloader=load_dataset(
'./data/calibration',
'./data/calibration_labels.csv',
'./data/cifar-10/train',
'./data/cifar-10/trainLabels.csv'
)

## Model, Optimizer, Trainer

In [12]:
model = MobileNetV2(num_bits=8, num_bits_weight=8)
model.load_state_dict(torch.load('models/mobilenet1.pt'))
criterion = CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=1e-2, momentum=0.5, weight_decay=0)
trainer = Trainer(model, criterion, optimizer, device=torch.device("cpu"))

val_loop(cal_dataloader, model, 'cpu')

45.74574574574575

## Cache, Hook

In [13]:
cached_qinput = {}

def Qhook(name, module, input, output):
    if module not in cached_qinput:
        cached_qinput[module] = []
        cached_qinput[module].append(input[0].detach().cpu())

In [14]:
cached_input_output = {}

def hook(name, module, input, output):
    if module not in cached_input_output:
        cached_input_output[module] = []
    cached_input_output[module].append((input[0].detach().cpu(), output.detach().cpu()))

In [15]:
for name, m in model.named_modules():
    if isinstance(m, QConv2d) or isinstance(m, QLinear):
        print(name)
        m.name = name

features.0.0
features.1.conv.0.0
features.1.conv.1
features.2.conv.0.0
features.2.conv.1.0
features.2.conv.2
features.3.conv.0.0
features.3.conv.1.0
features.3.conv.2
features.4.conv.0.0
features.4.conv.1.0
features.4.conv.2
features.5.conv.0.0
features.5.conv.1.0
features.5.conv.2
features.6.conv.0.0
features.6.conv.1.0
features.6.conv.2
features.7.conv.0.0
features.7.conv.1.0
features.7.conv.2
features.8.conv.0.0
features.8.conv.1.0
features.8.conv.2
features.9.conv.0.0
features.9.conv.1.0
features.9.conv.2
features.10.conv.0.0
features.10.conv.1.0
features.10.conv.2
features.11.conv.0.0
features.11.conv.1.0
features.11.conv.2
features.12.conv.0.0
features.12.conv.1.0
features.12.conv.2
features.13.conv.0.0
features.13.conv.1.0
features.13.conv.2
features.14.conv.0.0
features.14.conv.1.0
features.14.conv.2
features.15.conv.0.0
features.15.conv.1.0
features.15.conv.2
features.16.conv.0.0
features.16.conv.1.0
features.16.conv.2
features.17.conv.0.0
features.17.conv.1.0
features.17.conv

In [16]:
handlers = []
count = 0
for name, m in model.named_modules():
    if isinstance(m, QConv2d) or isinstance(m, QLinear):
        m.quantize = False
        handlers.append(m.register_forward_hook(partial(hook, name)))
        count += 1

In [17]:
# Store input/output for all quantizable layers
trainer.validate(val_dataloader)
print("Input/outputs cached")

Input/outputs cached


In [18]:
for handler in handlers:
    handler.remove()

for m in model.modules():
    if isinstance(m, QConv2d) or isinstance(m, QLinear):
        m.quantize = True

## Loop Through Layers

In [19]:
mse_df = pd.DataFrame(
    index=np.arange(len(cached_input_output)),
    columns=["name", "bit", "shape", "mse_before", "mse_after", "acc"],
)
print_freq = 100
evaluate = "evaluate"

In [20]:
for i, layer in enumerate(cached_input_output):
    if i > 0: 
        count = 0
        cached_qinput = {}
        for name, m in model.named_modules():
            if layer.name == name:
                if count < 5:
                    handler = m.register_forward_hook(partial(Qhook, name))
                    count += 1
        trainer.validate(cal_dataloader)
        print("cashed quant Input%s" % layer.name)
        cached_input_output[layer][0] = (
            cached_qinput[layer][0],
            cached_input_output[layer][0][1],
        )
        handler.remove()
    print(
        "\nOptimize {}:{} for {} bit of shape {}".format(
            i, layer.name, layer.num_bits, layer.weight.shape
        )
    )

    mse_before, mse_after = optimize_layer_adaquant(layer, cached_input_output[layer])
    acc=val_loop(val_dataloader, model, 'cpu')

    print("\nMSE before optimization: {}".format(mse_before))
    print("MSE after optimization:  {}".format(mse_after))
    mse_df.loc[i, "name"] = layer.name
    mse_df.loc[i, "bit"] = layer.num_bits
    mse_df.loc[i, "shape"] = str(layer.weight.shape)
    mse_df.loc[i, "mse_before"] = mse_before
    mse_df.loc[i, "mse_after"] = mse_after
    mse_df.loc[i, "acc"] = acc


mse_csv = evaluate + ".mse.csv"
mse_df.to_csv(mse_csv)

filename = evaluate + "_adaquant_val02"
torch.save(model.state_dict(), filename)


Optimize 0:features.0.0 for 8 bit of shape torch.Size([32, 3, 3, 3])


100%|██████████| 100/100 [00:01<00:00, 92.82it/s]



MSE before optimization: 0.3005015552043915
MSE after optimization:  0.032455552369356155
cashed quant Inputfeatures.1.conv.0.0

Optimize 1:features.1.conv.0.0 for 8 bit of shape torch.Size([32, 1, 3, 3])


100%|██████████| 100/100 [00:02<00:00, 45.97it/s]



MSE before optimization: 0.002374418079853058
MSE after optimization:  0.0023612570948898792
cashed quant Inputfeatures.1.conv.1

Optimize 2:features.1.conv.1 for 8 bit of shape torch.Size([16, 32, 1, 1])


100%|██████████| 100/100 [00:01<00:00, 56.84it/s]



MSE before optimization: 0.003858943236991763
MSE after optimization:  0.0032873444724828005
cashed quant Inputfeatures.2.conv.0.0

Optimize 3:features.2.conv.0.0 for 8 bit of shape torch.Size([96, 16, 1, 1])


100%|██████████| 100/100 [00:01<00:00, 55.58it/s]



MSE before optimization: 0.05347907543182373
MSE after optimization:  0.053474221378564835
cashed quant Inputfeatures.2.conv.1.0

Optimize 4:features.2.conv.1.0 for 8 bit of shape torch.Size([96, 1, 3, 3])


100%|██████████| 100/100 [00:04<00:00, 24.96it/s]



MSE before optimization: 0.04491991177201271
MSE after optimization:  0.03614884242415428
cashed quant Inputfeatures.2.conv.2

Optimize 5:features.2.conv.2 for 8 bit of shape torch.Size([24, 96, 1, 1])


100%|██████████| 100/100 [00:01<00:00, 60.35it/s]



MSE before optimization: 0.08661017566919327
MSE after optimization:  0.0783320963382721
cashed quant Inputfeatures.3.conv.0.0

Optimize 6:features.3.conv.0.0 for 8 bit of shape torch.Size([144, 24, 1, 1])


100%|██████████| 100/100 [00:01<00:00, 86.40it/s]



MSE before optimization: 0.08387216180562973
MSE after optimization:  0.07656430453062057
cashed quant Inputfeatures.3.conv.1.0

Optimize 7:features.3.conv.1.0 for 8 bit of shape torch.Size([144, 1, 3, 3])


100%|██████████| 100/100 [00:02<00:00, 43.86it/s]



MSE before optimization: 0.008655977435410023
MSE after optimization:  0.0027042729780077934
cashed quant Inputfeatures.3.conv.2

Optimize 8:features.3.conv.2 for 8 bit of shape torch.Size([24, 144, 1, 1])


100%|██████████| 100/100 [00:02<00:00, 49.17it/s]



MSE before optimization: 0.018840037286281586
MSE after optimization:  0.015646224841475487
cashed quant Inputfeatures.4.conv.0.0

Optimize 9:features.4.conv.0.0 for 8 bit of shape torch.Size([144, 24, 1, 1])


100%|██████████| 100/100 [00:01<00:00, 77.66it/s]



MSE before optimization: 0.2172471582889557
MSE after optimization:  0.18928572535514832
cashed quant Inputfeatures.4.conv.1.0

Optimize 10:features.4.conv.1.0 for 8 bit of shape torch.Size([144, 1, 3, 3])


100%|██████████| 100/100 [00:02<00:00, 45.14it/s]



MSE before optimization: 0.0950741097331047
MSE after optimization:  0.02344476990401745
cashed quant Inputfeatures.4.conv.2

Optimize 11:features.4.conv.2 for 8 bit of shape torch.Size([32, 144, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 157.33it/s]
