In [1]:
import os
import sys
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from torchmetrics import JaccardIndex
import albumentations as A
from hydra import initialize, compose
from pathlib import Path
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
"""
Achieves 79.3% mIoU on the validation set of the Kitti-360 dataset.
"""

In [2]:
# ────────────────────────────────────────────────────────────────────────────────
# Make sure your project root is in PYTHONPATH so we can import models & datasets
cur_dir     = Path.cwd()
project_dir = cur_dir.parent
sys.path.append(str(project_dir))

from models.DinoFPNbn import DinoFPN
from data.dataset import KittiSemSegDataset
from data.labels_kitti360 import NUM_CLASSES
# ────────────────────────────────────────────────────────────────────────────────

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    print("CUDA is available.")
else:
    print("CUDA is not available, using CPU.")

CUDA is available.


In [3]:
# Check if hardward supports activation quantization
print((torch.cuda.get_device_properties(device).major, torch.cuda.get_device_properties(device).minor))
print(torch.backends.quantized.supported_engines)

(7, 5)
['qnnpack', 'none', 'onednn', 'x86', 'fbgemm']


In [4]:
# Load Hydra config
with initialize(version_base=None, config_path=f"../configs", job_name="quant_static_ptq"):
    cfg = compose(config_name="quant_config")

In [5]:
fp32_model = DinoFPN(num_labels=cfg.dataset.num_classes, model_cfg=cfg.model)
fp32_model.eval()
fp32_model

DinoFPN(
  (backbone): Dinov2Model(
    (embeddings): Dinov2Embeddings(
      (patch_embeddings): Dinov2PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Dinov2Encoder(
      (layer): ModuleList(
        (0-11): 12 x Dinov2Layer(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attention): Dinov2Attention(
            (attention): Dinov2SelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): Dinov2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (layer_scale1): Dinov2LayerScale()
          

In [6]:
fp32_mem_footprint = fp32_model.get_memory_footprint(detailed=True)

=== Model Memory Footprint ===
Backbone: 86,580,480 params, 330.28 MB
Head:     3,747,617 params, 14.30 MB
Total:    90,328,097 params, 344.57 MB


In [7]:
# from torch.ao.quantization import fuse_modules
# 2. (Optional) Fuse modules manually if you want more control
# fuse_modules(model, [['features.0', 'features.1', 'features.2']], inplace=True)

In [None]:
from torch.ao.quantization import get_default_qconfig, QConfig
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.observer import MinMaxObserver, HistogramObserver, PerChannelMinMaxObserver

# Create a qconfig map: specify which ops use which quantization configs
#    Here we use the default fbgemm backend config for all ops.
activation = HistogramObserver.with_args(
    quant_min=0,
    quant_max=255,
    dtype=torch.quint8,
    qscheme=torch.per_tensor_affine
)
weight = PerChannelMinMaxObserver.with_args(
    quant_min=-128,
    quant_max=127,
    dtype=torch.qint8,
    qscheme=torch.per_channel_symmetric,
    ch_axis=0  # Channel axis (0 for Conv2d output channels, 1 for Linear input features)
)
custom_qconfig = QConfig(activation=activation, weight=weight)

# custom_qconfig = get_default_qconfig("x86")
qconfig_map = (
    QConfigMapping()
    .set_global(custom_qconfig)                   # applies to all modules by default
    # .set_module_name("backbone.embeddings", None)  # disable quant for embeddings
    .set_module_name("backbone.embeddings.patch_embeddings", None)
    .set_module_name("backbone.embeddings.dropout", None)
    # Exclude operations that return non-tensor objects
    .set_object_type("size", None)
    .set_object_type("view", None)
    .set_object_type("reshape", None)
    .set_object_type("permute", None)
    .set_object_type(torch.Tensor.size, None)
    .set_object_type(torch.Tensor.view, None)
    .set_object_type(torch.Tensor.reshape, None)
)

from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
from transformers.models.dinov2.modeling_dinov2 import Dinov2PatchEmbeddings, Dinov2Embeddings

# Tell FX to treat Dinov2PatchEmbeddings as non-traceable
## “Whenever you hit SomeModule in the model, don’t open it up and record its internal steps. 
## Instead just treat the whole call as a single step in the recipe
prepare_custom_config = (
    PrepareCustomConfig()
    .set_non_traceable_module_classes([
        Dinov2PatchEmbeddings, 
        Dinov2Embeddings
    ])
)

In [9]:
custom_qconfig

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})

In [9]:
# Load your best FP32 checkpoint into `quant_model.fp32_model`
ckpt_path = os.path.join(project_dir, "checkpoints", f"{cfg.checkpoint.model_name}.pth")
if not os.path.exists(ckpt_path):
    print(f"[Error] Checkpoint not found: {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu")
fp32_model.load_state_dict(checkpoint["model_state_dict"])
print(f"Loaded FP32 weights from {ckpt_path} into fp32_model.")

Loaded FP32 weights from /home/panos/dev/hf_seg/checkpoints/dino-fpn-bn.pth into fp32_model.


In [10]:
from torch.ao.quantization.quantize_fx import prepare_fx

# 4. Prepare the model for static quantization
example_inputs = (torch.randn(1, 3, 364, 1232),)
prep_model = prepare_fx(
    fp32_model, 
    qconfig_map, 
    example_inputs,
    prepare_custom_config=prepare_custom_config
)
prep_model

GraphModule(
  (backbone): Module(
    (embeddings): Dinov2Embeddings(
      (patch_embeddings): Dinov2PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Module(
      (layer): Module(
        (0): Module(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attention): Module(
            (attention): Module(
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (query): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): Module(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (layer_scale1): Module()
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e

In [11]:
def print_qconfig_status(model):
    """Print qconfig status for all modules"""
    print("\n=== QConfig Status for All Modules ===")
    
    quantized_modules = []
    excluded_modules = []
    
    for name, module in model.named_modules():
        qconfig = getattr(module, 'qconfig', 'not_set')
        
        if qconfig is None:
            excluded_modules.append(name)
            print(f"{name:60s} → EXCLUDED (qconfig=None)")
        elif qconfig == 'not_set':
            print(f"{name:60s} → NO QCONFIG")
        else:
            quantized_modules.append(name)
            print(f"{name:60s} → QUANTIZED ({type(qconfig).__name__})")
    
    print(f"\n📊 Summary:")
    print(f"   Quantized modules: {len(quantized_modules)}")
    print(f"   Excluded modules:  {len(excluded_modules)}")
    
    return quantized_modules, excluded_modules

# Use it after prepare_fx
quantized_modules, excluded_modules = print_qconfig_status(prep_model)


=== QConfig Status for All Modules ===
                                                             → NO QCONFIG
backbone                                                     → NO QCONFIG
backbone.embeddings                                          → QUANTIZED (QConfig)
backbone.embeddings.patch_embeddings                         → EXCLUDED (qconfig=None)
backbone.embeddings.patch_embeddings.projection              → EXCLUDED (qconfig=None)
backbone.embeddings.dropout                                  → EXCLUDED (qconfig=None)
backbone.encoder                                             → NO QCONFIG
backbone.encoder.layer                                       → NO QCONFIG
backbone.encoder.layer.0                                     → NO QCONFIG
backbone.encoder.layer.0.norm1                               → QUANTIZED (QConfig)
backbone.encoder.layer.0.attention                           → NO QCONFIG
backbone.encoder.layer.0.attention.attention                 → NO QCONFIG
backbone.encode

In [12]:
print(prep_model.graph)

graph():
    %images : [num_users=2] = placeholder[target=images]
    %getattr_1 : [num_users=4] = call_function[target=builtins.getattr](args = (%images, shape), kwargs = {})
    %getitem : [num_users=0] = call_function[target=operator.getitem](args = (%getattr_1, 0), kwargs = {})
    %getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%getattr_1, 1), kwargs = {})
    %getitem_2 : [num_users=5] = call_function[target=operator.getitem](args = (%getattr_1, 2), kwargs = {})
    %getitem_3 : [num_users=5] = call_function[target=operator.getitem](args = (%getattr_1, 3), kwargs = {})
    %backbone_embeddings : [num_users=1] = call_module[target=backbone.embeddings](args = (%images,), kwargs = {bool_masked_pos: None})
    %activation_post_process_0 : [num_users=2] = call_module[target=activation_post_process_0](args = (%backbone_embeddings,), kwargs = {})
    %backbone_encoder_layer_0_norm1 : [num_users=1] = call_module[target=backbone.encoder.layer.0.norm1](args = (%

In [13]:
print(prep_model.graph.print_tabular())

opcode         name                                                 target                                                  args                                                                                                                     kwargs
-------------  ---------------------------------------------------  ------------------------------------------------------  -----------------------------------------------------------------------------------------------------------------------  ------------------------------------------------------------------------------------------------------------------------------------------------------
placeholder    images                                               images                                                  ()                                                                                                                       {}
call_function  getattr_1                                            <built-in function getattr>         

In [14]:
# from torch.fx.passes.graph_drawer import FxGraphDrawer

# # 4. Render to PNG
# drawer = FxGraphDrawer(prep_model, "prep_model")
# dot = drawer.get_dot_graph()
# with open(project_dir / "assets" / "prep_model.png", "wb") as f:
#     f.write(dot.create_png())

In [15]:
# Build a validation/calibration loader (no augmentations, just center-crop)
crop_h, crop_w = (cfg.augmentation.crop_height, cfg.augmentation.crop_width)

val_transform = A.Compose([
    A.CenterCrop(crop_h, crop_w)
])

val_dataset = KittiSemSegDataset(
    root_dir='/home/panos/Documents/data/kitti-360',
    train=False,
    transform=val_transform
)
val_loader = DataLoader(
    val_dataset,
    batch_size=cfg.train.batch_size,
    shuffle=False,
    num_workers=cfg.dataset.num_workers,
    pin_memory=True,
)
print(f"[QuantPTQ] Validation dataset size: {len(val_dataset)}")

cal_dataset = KittiSemSegDataset(
    root_dir='/home/panos/Documents/data/kitti-360',
    train=True,
    calibration=True,
    transform=val_transform
)
cal_loader = DataLoader(
    cal_dataset,
    batch_size=cfg.train.batch_size,
    shuffle=True,
    num_workers=cfg.dataset.num_workers,
    pin_memory=True,
)
print(f"[QuantPTQ] Calibration dataset size: {len(cal_dataset)}")


[QuantPTQ] Validation dataset size: 783
[QuantPTQ] Calibration dataset size: 704


In [16]:
prep_model = prep_model.to(device)

# Calibrate on ~500 validation batches
with torch.no_grad():
    cal_bar = tqdm(cal_loader, desc="Calibrating")
    for batch_idx, (imgs, _) in enumerate(cal_bar, start=1):
        imgs = imgs.permute(0, 3, 1, 2) # [B, C, H, W]
        input = fp32_model.process(imgs).to(device)
        _ = prep_model(input)
print(f"Completed {batch_idx} batches of calibration.")

Calibrating: 100%|██████████| 59/59 [06:34<00:00,  6.69s/it]

Completed 59 batches of calibration.





In [17]:
from torch.ao.quantization.quantize_fx import convert_fx

# 6. Convert: swap out float ops for quantized kernels
prep_model = prep_model.to("cpu")
prep_model.eval()
quant_model = convert_fx(prep_model)

In [18]:
# Evaluate mIoU of the INT8 model
with torch.no_grad():
    quant_model.eval()
    miou_metric = JaccardIndex(
        task="multiclass",
        num_classes=NUM_CLASSES,
        average="micro",
        ignore_index=255,
    )

    for imgs, masks in tqdm(val_loader, desc="[Eval]"):
        imgs  = imgs.permute(0, 3, 1, 2).float()
        input = fp32_model.process(imgs)
        logits = quant_model(input)           # [B, num_classes, H, W]
        preds  = torch.argmax(logits, dim=1)  # [B, H, W]
        miou_metric.update(preds, masks)

    int8_miou = miou_metric.compute().item()
print(f"[QuantPTQ] INT8‐quantized model mIoU = {int8_miou:.4f}")

[Eval]: 100%|██████████| 66/66 [19:19<00:00, 17.57s/it]

[QuantPTQ] INT8‐quantized model mIoU = 0.7930





In [20]:
# 4j) Save the INT8 state_dict
save_path = os.path.join(project_dir, "checkpoints", f"{cfg.checkpoint.quant_model_name}.pth")
torch.save(quant_model.state_dict(), save_path)
print(f"[QuantPTQ] Saved INT8 weights to {save_path}")

[QuantPTQ] Saved INT8 weights to /home/panos/dev/hf_seg/checkpoints/dino-fpn-int8.pth


In [25]:
quant_model

GraphModule(
  (backbone): Module(
    (embeddings): Dinov2Embeddings(
      (patch_embeddings): Dinov2PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Module(
      (layer): Module(
        (0): Module(
          (norm1): QuantizedLayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attention): Module(
            (attention): Module(
              (key): QuantizedLinear(in_features=768, out_features=768, scale=0.07625889778137207, zero_point=135, qscheme=torch.per_channel_affine)
              (value): QuantizedLinear(in_features=768, out_features=768, scale=0.03319988399744034, zero_point=128, qscheme=torch.per_channel_affine)
              (query): QuantizedLinear(in_features=768, out_features=768, scale=0.07113160938024521, zero_point=122, qscheme=torch.per_channel_affine)
            )
            (output): Module(
              (dense): QuantizedLin

In [None]:
def get_quant_memory_footprint(model):
    """Calculate memory footprint for quantized models"""
    
    # Iterate over the model's modules
    total_params = 0
    total_bytes = 0
    for name, module in model.named_modules():
        # print(f"Module: {name}, Type: {type(module)}")
        
        if 'quantized' in str(type(module)).lower() and hasattr(module, 'weight') and callable(module.weight):
            # print("\t (quantized)")

            # Extract the weight
            w = module.weight()
            b = module.bias() if module.bias is not None else None

            # Extract the weight and bias parameters
            weight_bytes = w.numel() * w.element_size()
            bias_bytes = b.numel() * b.element_size() if b is not None else 0

            # Extract the scale and zero point parameters
            scale = w.q_per_channel_scales() if hasattr(w, 'q_per_channel_scales') else w.q_scale
            zero_point = w.q_per_channel_zero_points() if hasattr(w, 'q_per_channel_zero_points') else w.q_zero_point

            # Extract the scale and zero point sizes
            scale_bytes = scale.numel() * scale.element_size() if scale is not None else 0
            zero_point_bytes = zero_point.numel() * zero_point.element_size() if zero_point is not None else 0

            # Calculate total size in bytes
            bytes = weight_bytes + bias_bytes + scale_bytes + zero_point_bytes
            params = w.numel() + (b.numel() if b is not None else 0) + \
                    (scale.numel() if scale is not None else 0) + \
                    (zero_point.numel() if zero_point is not None else 0)

            # print(f"Module: {name}, Params: {params}, Size: {bytes / (1024**2):.2f} MB")

            # Add to total bytes and param count
            total_bytes += bytes
            total_params += params
        else:
            # print("\t (not quantized)")

            # If not quantized, just count the parameters
            params = sum(p.numel() for p in module.parameters())
            bytes = sum(p.numel() * p.element_size() for p in module.parameters())
            # print(f"Module: {name}, Params: {params}, Size: {bytes / (1024**2):.2f} MB")

            # Add to total bytes
            total_bytes += bytes
            total_params += params

    print(f"Total memory footprint: {total_params:,} params, {total_bytes / (1024**2):.2f} MB")
    
    return total_bytes

quant_mem_footprint = get_quant_memory_footprint(quant_model)

In [None]:
def get_memory_footprint(model):
    def get_module_size(module):
        """Helper to get size of a specific module"""
        total_bytes = 0
        for param in module.parameters():
            total_bytes += param.numel() * param.element_size() # number of elements * size of each element in bytes
        return total_bytes
    
    # Get sizes for each major component
    backbone_bytes = get_module_size(model.backbone)
    head_bytes = get_module_size(model.head)
    total_bytes = backbone_bytes + head_bytes
    
    backbone_params = sum(p.numel() for p in model.backbone.parameters())
    head_params = sum(p.numel() for p in model.head.parameters())
    
    print(f"=== Model Memory Footprint ===")
    print(f"Backbone: {backbone_params:,} params, {backbone_bytes / (1024**2):.2f} MB")
    print(f"Head:     {head_params:,} params, {head_bytes / (1024**2):.2f} MB")
    print(f"Total:    {backbone_params + head_params:,} params, {total_bytes / (1024**2):.2f} MB")
    
    return total_bytes

fp32_mem_footprint = get_memory_footprint(fp32_model)