Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No Sign of activation quantization with QAT #647

Closed
Veccoy opened this issue Jun 18, 2024 · 1 comment
Closed

No Sign of activation quantization with QAT #647

Veccoy opened this issue Jun 18, 2024 · 1 comment

Comments

@Veccoy
Copy link

Veccoy commented Jun 18, 2024

Checklist

  • I have searched related issues but cannot get the expected help.
  • I have read related documents and don't know what to do.

Describe the question you meet

I am trying to train a model using QAT. I use the default config files provided by MMRazor for the ResNet18 model and the QAT algorithm. The training seems to work well until I check what is inside the model. I see the observers for the weights but I can't see anything linked with activation quantization...

Can someone explain why? How can I check that the activation observers and the fake quant nodes exist in the quantized model?

Even after deploying and checking the onnx model, I see only FixedPerChannelAffine and Identity blocks corresponding to weight quantization, but nothing for activation quantization.

Post related information

  1. The output of pip list | grep "mmcv\|mmrazor\|^torch"
    mmcv                      2.0.1
    mmrazor                 1.0.0             /workspace/mmlab/mmrazor
    torch                       1.13.1
    torchelastic             0.2.2
    torchtext                 0.14.1
    torchvision              0.14.1
  1. Your config file if you modified it or created a new one.
data_preprocessor = dict(
    mean=[
        123.675,
        116.28,
        103.53,
    ],
    std=[
        58.395,
        57.12,
        57.375,
    ],
    to_rgb=True)
data_root = '/home/XXX/val/Data'
default_hooks = dict(
    checkpoint=dict(interval=10, type='mmengine.hooks.CheckpointHook'),
    logger=dict(interval=100, type='mmengine.hooks.LoggerHook'),
    param_scheduler=dict(type='mmengine.hooks.ParamSchedulerHook'),
    sampler_seed=dict(type='mmengine.hooks.DistSamplerSeedHook'),
    sync=dict(type='mmengine.hooks.SyncBuffersHook'),
    timer=dict(type='mmengine.hooks.IterTimerHook'),
    visualization=dict(
        enable=False, type='mmcls.engine.hooks.VisualizationHook'))
default_scope = None
env_cfg = dict(
    cudnn_benchmark=False,
    dist_cfg=dict(backend='nccl'),
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
float_checkpoint = '/workspace/mmlab/MMR/checkpoints/resnet18_8xb32_in1k_20210831-fbbb1da6.pth'
global_qconfig = dict(
    a_fake_quant=dict(type='FakeQuantize'),
    a_observer=dict(type='MovingAverageMinMaxObserver'),
    a_qscheme=dict(bit=8, is_symmetry=True, qdtype='quint8'),
    w_fake_quant=dict(type='FakeQuantize'),
    w_observer=dict(type='MovingAverageMinMaxObserver'),
    w_qscheme=dict(
        bit=8, is_symmetric_range=True, is_symmetry=True, qdtype='qint8'))
launcher = 'none'
load_from = None
log_level = 'INFO'
model = dict(
    _scope_='mmrazor',
    architecture=dict(
        backbone=dict(
            depth=18,
            num_stages=4,
            out_indices=(3, ),
            style='pytorch',
            type='mmcls.models.backbones.ResNet'),
        head=dict(
            in_channels=512,
            loss=dict(
                loss_weight=1.0, type='mmcls.models.losses.CrossEntropyLoss'),
            num_classes=1000,
            topk=(
                1,
                5,
            ),
            type='mmcls.models.heads.LinearClsHead'),
        neck=dict(type='mmcls.models.necks.GlobalAveragePooling'),
        type='mmcls.models.classifiers.ImageClassifier'),
    data_preprocessor=dict(
        mean=[
            123.675,
            116.28,
            103.53,
        ],
        num_classes=1000,
        std=[
            58.395,
            57.12,
            57.375,
        ],
        to_rgb=True,
        type='mmcls.models.utils.ClsDataPreprocessor'),
    float_checkpoint=
    '/workspace/mmlab/MMR/checkpoints/resnet18_8xb32_in1k_20210831-fbbb1da6.pth',
    quantizer=dict(
        global_qconfig=dict(
            a_fake_quant=dict(type='FakeQuantize'),
            a_observer=dict(type='MovingAverageMinMaxObserver'),
            a_qscheme=dict(bit=8, is_symmetry=True, qdtype='quint8'),
            w_fake_quant=dict(type='FakeQuantize'),
            w_observer=dict(type='MovingAverageMinMaxObserver'),
            w_qscheme=dict(
                bit=8,
                is_symmetric_range=True,
                is_symmetry=True,
                qdtype='qint8')),
        tracer=dict(
            skipped_methods=[
                'mmcls.models.heads.cls_head.ClsHead._get_loss',
                'mmcls.models.heads.cls_head.ClsHead._get_predictions',
            ],
            type='mmrazor.models.task_modules.tracer.fx.CustomTracer'),
        type='mmrazor.models.quantizers.TensorRTQuantizer'),
    type='mmrazor.models.algorithms.quantization.MMArchitectureQuant')
model_wrapper_cfg = dict(
    broadcast_buffers=False,
    find_unused_parameters=False,
    type='mmrazor.models.algorithms.quantization.MMArchitectureQuantDDP')
optim_wrapper = dict(
    clip_grad=None,
    optimizer=dict(
        lr=0.0001, momentum=0.9, type='torch.optim.SGD', weight_decay=0.0001),
    type='mmengine.optim.optimizer.optimizer_wrapper.OptimWrapper')
param_scheduler = dict(
    by_epoch=True,
    factor=1.0,
    type='mmengine.optim.scheduler.lr_scheduler.ConstantLR')
resnet18_model = dict(
    backbone=dict(
        depth=18,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch',
        type='mmcls.models.backbones.ResNet'),
    head=dict(
        in_channels=512,
        loss=dict(
            loss_weight=1.0, type='mmcls.models.losses.CrossEntropyLoss'),
        num_classes=1000,
        topk=(
            1,
            5,
        ),
        type='mmcls.models.heads.LinearClsHead'),
    neck=dict(type='mmcls.models.necks.GlobalAveragePooling'),
    type='mmcls.models.classifiers.ImageClassifier')
resume = False
test_cfg = dict(type='mmrazor.engine.runner.QATValLoop')
test_dataloader = dict(
    batch_size=32,
    dataset=dict(
        ann_file='meta/val.txt',
        data_root='/home/XXX/val/Data',
        pipeline=[
            dict(type='mmcv.transforms.loading.LoadImageFromFile'),
            dict(
                edge='short',
                scale=256,
                type='mmcls.datasets.transforms.ResizeEdge'),
            dict(crop_size=224, type='mmcv.transforms.processing.CenterCrop'),
            dict(type='mmcls.datasets.transforms.PackClsInputs'),
        ],
        type='mmcls.datasets.ImageNet'),
    num_workers=5,
    persistent_workers=True,
    sampler=dict(
        shuffle=False, type='mmengine.dataset.sampler.DefaultSampler'))
test_evaluator = dict(
    topk=(
        1,
        5,
    ), type='mmcls.evaluation.metrics.Accuracy')
test_pipeline = [
    dict(type='mmcv.transforms.loading.LoadImageFromFile'),
    dict(edge='short', scale=256, type='mmcls.datasets.transforms.ResizeEdge'),
    dict(crop_size=224, type='mmcv.transforms.processing.CenterCrop'),
    dict(type='mmcls.datasets.transforms.PackClsInputs'),
]
train_cfg = dict(
    max_epochs=10,
    type='mmrazor.engine.runner.QATEpochBasedLoop',
    val_interval=10)
train_dataloader = dict(
    batch_size=32,
    dataset=dict(
        ann_file='meta/train.txt',
        data_root='/home/XXX/train/Data',
        pipeline=[
            dict(type='mmcv.transforms.loading.LoadImageFromFile'),
            dict(
                scale=224, type='mmcls.datasets.transforms.RandomResizedCrop'),
            dict(
                direction='horizontal',
                prob=0.5,
                type='mmcv.transforms.processing.RandomFlip'),
            dict(type='mmcls.datasets.transforms.PackClsInputs'),
        ],
        type='mmcls.datasets.ImageNet'),
    num_workers=5,
    persistent_workers=True,
    sampler=dict(shuffle=True, type='mmengine.dataset.sampler.DefaultSampler'))
train_pipeline = [
    dict(type='mmcv.transforms.loading.LoadImageFromFile'),
    dict(scale=224, type='mmcls.datasets.transforms.RandomResizedCrop'),
    dict(
        direction='horizontal',
        prob=0.5,
        type='mmcv.transforms.processing.RandomFlip'),
    dict(type='mmcls.datasets.transforms.PackClsInputs'),
]
val_cfg = dict(type='mmrazor.engine.runner.QATValLoop')
val_dataloader = dict(
    batch_size=32,
    dataset=dict(
        ann_file='meta/val.txt',
        data_root='/home/XXX/val/Data',
        pipeline=[
            dict(type='mmcv.transforms.loading.LoadImageFromFile'),
            dict(
                edge='short',
                scale=256,
                type='mmcls.datasets.transforms.ResizeEdge'),
            dict(crop_size=224, type='mmcv.transforms.processing.CenterCrop'),
            dict(type='mmcls.datasets.transforms.PackClsInputs'),
        ],
        type='mmcls.datasets.ImageNet'),
    num_workers=5,
    persistent_workers=True,
    sampler=dict(
        shuffle=False, type='mmengine.dataset.sampler.DefaultSampler'))
val_evaluator = dict(
    topk=(
        1,
        5,
    ), type='mmcls.evaluation.metrics.Accuracy')
vis_backends = [
    dict(type='mmengine.visualization.vis_backend.LocalVisBackend'),
    dict(type='mmengine.visualization.vis_backend.TensorboardVisBackend'),
]
visualizer = dict(
    type='mmcls.visualization.ClsVisualizer',
    vis_backends=[
        dict(type='mmengine.visualization.vis_backend.LocalVisBackend'),
        dict(type='mmengine.visualization.vis_backend.TensorboardVisBackend'),
    ])
work_dir = '/workspace/mmlab/MMR/qat/cls/checkpoints/'
  1. Your train log file if you meet the problem during training.

When printing self.qmodels['tensor'].head, should I see the observers for the activation, isn't it? Instead I have:

Module(
  (fc): Linear(
    in_features=512, out_features=1000, bias=True
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([0], device='cuda:0', dtype=torch.uint8), quant_min=-127, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([0.0051], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=-0.3029666841030121, max_val=0.6515278220176697)
    )
  )
)

Also in the get_deploy_model{ method of MMArchitectureQuant, I see that I should have this activation_post_process` blocks when printing the node names. However, I have this instead:

inputs
data_samples
mode_1
eq
_assert
backbone_conv1
backbone_maxpool
backbone_layer1_0_conv1
backbone_layer1_0_conv2
backbone_layer1_0_drop_path
add
backbone_layer1_0_relu_1
backbone_layer1_1_conv1
backbone_layer1_1_conv2
backbone_layer1_1_drop_path
add_1
backbone_layer1_1_relu_1
backbone_layer2_0_conv1
backbone_layer2_0_conv2
backbone_layer2_0_downsample_0
backbone_layer2_0_drop_path
add_2
backbone_layer2_0_relu_1
backbone_layer2_1_conv1
backbone_layer2_1_conv2
backbone_layer2_1_drop_path
add_3
backbone_layer2_1_relu_1
backbone_layer3_0_conv1
backbone_layer3_0_conv2
backbone_layer3_0_downsample_0
backbone_layer3_0_drop_path
add_4
backbone_layer3_0_relu_1
backbone_layer3_1_conv1
backbone_layer3_1_conv2
backbone_layer3_1_drop_path
add_5
backbone_layer3_1_relu_1
backbone_layer4_0_conv1
backbone_layer4_0_conv2
backbone_layer4_0_downsample_0
backbone_layer4_0_drop_path
add_6
backbone_layer4_0_relu_1
backbone_layer4_1_conv1
backbone_layer4_1_conv2
backbone_layer4_1_drop_path
add_7
backbone_layer4_1_relu_1
neck_gap
size
view
head_fc
output
@Veccoy
Copy link
Author

Veccoy commented Jun 21, 2024

The issue was coming from TensorRT. It seems that this backend doesn't support activation quantization for some layers, such as maxpooling or add. When preparing the model using TensorRTQuantizer, the prepare() torch function (1.13.1 version) always return False when checking if the qconfig is supported by the backend for such layers (L1255). This results in no quantization for activations at all when checking the onnx file...

Switching to OpenVINOQuantizer solved the problem

@Veccoy Veccoy closed this as completed Jun 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant