In [None]:
#!/usr/bin/env python3
import os
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.quantization as tq
import torch.nn.intrinsic as nni
from torch.quantization.observer import HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver
from torch.quantization import QConfig
from torch.utils.data import DataLoader
from torchmetrics import JaccardIndex
import albumentations as A
from hydra import initialize, compose
from pathlib import Path
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


In [2]:
# ────────────────────────────────────────────────────────────────────────────────
# Make sure your project root is in PYTHONPATH so we can import models & datasets
cur_dir     = Path.cwd()
project_dir = cur_dir.parent
sys.path.append(str(project_dir))

from models.DinoFPNbn import DinoFPN
from data.kitti360.dataset import KittiSemSegDataset
from data.kitti360.labels_kitti360 import NUM_CLASSES
# ────────────────────────────────────────────────────────────────────────────────

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    print("CUDA is available.")
else:
    print("CUDA is not available, using CPU.")

CUDA is available.


In [3]:
# Check if hardward supports activation quantization
print((torch.cuda.get_device_properties(device).major, torch.cuda.get_device_properties(device).minor))
print(torch.backends.quantized.supported_engines)

(7, 5)
['qnnpack', 'none', 'onednn', 'x86', 'fbgemm']


In [4]:
# Load Hydra config
with initialize(version_base=None, config_path=f"../configs", job_name="quant_static_ptq"):
    cfg = compose(config_name="quant_config")

In [5]:
fp32_model = DinoFPN(num_labels=cfg.dataset.num_classes, model_cfg=cfg.model)
fp32_model

DinoFPN(
  (backbone): Dinov2Model(
    (embeddings): Dinov2Embeddings(
      (patch_embeddings): Dinov2PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Dinov2Encoder(
      (layer): ModuleList(
        (0-11): 12 x Dinov2Layer(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attention): Dinov2Attention(
            (attention): Dinov2SelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): Dinov2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (layer_scale1): Dinov2LayerScale()
          

In [6]:
fp32_mem_footprint = fp32_model.get_memory_footprint(detailed=True)

=== Model Memory Footprint ===
Backbone: 86,580,480 params, 330.28 MB
Head:     3,747,617 params, 14.30 MB
Total:    90,328,097 params, 344.57 MB


In [None]:
# ────────────────────────────────────────────────────────────────────────────────
# 1) Build a wrapper around the FP32 DinoFPN so we can insert QuantStubs and DeQuantStubs.
#    Because FPNHead now uses Conv2d → BatchNorm2d → ReLU, PyTorch’s fuse_modules will handle:
#       Conv2d + BatchNorm2d + ReLU → FusedConvBnRelu
# ────────────────────────────────────────────────────────────────────────────────
class QuantDinoFPN(nn.Module):
    def __init__(self, num_labels: int, model_cfg):
        super().__init__()
        # 1a) QuantStub to quantize input activations
        self.quant = tq.QuantStub()

        # 1b) The original FP32 DinoFPN (with BatchNorm)
        self.fp32_model = DinoFPN(num_labels=num_labels, model_cfg=model_cfg)
                
        # 1c) DeQuantStub to convert final output back to FP32
        self.dequant = tq.DeQuantStub()

    def process(self, images):
        return self.fp32_model.process(images)

    def forward(self, images):
        # 1) Quantize input: attaches observers to measure activation ranges
        x = self.quant(images)

        # 2) Forward through original FP32 model
        logits = self.fp32_model(x)

        # 3) Dequantize output (brings quantized int8 result back to float)
        out = self.dequant(logits)
        return out

    def fuse_model(self):
        """
        Fuse Conv2d + BatchNorm2d + ReLU sequences, wherever they appear, across the entire network.
        Because your FPNHead modules use exactly that pattern, this will fuse:
            - Each proj: Conv2d → BatchNorm2d → ReLU → Dropout2d (Dropout is skipped in fusion)
            - The fuse block: Conv2d → BatchNorm2d → ReLU → Dropout2d
            - The classifier: Conv2d → BatchNorm2d → ReLU → Dropout2d → Conv2d  (we can fuse up to ReLU)
        PyTorch automatically handles fusing only the Conv-BN-ReLU parts, leaving dropout alone.
        """
        # Fuse in the head
        head = self.fp32_model.head

        # 1) Fuse each 1×1 proj: Conv → BatchNorm → ReLU
        for idx, proj in enumerate(head.projs):
            # proj is nn.Sequential([Conv2d, BatchNorm2d, ReLU, Dropout2d])
            torch.quantization.fuse_modules(proj,
                                           ["0", "1", "2"],  # fuse conv (idx 0), bn (idx 1), relu (idx 2)
                                           inplace=True)

        # 2) Fuse the fuse-block: Conv → BatchNorm → ReLU
        torch.quantization.fuse_modules(head.fuse,
                                       ["0", "1", "2"],  # indices: 0=Conv2d, 1=BatchNorm2d, 2=ReLU
                                       inplace=True)

        # 3) Fuse the first part of classifier: Conv → BatchNorm → ReLU
        #    classifier = nn.Sequential([Conv2d, BatchNorm2d, ReLU, Dropout2d, Conv2d])
        torch.quantization.fuse_modules(head.classifier,
                                       ["0", "1", "2"],  # fuse conv(0), bn(1), relu(2)
                                       inplace=True)
        # The final Conv2d (index 4) cannot be fused further, since there's no BatchNorm or ReLU after.

        print("[QuantDinoFPN] Fused all Conv2d+BatchNorm2d+ReLU sequences.")

In [8]:
# 4b) Instantiate QuantDinoFPN
quant_model = QuantDinoFPN(num_labels=cfg.dataset.num_classes, model_cfg=cfg.model)
quant_model.to(torch.device("cpu"))
print("[QuantPTQ] Created QuantDinoFPN.")

[QuantPTQ] Created QuantDinoFPN.


In [9]:
quant_model

QuantDinoFPN(
  (quant): QuantStub()
  (fp32_model): DinoFPN(
    (backbone): Dinov2Model(
      (embeddings): Dinov2Embeddings(
        (patch_embeddings): Dinov2PatchEmbeddings(
          (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
        )
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): Dinov2Encoder(
        (layer): ModuleList(
          (0-11): 12 x Dinov2Layer(
            (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
            (attention): Dinov2Attention(
              (attention): Dinov2SelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
              )
              (output): Dinov2SelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.0, inp

In [11]:
# ────────────────────────────────────────────────────────────────────────────────
# 4) Main script: load FP32 checkpoint → wrap in QuantDinoFPN → fuse → set qconfig
#    → prepare → calibrate → convert → evaluate → save INT8 weights
# ────────────────────────────────────────────────────────────────────────────────

# 4a) Build a validation loader (no augmentations, just center-crop)
crop_h, crop_w = (cfg.augmentation.crop_height, cfg.augmentation.crop_width)

val_transform = A.Compose([
    A.CenterCrop(crop_h, crop_w)
])

val_dataset = KittiSemSegDataset(
    root_dir='/home/panos/Documents/data/kitti-360',
    train=False,
    transform=val_transform
)
val_loader = DataLoader(
    val_dataset,
    batch_size=cfg.train.batch_size,
    shuffle=False,
    num_workers=cfg.dataset.num_workers,
    pin_memory=True,
)
print(f"[QuantPTQ] Validation dataset size: {len(val_dataset)}")

cal_dataset = KittiSemSegDataset(
    root_dir='/home/panos/Documents/data/kitti-360',
    train=True,
    calibration=True,
    transform=val_transform
)
cal_loader = DataLoader(
    cal_dataset,
    batch_size=cfg.train.batch_size,
    shuffle=True,
    num_workers=cfg.dataset.num_workers,
    pin_memory=True,
)
print(f"[QuantPTQ] Calibration dataset size: {len(cal_dataset)}")


[QuantPTQ] Validation dataset size: 783
[QuantPTQ] Calibration dataset size: 35


In [12]:
# 4c) Load your best FP32 checkpoint into `quant_model.fp32_model`
ckpt_path = os.path.join(project_dir, "checkpoints", f"{cfg.checkpoint.model_name}.pth")
if not os.path.exists(ckpt_path):
    print(f"[Error] Checkpoint not found: {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu")
quant_model.fp32_model.load_state_dict(checkpoint["model_state_dict"])
print(f"[QuantPTQ] Loaded FP32 weights from {ckpt_path} into fp32_model.")

[QuantPTQ] Loaded FP32 weights from /home/panos/dev/hf_seg/checkpoints/dino-fpn-quantized.pth into fp32_model.


In [13]:
# 4d) Fuse Conv2d + BatchNorm2d + ReLU sequences in the head
quant_model.eval()
quant_model.fuse_model()

[QuantDinoFPN] Fused all Conv2d+BatchNorm2d+ReLU sequences.


In [None]:
# 4e) Assign QConfig: per‐channel weight quant + HistogramObserver (KL) activations
#     - get_default_qconfig("fbgemm") already uses PerChannelMinMaxObserver for weights
#     - We swap the activation observer to HistogramObserver (KL)
# - Build a per-channel symmetric observer for weights
weight_obs  = PerChannelMinMaxObserver.with_args(
    dtype=torch.qint8,
    qscheme=torch.per_channel_symmetric,
    ch_axis=0
)

# - Build a histogram (KL) observer for activations
activation_obs = HistogramObserver.with_args(
    dtype=torch.quint8,
    qscheme=torch.per_tensor_affine,
    reduce_range=False
)

# - Compose them into a QConfig
custom_qconfig = QConfig(
    activation=activation_obs,
    weight=weight_obs
)
quant_model.qconfig = custom_qconfig
# quant_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
quant_model.qconfig = torch.quantization.default_qconfig('x86')
print(quant_model.qconfig)
print("[QuantPTQ] Assigned QConfig.")

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})
[QuantPTQ] Assigned QConfig.


In [None]:
# exclude_problematic_modules(model)
# Exclude embeddings
embeddings = quant_model.fp32_model.backbone.embeddings
embeddings.qconfig = None
for name, module in embeddings.named_modules():
    module.qconfig = None
print("Excluded Dinov2Embeddings from quantization")

# Exclude LayerNorm modules
layernorm_count = 0
for name, module in quant_model.named_modules():
    if isinstance(module, torch.nn.LayerNorm):
        module.qconfig = None
        layernorm_count += 1
        print(f"Excluded LayerNorm: {name}")

print(f"Excluded {layernorm_count} LayerNorm modules from quantization")

In [15]:
# 4f) Prepare for static quantization (inserts observers)
dev = next(quant_model.parameters()).device
assert dev.type == "cpu"
tq.prepare(quant_model, inplace=True)
quant_model

QuantDinoFPN(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (fp32_model): DinoFPN(
    (backbone): Dinov2Model(
      (embeddings): Dinov2Embeddings(
        (patch_embeddings): Dinov2PatchEmbeddings(
          (projection): Conv2d(
            3, 768, kernel_size=(14, 14), stride=(14, 14)
            (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
          )
        )
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): Dinov2Encoder(
        (layer): ModuleList(
          (0-11): 12 x Dinov2Layer(
            (norm1): LayerNorm(
              (768,), eps=1e-06, elementwise_affine=True
              (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
            )
            (attention): Dinov2Attention(
              (attention): Dinov2SelfAttention(
                (query): Linear(
                  in_features=768, out_features=768, bias=True
                  (activa

In [None]:
# 4g) Calibrate on ~500 validation batches
quant_model.eval()
with torch.no_grad():
    cal_bar = tqdm(cal_loader, desc="Calibrating")
    for batch_idx, (imgs, _) in enumerate(cal_bar, start=1):
        imgs = imgs.permute(0, 3, 1, 2).to(device) # [B, C, H, W]
        input = quant_model.fp32_model.process(imgs)
        _ = quant_model(input)  # Forward pass only: observers record histograms
print(f"[Calibrate] Completed {batch_idx} batches of calibration.")

In [18]:
# 4h) Convert to INT8 (replace FP32 modules with quantized kernels)
quant_model.eval()
tq.convert(quant_model, inplace=True)



QuantDinoFPN(
  (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
  (fp32_model): DinoFPN(
    (backbone): Dinov2Model(
      (embeddings): Dinov2Embeddings(
        (patch_embeddings): Dinov2PatchEmbeddings(
          (projection): QuantizedConv2d(3, 768, kernel_size=(14, 14), stride=(14, 14), scale=1.0, zero_point=0)
        )
        (dropout): QuantizedDropout(p=0.0, inplace=False)
      )
      (encoder): Dinov2Encoder(
        (layer): ModuleList(
          (0-11): 12 x Dinov2Layer(
            (norm1): QuantizedLayerNorm((768,), eps=1e-06, elementwise_affine=True)
            (attention): Dinov2Attention(
              (attention): Dinov2SelfAttention(
                (query): QuantizedLinear(in_features=768, out_features=768, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
                (key): QuantizedLinear(in_features=768, out_features=768, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
                (value): Quantized

In [20]:
# 4i) Evaluate mIoU of the INT8 model
with torch.no_grad():
    quant_model.eval()
    miou_metric = JaccardIndex(
        task="multiclass",
        num_classes=NUM_CLASSES,
        average="micro",
        ignore_index=255,
    )

    for imgs, masks in tqdm(val_loader, desc="[Eval]"):
        imgs  = imgs.permute(0, 3, 1, 2).float()
        input = quant_model.process(imgs)
        print(f"Input type: {input.dtype}, shape: {input.shape}")
        logits = quant_model(input)           # [B, num_classes, H, W]
        preds  = torch.argmax(logits, dim=1)  # [B, H, W]
        miou_metric.update(preds, masks)

    int8_miou = miou_metric.compute().item()
print(f"[QuantPTQ] INT8‐quantized model mIoU = {int8_miou:.4f}")

[Eval]:   0%|          | 0/66 [00:00<?, ?it/s]

Input type: torch.float32, shape: torch.Size([12, 3, 364, 1232])
[QuantDinoFPN] Forwarding through the quantized model...
[QuantDinoFPN] Input quantized, shape: torch.Size([12, 3, 364, 1232]), dtype: torch.quint8
Quantized model detected - skipping dtype conversion
Embeddings type: torch.quint8
Embeddings dequantized type: torch.float32
Embeddings type: torch.float32
CLS tokens type: torch.float32


[Eval]:   0%|          | 0/66 [00:01<?, ?it/s]


NotImplementedError: Could not run 'quantized::layer_norm' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::layer_norm' is only available for these backends: [Meta, QuantizedCPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMTIA, AutogradMeta, Tracer, AutocastCPU, AutocastMTIA, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Meta: registered at /pytorch/aten/src/ATen/core/MetaFallbackKernel.cpp:23 [backend fallback]
QuantizedCPU: registered at /pytorch/aten/src/ATen/native/quantized/cpu/qnormalization.cpp:132 [kernel]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:194 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:479 [backend fallback]
Functionalize: registered at /pytorch/aten/src/ATen/FunctionalizeFallbackKernel.cpp:349 [backend fallback]
Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /pytorch/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at /pytorch/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at /pytorch/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:100 [backend fallback]
AutogradOther: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:63 [backend fallback]
AutogradCPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:67 [backend fallback]
AutogradCUDA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:75 [backend fallback]
AutogradXLA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:83 [backend fallback]
AutogradMPS: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:91 [backend fallback]
AutogradXPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:71 [backend fallback]
AutogradHPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:104 [backend fallback]
AutogradLazy: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:87 [backend fallback]
AutogradMTIA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:79 [backend fallback]
AutogradMeta: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:95 [backend fallback]
Tracer: registered at /pytorch/torch/csrc/autograd/TraceTypeManual.cpp:294 [backend fallback]
AutocastCPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:322 [backend fallback]
AutocastMTIA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:466 [backend fallback]
AutocastXPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:504 [backend fallback]
AutocastMPS: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:209 [backend fallback]
AutocastCUDA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:165 [backend fallback]
FuncTorchBatched: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:731 [backend fallback]
BatchedNestedTensor: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:758 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp:27 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /pytorch/aten/src/ATen/functorch/TensorWrapper.cpp:208 [backend fallback]
PythonTLSSnapshot: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:202 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:475 [backend fallback]
PreDispatch: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:206 [backend fallback]
PythonDispatcher: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:198 [backend fallback]


In [None]:
# 4j) Save the INT8 state_dict
save_path = os.path.join(project_dir, "checkpoints", "dinofpn_int8.pth")
torch.save(quant_model.state_dict(), save_path)
print(f"[QuantPTQ] Saved INT8 weights to {save_path}")