In [None]:
import torch, time
import torch.nn as nn
import torch.optim as optim

from src.utils import *
from src.override_resnet import *


class Args:
    arch = 50
    dataset = "ImageNet"
    # dataset = "CIFAR100"
    lr = 0.001
    momentum = 0.9
    batch = 16
    epochs = 10
    save_every = 1
    quan = "static"
    only_eval = True
    verbose = True


args = Args()

In [None]:
def run_benchmark(model, img_loader, device)->float:
    elapsed = 0
    model.eval()
    num_batches = 1
    for i, (images, target) in enumerate(img_loader):
        if i < num_batches:
            images = images.to(device)
            start = time.time()
            output = model(images)
            end = time.time()
            elapsed = elapsed + (end - start)
        else:
            break
    num_images = images.size()[0] * num_batches

    print("Elapsed time: %3.0f ms" % (elapsed / num_images * 1000))
    # return elapsed
    return elapsed / num_images * 1000

def check_accuracy(model, device, batch_size=25)->tuple:
    model.eval()
    model.to(device)

    _, test_loader = GetDataset(
        dataset_name=args.dataset,
        device=device,
        root="data",
        batch_size=batch_size,
        num_workers=8,
    )
    _ = run_benchmark(model, test_loader, device)
    print_size_of_model(model)
    criterion = nn.CrossEntropyLoss()
    eval_loss, eval_acc = SingleEpochEval(model, test_loader, criterion, device, 500)
    
    print(f"Eval Loss: {eval_loss:.4f}, Eval Acc: {eval_acc:.2f}%")
    return  

In [None]:
def fuse_model(model) -> nn.Module:
    flag = False
    for m in model.modules():
        if m.__class__.__name__ == ResNet_quan.__name__:
            if flag == True:
                raise ValueError("ResNet_quan is already fused")
            flag = True
            torch.quantization.fuse_modules(
                m,
                ["conv1", "bn1", "relu"],
                inplace=True,
            )

        if type(m) == BottleNeck_quan:
            torch.quantization.fuse_modules(
                m,
                [
                    ["conv1", "bn1", "relu1"],
                    ["conv2", "bn2", "relu2"],
                    ["conv3", "bn3"],
                ],
                inplace=True,
            )
            if m.downsample is not None:
                torch.quantization.fuse_modules(
                    m.downsample,
                    ["0", "1"],
                    inplace=True,
                )
    return model

In [None]:
# %% my code

args = Args()
# %% Load the ResNet-50 model
if args.quan == "fp32":
    # case 0 : no quantization case
    print("----------No quantization enabled")
    device = str(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
    model = layers_mapping[args.arch](
        weights=pretrained_weights_mapping[args.arch]
    ).to(device)

elif args.quan == "dynamic":
    # case 1 : Dynamic Quantization
    print("----------Dynamic Quantization enabled")
    device = "cuda"
    model = resnet50_quan(weights=pretrained_weights_mapping[args.arch]).to(device)
    quantized_model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )
    model = quantized_model

elif args.quan == "static":
    # case 2 : Static Quantization
    print("----------Static Quantization enabled")
    device = "cpu"
    model = resnet50_quan(weights=pretrained_weights_mapping[args.arch]).to(device)

elif args.quan == "qat":
    # case 3 : Quantization Aware Training
    print("----------Quantization Aware Training enabled")
else:
    raise ValueError("Invalid quantization method")

# _folder_path = f"resnet{args.arch}_{args.dataset}" + "_" + args.quan
# _file_name = (
#     f"resnet{args.arch}_{args.dataset}_epoch"  # resnet18_cifar10_epoch{epoch}.pth
# )


# 1. The Acc of Reference Network

- Check the origin network architecture

In [None]:
print(model.layer1[0])

In [None]:
check_accuracy(model=model, device="cpu", batch_size=25)
print("Post Training Quantization: Eval done")

- Check the fused network architecture

In [None]:
# model = fuse_model(model)
print(print_size_of_model(model))
print(model.layer1[0])

# 2. Calibration for Post-Training Static Quantization

- Check the Quantization Configuration

In [None]:
# QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){},
#         weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})


model.qconfig = torch.quantization.get_default_qconfig("x86")
# model.qconfig = torch.quantization.QConfig(
#     activation=torch.quantization.observer.HistogramObserver.with_args(
#         reduce_range=True
#     ),
#     weight=torch.quantization.observer.PerChannelMinMaxObserver.with_args(qscheme=torch.per_channel_symmetric),
# )
print(model.qconfig)


In [None]:
torch.quantization.prepare(model, inplace=True)

print("Post Training Quantization Prepare: Inserting Observers")

- Inference with the representative dataset (calculate the quantization parameters)

In [None]:
criterion = nn.CrossEntropyLoss()
train_loader, test_loader = GetDataset(
    dataset_name=args.dataset,
    device=device,
    root="data",
    batch_size=256,
    num_workers=8,
)
_, _ = SingleEpochEval(model, train_loader, criterion, "cuda", 5000)
print("Post Training Quantization: Calibration done")

- Convert to quantized model

In [None]:
device = "cpu"
model.to(device)
torch.quantization.convert(model, inplace=True)
print("Post Training Quantization: Convert done")

# 3. Complete 

In [None]:
check_accuracy(model=model, device="cpu", batch_size=25)
print("Post Training Quantization: Eval done")