In [1]:
import torch, time, copy
import torch.nn as nn
import torch.optim as optim
from torch.quantization import prepare, convert
from src.utils import *
from src.override_resnet import *
import matplotlib.pyplot as plt
import numpy as np

In [None]:
device = "cpu"
model = resnet50_quan(weights=resnet.ResNet50_Weights.IMAGENET1K_V2)
model.to(device)
model.eval()

# set fuse ############################################################
model = fuse_ALL(model)


for name, module in model.named_modules():
    # print(name)
    if hasattr(module, "qconfig"):
        if module.qconfig is not None:
            print(f"{name} | {module.qconfig}")
    if len(name) > 2 and name[-7:] == "act_obs":

        module.qconfig = torch.quantization.QConfig(
            activation=torch.quantization.RecordingObserver.with_args(
                dtype=torch.quint8
            ),
            weight=None,
        )
        print(f"{name} | {module.qconfig}")

In [None]:
prepare(model, inplace=True)

In [None]:
# calibrate the model ############################################################
criterion = nn.CrossEntropyLoss()
batch_size = 32
train_loader, _ = GetDataset(
    dataset_name="ImageNet",
    device=device,
    root="data",
    batch_size=batch_size,
    num_workers=8,
)

# nvidia paper == 1024 images
print(SingleEpochEval(model, train_loader, criterion, device, limit=32))

In [None]:
for name, module in model.named_modules():
    if (
        len(name) > 2
        and name[-7:] == "act_obs"
        and hasattr(module, "activation_post_process")
    ):
        # if name == "act_obs":
        print(f"{name}", end=" ")
        wns = module.activation_post_process.get_tensor_value()
        wns = np.array(wns)
        print(list(wns.shape))
        cnt = 1
        for i in list(wns.shape):
            cnt *= i
        wns = wns.flatten()
        plt.hist(wns, bins=200)
        tmp = cnt * 4 / 1024 / 1024, "MB"
        plt.title(name + str(tmp))
        plt.yscale("log")
        plt.xlabel("Activation value")
        plt.ylabel("Frequency")
        plt.show()