## Libraries & Setup

### Setup

In [1]:
from functools import partial

import numpy as np
import pandas as pd
import torch.nn.parallel

from torch.nn import CrossEntropyLoss
from torch.optim import SGD

from utils.adaquant import optimize_layer_adaquant
from utils.load_dataset import load_dataset
from utils.utils_functions import val_loop, set_seeds
from utils.mobilenet_v2 import MobileNetV2
from utils.quantize import QConv2d, QLinear
from utils.trainer import Trainer
set_seeds()

In [2]:
acc = -1
loss = -1
best_prec1 = 0
dtype = torch.float32

## Load dataset

In [3]:
cal_dataloader, train_dataloader, val_dataloader=load_dataset(
'./data/calibration',
'./data/calibration_labels.csv',
'./data/cifar-10/train',
'./data/cifar-10/trainLabels.csv'
)

## Model, Optimizer, Trainer

In [4]:
model = MobileNetV2(num_bits=8, num_bits_weight=8)
model.load_state_dict(torch.load('models/mobilenet1.pt'))
criterion = CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=1e-2, momentum=0.5, weight_decay=0)
trainer = Trainer(model, criterion, optimizer, device=torch.device("cpu"))

val_loop(cal_dataloader, model, 'cpu')

44.04404404404404

## Prepare for quantization

In [5]:
cached_qinput = {}

def Qhook(name, module, input, output):
    if module not in cached_qinput:
        cached_qinput[module] = []
        cached_qinput[module].append(input[0].detach().cpu())

In [6]:
cached_input_output = {}

def hook(name, module, input, output):
    if module not in cached_input_output:
        cached_input_output[module] = []
    cached_input_output[module].append((input[0].detach().cpu(), output.detach().cpu()))

In [7]:
for name, m in model.named_modules():
    if isinstance(m, QConv2d) or isinstance(m, QLinear):
        print(name)
        m.name = name

features.0.0
features.1.conv.0.0
features.1.conv.1
features.2.conv.0.0
features.2.conv.1.0
features.2.conv.2
features.3.conv.0.0
features.3.conv.1.0
features.3.conv.2
features.4.conv.0.0
features.4.conv.1.0
features.4.conv.2
features.5.conv.0.0
features.5.conv.1.0
features.5.conv.2
features.6.conv.0.0
features.6.conv.1.0
features.6.conv.2
features.7.conv.0.0
features.7.conv.1.0
features.7.conv.2
features.8.conv.0.0
features.8.conv.1.0
features.8.conv.2
features.9.conv.0.0
features.9.conv.1.0
features.9.conv.2
features.10.conv.0.0
features.10.conv.1.0
features.10.conv.2
features.11.conv.0.0
features.11.conv.1.0
features.11.conv.2
features.12.conv.0.0
features.12.conv.1.0
features.12.conv.2
features.13.conv.0.0
features.13.conv.1.0
features.13.conv.2
features.14.conv.0.0
features.14.conv.1.0
features.14.conv.2
features.15.conv.0.0
features.15.conv.1.0
features.15.conv.2
features.16.conv.0.0
features.16.conv.1.0
features.16.conv.2
features.17.conv.0.0
features.17.conv.1.0
features.17.conv

In [8]:
handlers = []
count = 0
for name, m in model.named_modules():
    if isinstance(m, QConv2d) or isinstance(m, QLinear):
        m.quantize = False
        handlers.append(m.register_forward_hook(partial(hook, name)))
        count += 1

In [9]:
# Store input/output for all quantizable layers
trainer.validate(val_dataloader)
print("Input/outputs cached")

Input/outputs cached


In [10]:
for handler in handlers:
    handler.remove()

for m in model.modules():
    if isinstance(m, QConv2d) or isinstance(m, QLinear):
        m.quantize = True

## Loop through layers - run qunatization

In [11]:
mse_df = pd.DataFrame(
    index=np.arange(len(cached_input_output)),
    columns=["name", "bit", "shape", "mse_before", "mse_after", "acc"],
)
print_freq = 100
evaluate = "evaluate"

In [12]:
for i, layer in enumerate(cached_input_output):
    if i > 0: 
        count = 0
        cached_qinput = {}
        for name, m in model.named_modules():
            if layer.name == name:
                if count < 5:
                    handler = m.register_forward_hook(partial(Qhook, name))
                    count += 1
        trainer.validate(cal_dataloader)
        print("cashed quant Input%s" % layer.name)
        cached_input_output[layer][0] = (
            cached_qinput[layer][0],
            cached_input_output[layer][0][1],
        )
        handler.remove()
    print(
        "\nOptimize {}:{} for {} bit of shape {}".format(
            i, layer.name, layer.num_bits, layer.weight.shape
        )
    )

    mse_before, mse_after = optimize_layer_adaquant(layer, cached_input_output[layer])
    acc=val_loop(val_dataloader, model, 'cpu')

    print("\nMSE before optimization: {}".format(mse_before))
    print("MSE after optimization:  {}".format(mse_after))
    mse_df.loc[i, "name"] = layer.name
    mse_df.loc[i, "bit"] = layer.num_bits
    mse_df.loc[i, "shape"] = str(layer.weight.shape)
    mse_df.loc[i, "mse_before"] = mse_before
    mse_df.loc[i, "mse_after"] = mse_after
    mse_df.loc[i, "acc"] = acc


#save the results and model
mse_csv = evaluate + ".mse.csv"
mse_df.to_csv(mse_csv)

filename = evaluate + "_adaquant_val02"
torch.save(model.state_dict(), filename)


Optimize 0:features.0.0 for 8 bit of shape torch.Size([32, 3, 3, 3])


100%|██████████| 100/100 [00:01<00:00, 87.14it/s]



MSE before optimization: 0.2940675616264343
MSE after optimization:  0.03231118991971016
cashed quant Inputfeatures.1.conv.0.0

Optimize 1:features.1.conv.0.0 for 8 bit of shape torch.Size([32, 1, 3, 3])


100%|██████████| 100/100 [00:02<00:00, 43.87it/s]



MSE before optimization: 0.002355064731091261
MSE after optimization:  0.00234135240316391
cashed quant Inputfeatures.1.conv.1

Optimize 2:features.1.conv.1 for 8 bit of shape torch.Size([16, 32, 1, 1])


100%|██████████| 100/100 [00:01<00:00, 60.00it/s]



MSE before optimization: 0.003050582716241479
MSE after optimization:  0.0029362684581428766
cashed quant Inputfeatures.2.conv.0.0

Optimize 3:features.2.conv.0.0 for 8 bit of shape torch.Size([96, 16, 1, 1])


100%|██████████| 100/100 [00:01<00:00, 52.72it/s]



MSE before optimization: 0.05361859127879143
MSE after optimization:  0.05361330136656761
cashed quant Inputfeatures.2.conv.1.0

Optimize 4:features.2.conv.1.0 for 8 bit of shape torch.Size([96, 1, 3, 3])


100%|██████████| 100/100 [00:04<00:00, 24.46it/s]



MSE before optimization: 0.045015767216682434
MSE after optimization:  0.0361471064388752
cashed quant Inputfeatures.2.conv.2

Optimize 5:features.2.conv.2 for 8 bit of shape torch.Size([24, 96, 1, 1])


100%|██████████| 100/100 [00:01<00:00, 68.90it/s]



MSE before optimization: 0.08585748076438904
MSE after optimization:  0.08518283814191818
cashed quant Inputfeatures.3.conv.0.0

Optimize 6:features.3.conv.0.0 for 8 bit of shape torch.Size([144, 24, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 104.05it/s]



MSE before optimization: 0.08344480395317078
MSE after optimization:  0.07618572562932968
cashed quant Inputfeatures.3.conv.1.0

Optimize 7:features.3.conv.1.0 for 8 bit of shape torch.Size([144, 1, 3, 3])


100%|██████████| 100/100 [00:02<00:00, 45.69it/s]



MSE before optimization: 0.00898977741599083
MSE after optimization:  0.002702984493225813
cashed quant Inputfeatures.3.conv.2

Optimize 8:features.3.conv.2 for 8 bit of shape torch.Size([24, 144, 1, 1])


100%|██████████| 100/100 [00:01<00:00, 54.63it/s]



MSE before optimization: 0.019225656986236572
MSE after optimization:  0.016296228393912315
cashed quant Inputfeatures.4.conv.0.0

Optimize 9:features.4.conv.0.0 for 8 bit of shape torch.Size([144, 24, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 104.08it/s]



MSE before optimization: 0.2163679599761963
MSE after optimization:  0.1883907914161682
cashed quant Inputfeatures.4.conv.1.0

Optimize 10:features.4.conv.1.0 for 8 bit of shape torch.Size([144, 1, 3, 3])


100%|██████████| 100/100 [00:01<00:00, 56.63it/s]



MSE before optimization: 0.09327231347560883
MSE after optimization:  0.01956268586218357
cashed quant Inputfeatures.4.conv.2

Optimize 11:features.4.conv.2 for 8 bit of shape torch.Size([32, 144, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 175.47it/s]



MSE before optimization: 1.4994641542434692
MSE after optimization:  1.3670309782028198
cashed quant Inputfeatures.5.conv.0.0

Optimize 12:features.5.conv.0.0 for 8 bit of shape torch.Size([192, 32, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 201.01it/s]



MSE before optimization: 1.5233330726623535
MSE after optimization:  1.3364012241363525
cashed quant Inputfeatures.5.conv.1.0

Optimize 13:features.5.conv.1.0 for 8 bit of shape torch.Size([192, 1, 3, 3])


100%|██████████| 100/100 [00:00<00:00, 152.88it/s]



MSE before optimization: 0.0440930612385273
MSE after optimization:  0.040624555200338364
cashed quant Inputfeatures.5.conv.2

Optimize 14:features.5.conv.2 for 8 bit of shape torch.Size([32, 192, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 161.91it/s]



MSE before optimization: 0.7615174651145935
MSE after optimization:  0.6663913130760193
cashed quant Inputfeatures.6.conv.0.0

Optimize 15:features.6.conv.0.0 for 8 bit of shape torch.Size([192, 32, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 197.92it/s]



MSE before optimization: 2.025031805038452
MSE after optimization:  1.6768062114715576
cashed quant Inputfeatures.6.conv.1.0

Optimize 16:features.6.conv.1.0 for 8 bit of shape torch.Size([192, 1, 3, 3])


100%|██████████| 100/100 [00:00<00:00, 150.54it/s]



MSE before optimization: 0.03506797179579735
MSE after optimization:  0.03471272438764572
cashed quant Inputfeatures.6.conv.2

Optimize 17:features.6.conv.2 for 8 bit of shape torch.Size([32, 192, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 157.88it/s]



MSE before optimization: 0.3706902265548706
MSE after optimization:  0.34559449553489685
cashed quant Inputfeatures.7.conv.0.0

Optimize 18:features.7.conv.0.0 for 8 bit of shape torch.Size([192, 32, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 202.44it/s]



MSE before optimization: 2.4101109504699707
MSE after optimization:  2.224879264831543
cashed quant Inputfeatures.7.conv.1.0

Optimize 19:features.7.conv.1.0 for 8 bit of shape torch.Size([192, 1, 3, 3])


100%|██████████| 100/100 [00:00<00:00, 153.19it/s]



MSE before optimization: 0.1415058672428131
MSE after optimization:  0.07243413478136063
cashed quant Inputfeatures.7.conv.2

Optimize 20:features.7.conv.2 for 8 bit of shape torch.Size([64, 192, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 200.05it/s]



MSE before optimization: 1.549709677696228
MSE after optimization:  1.3535579442977905
cashed quant Inputfeatures.8.conv.0.0

Optimize 21:features.8.conv.0.0 for 8 bit of shape torch.Size([384, 64, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 184.62it/s]



MSE before optimization: 2.2502737045288086
MSE after optimization:  2.199319362640381
cashed quant Inputfeatures.8.conv.1.0

Optimize 22:features.8.conv.1.0 for 8 bit of shape torch.Size([384, 1, 3, 3])


100%|██████████| 100/100 [00:00<00:00, 103.74it/s]



MSE before optimization: 0.01570802927017212
MSE after optimization:  0.013197551481425762
cashed quant Inputfeatures.8.conv.2

Optimize 23:features.8.conv.2 for 8 bit of shape torch.Size([64, 384, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 162.62it/s]



MSE before optimization: 0.5048291087150574
MSE after optimization:  0.4900008738040924
cashed quant Inputfeatures.9.conv.0.0

Optimize 24:features.9.conv.0.0 for 8 bit of shape torch.Size([384, 64, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 190.28it/s]



MSE before optimization: 1.828735113143921
MSE after optimization:  1.8266786336898804
cashed quant Inputfeatures.9.conv.1.0

Optimize 25:features.9.conv.1.0 for 8 bit of shape torch.Size([384, 1, 3, 3])


100%|██████████| 100/100 [00:01<00:00, 95.92it/s]



MSE before optimization: 0.0016530179418623447
MSE after optimization:  0.0015722020762041211
cashed quant Inputfeatures.9.conv.2

Optimize 26:features.9.conv.2 for 8 bit of shape torch.Size([64, 384, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 161.00it/s]



MSE before optimization: 0.040724463760852814
MSE after optimization:  0.040046047419309616
cashed quant Inputfeatures.10.conv.0.0

Optimize 27:features.10.conv.0.0 for 8 bit of shape torch.Size([384, 64, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 188.42it/s]



MSE before optimization: 0.8355658054351807
MSE after optimization:  0.8292196393013
cashed quant Inputfeatures.10.conv.1.0

Optimize 28:features.10.conv.1.0 for 8 bit of shape torch.Size([384, 1, 3, 3])


100%|██████████| 100/100 [00:00<00:00, 104.96it/s]



MSE before optimization: 0.0014586912002414465
MSE after optimization:  0.0012677176855504513
cashed quant Inputfeatures.10.conv.2

Optimize 29:features.10.conv.2 for 8 bit of shape torch.Size([64, 384, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 160.79it/s]



MSE before optimization: 0.048190440982580185
MSE after optimization:  0.031466081738471985
cashed quant Inputfeatures.11.conv.0.0

Optimize 30:features.11.conv.0.0 for 8 bit of shape torch.Size([384, 64, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 170.55it/s]



MSE before optimization: 5.981011390686035
MSE after optimization:  5.909394264221191
cashed quant Inputfeatures.11.conv.1.0

Optimize 31:features.11.conv.1.0 for 8 bit of shape torch.Size([384, 1, 3, 3])


100%|██████████| 100/100 [00:00<00:00, 105.56it/s]



MSE before optimization: 0.03958893567323685
MSE after optimization:  0.036917828023433685
cashed quant Inputfeatures.11.conv.2

Optimize 32:features.11.conv.2 for 8 bit of shape torch.Size([96, 384, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 152.72it/s]



MSE before optimization: 0.9285637140274048
MSE after optimization:  0.8980165719985962
cashed quant Inputfeatures.12.conv.0.0

Optimize 33:features.12.conv.0.0 for 8 bit of shape torch.Size([576, 96, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 149.94it/s]



MSE before optimization: 0.13901114463806152
MSE after optimization:  0.13779087364673615
cashed quant Inputfeatures.12.conv.1.0

Optimize 34:features.12.conv.1.0 for 8 bit of shape torch.Size([576, 1, 3, 3])


100%|██████████| 100/100 [00:01<00:00, 79.88it/s]



MSE before optimization: 0.0007512808660976589
MSE after optimization:  0.0005068076425231993
cashed quant Inputfeatures.12.conv.2

Optimize 35:features.12.conv.2 for 8 bit of shape torch.Size([96, 576, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 128.46it/s]



MSE before optimization: 0.09863002598285675
MSE after optimization:  0.07016261667013168
cashed quant Inputfeatures.13.conv.0.0

Optimize 36:features.13.conv.0.0 for 8 bit of shape torch.Size([576, 96, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 140.12it/s]



MSE before optimization: 0.1611613929271698
MSE after optimization:  0.15906818211078644
cashed quant Inputfeatures.13.conv.1.0

Optimize 37:features.13.conv.1.0 for 8 bit of shape torch.Size([576, 1, 3, 3])


100%|██████████| 100/100 [00:01<00:00, 79.62it/s]



MSE before optimization: 0.0006240135990083218
MSE after optimization:  0.0004206638259347528
cashed quant Inputfeatures.13.conv.2

Optimize 38:features.13.conv.2 for 8 bit of shape torch.Size([96, 576, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 130.11it/s]



MSE before optimization: 0.05925614759325981
MSE after optimization:  0.028244568035006523
cashed quant Inputfeatures.14.conv.0.0

Optimize 39:features.14.conv.0.0 for 8 bit of shape torch.Size([576, 96, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 151.38it/s]



MSE before optimization: 0.2645561099052429
MSE after optimization:  0.26002272963523865
cashed quant Inputfeatures.14.conv.1.0

Optimize 40:features.14.conv.1.0 for 8 bit of shape torch.Size([576, 1, 3, 3])


100%|██████████| 100/100 [00:01<00:00, 84.48it/s]



MSE before optimization: 0.003313385881483555
MSE after optimization:  0.00039144637412391603
cashed quant Inputfeatures.14.conv.2

Optimize 41:features.14.conv.2 for 8 bit of shape torch.Size([160, 576, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 131.26it/s]



MSE before optimization: 0.16470959782600403
MSE after optimization:  0.16192737221717834
cashed quant Inputfeatures.15.conv.0.0

Optimize 42:features.15.conv.0.0 for 8 bit of shape torch.Size([960, 160, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 120.04it/s]



MSE before optimization: 0.8386932611465454
MSE after optimization:  0.7928159236907959
cashed quant Inputfeatures.15.conv.1.0

Optimize 43:features.15.conv.1.0 for 8 bit of shape torch.Size([960, 1, 3, 3])


100%|██████████| 100/100 [00:01<00:00, 65.72it/s]



MSE before optimization: 0.0019710054621100426
MSE after optimization:  0.0018117798026651144
cashed quant Inputfeatures.15.conv.2

Optimize 44:features.15.conv.2 for 8 bit of shape torch.Size([160, 960, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 115.38it/s]



MSE before optimization: 0.05379612743854523
MSE after optimization:  0.030546288937330246
cashed quant Inputfeatures.16.conv.0.0

Optimize 45:features.16.conv.0.0 for 8 bit of shape torch.Size([960, 160, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 126.35it/s]



MSE before optimization: 8.20903491973877
MSE after optimization:  4.843896389007568
cashed quant Inputfeatures.16.conv.1.0

Optimize 46:features.16.conv.1.0 for 8 bit of shape torch.Size([960, 1, 3, 3])


100%|██████████| 100/100 [00:01<00:00, 68.50it/s]



MSE before optimization: 0.0033551654778420925
MSE after optimization:  0.0016540881479158998
cashed quant Inputfeatures.16.conv.2

Optimize 47:features.16.conv.2 for 8 bit of shape torch.Size([160, 960, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 126.74it/s]



MSE before optimization: 1.3423482179641724
MSE after optimization:  0.5264734625816345
cashed quant Inputfeatures.17.conv.0.0

Optimize 48:features.17.conv.0.0 for 8 bit of shape torch.Size([960, 160, 1, 1])


100%|██████████| 100/100 [00:00<00:00, 131.32it/s]



MSE before optimization: 129.00953674316406
MSE after optimization:  101.76636505126953
cashed quant Inputfeatures.17.conv.1.0

Optimize 49:features.17.conv.1.0 for 8 bit of shape torch.Size([960, 1, 3, 3])


100%|██████████| 100/100 [00:01<00:00, 66.20it/s]



MSE before optimization: 0.004908871371299028
MSE after optimization:  0.0013554277829825878
cashed quant Inputfeatures.17.conv.2

Optimize 50:features.17.conv.2 for 8 bit of shape torch.Size([320, 960, 1, 1])


100%|██████████| 100/100 [00:01<00:00, 55.21it/s]



MSE before optimization: 1.0320674180984497
MSE after optimization:  0.5273337960243225
cashed quant Inputfeatures.18.0

Optimize 51:features.18.0 for 8 bit of shape torch.Size([1280, 320, 1, 1])


100%|██████████| 100/100 [00:02<00:00, 48.39it/s]



MSE before optimization: 9.769645690917969
MSE after optimization:  9.340081214904785
cashed quant Inputclassifier.1

Optimize 52:classifier.1 for 8 bit of shape torch.Size([10, 1280])


100%|██████████| 100/100 [00:00<00:00, 198.39it/s]



MSE before optimization: 3332.45947265625
MSE after optimization:  3047.751953125
