In [1]:
import torch, time
import torch.nn as nn
import torch.optim as optim

from src.utils import *
# from src.override_resnet import *


class Args:
    arch = 50
    dataset = "ImageNet"
    # dataset = "CIFAR100"
    lr = 0.001
    momentum = 0.9
    batch = 16
    epochs = 10
    save_every = 1
    quan = "static"
    only_eval = True
    verbose = True


args = Args()

In [2]:
# %% override the torchvision.models.resnet
from torchvision.models.resnet import (
    ResNet,
    ResNet50_Weights,
    Bottleneck,
    BasicBlock,
)
from functools import partial
from typing import Any, Callable, List, Optional, Type, Union

import torch
import torch.nn as nn
from torch import Tensor

from torchvision.transforms._presets import ImageClassification
from torchvision.utils import _log_api_usage_once
from torchvision.models._api import register_model, Weights, WeightsEnum
from torchvision.models._meta import _IMAGENET_CATEGORIES
from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface

"""
Todo : 
- [x] forward 함수 앞뒤로 quantization 추가
- [ ] skip add에서 그냥 +를 nn.quantized.FloatFunctional()으로 바꾸기
- [x] Conv, bn, relu 하나로 만들어야함.
- [x] ReLU 6면 int계산 안 되는데, 일반 ReLU인 것은 확인 완료
"""


class BottleNeck_quan(Bottleneck):
    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super(BottleNeck_quan, self).__init__(
            inplanes,
            planes,
            stride,
            downsample,
            groups,
            base_width,
            dilation,
            norm_layer,
        )
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.relu3 = nn.ReLU()
        self.add = nn.quantized.FloatFunctional()

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        # out += identity
        out = self.add.add(out, identity)
        out = self.relu3(out)

        return out

    # def forward(self, x: Tensor) -> Tensor:
    #     x = super(BottleNeck_quan, self).forward(x)
    #     return x


# class BasicBlock_quan(BasicBlock): << 원하면 Block 내부 override해서 사용
class ResNet_quan(ResNet):
    def __init__(
        self,
        block: Any,
        layers: list[int],
        num_classes: int = 1000,
        weights: Optional[str] = None,
    ) -> None:
        super(ResNet_quan, self).__init__(block, layers, num_classes)
        if weights is not None:
            self.load_state_dict(torch.load(weights))
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    # def forward(self, x: torch.Tensor) -> torch.Tensor:
    #     x = self.quant(x)
    #     x = super(ResNet_quan, self).forward(x)
    #     x = self.dequant(x)
    #     return x

    def forward(self, x:torch.Tensor) -> torch.Tensor:

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.quant(x)  
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = self.dequant(x)
        return x


def _resnet_quan(
    block: Type[Union[BasicBlock, BottleNeck_quan]],
    layers: List[int],
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> ResNet:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

    model = ResNet_quan(block, layers, **kwargs)

    if weights is not None:
        model.load_state_dict(
            weights.get_state_dict(progress=progress, check_hash=True)
        )

    return model


def resnet50_quan(
    *, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
    weights = ResNet50_Weights.verify(weights)
    return _resnet_quan(BottleNeck_quan, [3, 4, 6, 3], weights, progress, **kwargs)

In [3]:
def fuse_model(model) -> nn.Module:
    flag = False
    for m in model.modules():
        # if m.__class__.__name__ == ResNet_quan.__name__:
        #     if flag == True:
        #         raise ValueError("ResNet_quan is already fused")
        #     flag = True
        #     torch.quantization.fuse_modules(
        #         m,
        #         ["conv1", "bn1", "relu"],
        #         inplace=True,
        #     )

        if type(m) == BottleNeck_quan:
            torch.quantization.fuse_modules(
                m,
                [
                    ["conv1", "bn1", "relu1"],
                    ["conv2", "bn2", "relu2"],
                    ["conv3", "bn3"],
                ],
                inplace=True,
            )
            if m.downsample is not None:
                torch.quantization.fuse_modules(
                    m.downsample,
                    ["0", "1"],
                    inplace=True,
                )
    return model

In [4]:
# %% my code

args = Args()
# %% Load the ResNet-50 model
if args.quan == "fp32":
    # case 0 : no quantization case
    print("----------No quantization enabled")
    device = str(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
    model = layers_mapping[args.arch](
        weights=pretrained_weights_mapping[args.arch]
    ).to(device)

elif args.quan == "dynamic":
    # case 1 : Dynamic Quantization
    print("----------Dynamic Quantization enabled")
    device = "cuda"
    model = resnet50_quan(weights=pretrained_weights_mapping[args.arch]).to(device)
    quantized_model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )
    model = quantized_model

elif args.quan == "static":
    # case 2 : Static Quantization
    print("----------Static Quantization enabled")
    device = "cpu"
    model = resnet50_quan(weights=pretrained_weights_mapping[args.arch]).to(device)

elif args.quan == "qat":
    # case 3 : Quantization Aware Training
    print("----------Quantization Aware Training enabled")
else:
    raise ValueError("Invalid quantization method")

# _folder_path = f"resnet{args.arch}_{args.dataset}" + "_" + args.quan
# _file_name = (
#     f"resnet{args.arch}_{args.dataset}_epoch"  # resnet18_cifar10_epoch{epoch}.pth
# )


----------Static Quantization enabled


# 1. The Acc of Reference Network

- Check the origin network architecture

In [5]:
print(model.layer1[0])

BottleNeck_quan(
  (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (downsample): Sequential(
    (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (add): FloatFunctional(
    (activation_post_process): Identity()
  )
)


In [6]:
# check_accuracy(model=model, device="cpu", batch_size=25)
print("Post Training Quantization: Eval done")

Post Training Quantization: Eval done


- Check the fused network architecture

In [7]:
model.eval()
model = fuse_model(model)
print(print_size_of_model(model))
print(model.layer1[0])

Size (MB): 102.16063
None
BottleNeck_quan(
  (conv1): ConvReLU2d(
    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
  )
  (bn1): Identity()
  (conv2): ConvReLU2d(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
  (bn2): Identity()
  (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
  (bn3): Identity()
  (relu): ReLU(inplace=True)
  (downsample): Sequential(
    (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
    (1): Identity()
  )
  (relu1): Identity()
  (relu2): Identity()
  (relu3): ReLU()
  (add): FloatFunctional(
    (activation_post_process): Identity()
  )
)


# 2. Calibration for Post-Training Static Quantization

- Check the Quantization Configuration

In [8]:
# QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){},
#         weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})


model.qconfig = torch.quantization.get_default_qconfig("x86")
# model.qconfig = torch.quantization.QConfig(
#     activation=torch.quantization.observer.HistogramObserver.with_args(
#         reduce_range=True
#     ),
#     weight=torch.quantization.observer.PerChannelMinMaxObserver.with_args(qscheme=torch.per_channel_symmetric),
# )
print(model.qconfig)


QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})


In [9]:
torch.quantization.prepare(model, inplace=True)

print("Post Training Quantization Prepare: Inserting Observers")

Post Training Quantization Prepare: Inserting Observers




- Inference with the representative dataset (calculate the quantization parameters)

In [10]:
criterion = nn.CrossEntropyLoss()
train_loader, test_loader = GetDataset(
    dataset_name=args.dataset,
    device=device,
    root="data",
    batch_size=256,
    num_workers=8,
)
_, _ = SingleEpochEval(model, train_loader, criterion, "cuda", 1000)
print("Post Training Quantization: Calibration done")

 20%|█▉        | 999/5005 [26:03<1:44:28,  1.56s/it]

Post Training Quantization: Calibration done





- Convert to quantized model

In [11]:
device = "cpu"
model.to(device)
torch.quantization.convert(model, inplace=True)
print("Post Training Quantization: Convert done")

Post Training Quantization: Convert done


# 3. Complete 

In [13]:
check_accuracy(model=model, device="cpu", batch_size=25)
print("Post Training Quantization: Eval done")

NotImplementedError: Could not run 'quantized::conv2d.new' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::conv2d.new' is only available for these backends: [Meta, QuantizedCPU, QuantizedCUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMeta, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Meta: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/MetaFallbackKernel.cpp:23 [backend fallback]
QuantizedCPU: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/native/quantized/cpu/qconv.cpp:1912 [kernel]
QuantizedCUDA: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/native/quantized/cudnn/Conv.cpp:388 [kernel]
BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/PythonFallbackKernel.cpp:154 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/FunctionalizeFallbackKernel.cpp:324 [backend fallback]
Named: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/VariableFallbackKernel.cpp:53 [backend fallback]
AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/VariableFallbackKernel.cpp:57 [backend fallback]
AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/VariableFallbackKernel.cpp:65 [backend fallback]
AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/VariableFallbackKernel.cpp:69 [backend fallback]
AutogradMPS: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/VariableFallbackKernel.cpp:77 [backend fallback]
AutogradXPU: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/VariableFallbackKernel.cpp:61 [backend fallback]
AutogradHPU: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/VariableFallbackKernel.cpp:90 [backend fallback]
AutogradLazy: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/VariableFallbackKernel.cpp:73 [backend fallback]
AutogradMeta: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/VariableFallbackKernel.cpp:81 [backend fallback]
Tracer: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/torch/csrc/autograd/TraceTypeManual.cpp:297 [backend fallback]
AutocastCPU: fallthrough registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/autocast_mode.cpp:378 [backend fallback]
AutocastCUDA: fallthrough registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/autocast_mode.cpp:244 [backend fallback]
FuncTorchBatched: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:720 [backend fallback]
BatchedNestedTensor: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:746 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/PythonFallbackKernel.cpp:162 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/PythonFallbackKernel.cpp:166 [backend fallback]
PythonDispatcher: registered at /opt/conda/conda-bld/pytorch_1708025845206/work/aten/src/ATen/core/PythonFallbackKernel.cpp:158 [backend fallback]
