In [1]:
import os
import sys
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from torchmetrics import JaccardIndex
import albumentations as A
from hydra import initialize, compose
from pathlib import Path
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ────────────────────────────────────────────────────────────────────────────────
# Make sure your project root is in PYTHONPATH so we can import models & datasets
cur_dir     = Path.cwd()
project_dir = cur_dir.parent
sys.path.append(str(project_dir))

from models.DinoFPNbn import DinoFPN
from data.dataset import KittiSemSegDataset
from data.labels_kitti360 import NUM_CLASSES
from utils.others import get_memory_footprint, get_quant_memory_footprint

# ────────────────────────────────────────────────────────────────────────────────

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    print("CUDA is available.")
else:
    print("CUDA is not available, using CPU.")

CUDA is available.


In [3]:
"""Achieves 83.94% mIoU on KITTI-360 validation set with QAT"""

'Achieves 83.94% mIoU on KITTI-360 validation set with QAT'

In [4]:
# Check if hardward supports activation quantization
print((torch.cuda.get_device_properties(device).major, torch.cuda.get_device_properties(device).minor))
print(torch.backends.quantized.supported_engines)

(7, 5)
['qnnpack', 'none', 'onednn', 'x86', 'fbgemm']


In [5]:
# Load Hydra config
with initialize(version_base=None, config_path=f"../configs", job_name="quant_static_ptq"):
    cfg = compose(config_name="qat_config")

In [7]:
fp32_model = DinoFPN(num_labels=cfg.dataset.num_classes, model_cfg=cfg.model)

In [8]:
fp32_mem_footprint = get_memory_footprint(fp32_model, detailed=True)

=== Model Memory Footprint ===
Backbone: 86,580,480 params, 330.28 MB
Head:     3,747,617 params, 14.30 MB
Total:    90,328,097 params, 344.57 MB


In [9]:
# from torch.ao.quantization import fuse_modules
# 2. (Optional) Fuse modules manually if you want more control
# fuse_modules(model, [['features.0', 'features.1', 'features.2']], inplace=True)

In [10]:
from torch.ao.quantization import get_default_qconfig, QConfig
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.observer import MovingAverageMinMaxObserver
from torch.ao.quantization.observer import PerChannelMinMaxObserver, HistogramObserver
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization._learnable_fake_quantize import _LearnableFakeQuantize as LearnableFakeQuantize
import torch.quantization as tq

# Build a MinMax observer for weights
weight_obs = PerChannelMinMaxObserver.with_args(
    dtype=torch.qint8,
    qscheme=torch.per_channel_symmetric,
    ch_axis=0,
    reduce_range=False,
)
learnable_qat_weight = LearnableFakeQuantize.with_args(
    observer=weight_obs,
    quant_min=-128,
    quant_max=127,
    dtype=torch.qint8,
    qscheme=torch.per_channel_symmetric,
    ch_axis=0,
    scale=1.0,        # Initial scale (will be learned)
    zero_point=0.0    # Initial zero point (will be learned)
)
qat_weight = FakeQuantize.with_args(
    observer=weight_obs
)

# - Build a histogram (KL) observer for activations
activation_obs = HistogramObserver.with_args(
    dtype=torch.quint8,
    qscheme=torch.per_tensor_affine,
    reduce_range=False
)

learnable_qat_activation = LearnableFakeQuantize.with_args(
    observer=activation_obs,
    quant_min=0,
    quant_max=255,
    dtype=torch.quint8,
    qscheme=torch.per_tensor_affine,
    reduce_range=False,
    scale=1.0,             # Initial scale (will be learned)
    zero_point=128.0       # Initial zero point (will be learned)
)
qat_activation = FakeQuantize.with_args(
    observer=activation_obs
)

custom_qconfig = QConfig(
    activation=qat_activation, 
    weight=tq.default_fused_per_channel_wt_fake_quant
)
# custom_qconfig = get_default_qconfig("x86")

In [11]:
qconfig_map = (
    QConfigMapping()
    .set_global(custom_qconfig)                   # applies to all modules by default
    # .set_module_name("backbone.embeddings", None)  # disable quant for embeddings
    .set_module_name("backbone.embeddings.patch_embeddings", None)
    .set_module_name("backbone.embeddings.dropout", None)
    # Exclude normalization layers
    .set_object_type(torch.nn.LayerNorm, None)
    .set_object_type(torch.nn.BatchNorm2d, None)
    .set_object_type(torch.nn.GroupNorm, None)
    # Exclude operations that return non-tensor objects
    .set_object_type("size", None)
    .set_object_type("view", None)
    .set_object_type("reshape", None)
    .set_object_type("permute", None)
    .set_object_type(torch.Tensor.size, None)
    .set_object_type(torch.Tensor.view, None)
    .set_object_type(torch.Tensor.reshape, None)
)

from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
from transformers.models.dinov2.modeling_dinov2 import Dinov2PatchEmbeddings, Dinov2Embeddings

# Tell FX to treat Dinov2PatchEmbeddings as non-traceable
## “Whenever you hit SomeModule in the model, don’t open it up and record its internal steps. 
## Instead just treat the whole call as a single step in the recipe
prepare_custom_config = (
    PrepareCustomConfig()
    .set_non_traceable_module_classes([
        Dinov2PatchEmbeddings, 
        Dinov2Embeddings
    ])
)

In [12]:
# Load your best FP32 checkpoint into `quant_model.fp32_model`
ckpt_path = os.path.join(project_dir, "checkpoints", f"{cfg.checkpoint.model_name}.pth")
if not os.path.exists(ckpt_path):
    print(f"[Error] Checkpoint not found: {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu")
fp32_model.load_state_dict(checkpoint["model_state_dict"])
print(f"Loaded FP32 weights from {ckpt_path} into fp32_model.")

Loaded FP32 weights from /home/panos/dev/hf_seg/checkpoints/dino-fpn-bn.pth into fp32_model.


In [13]:
from torch.ao.quantization.quantize_fx import prepare_qat_fx

# 4. Prepare the model for static quantization
example_inputs = (torch.randn(1, 3, 364, 1232),)
prep_model = prepare_qat_fx(
    fp32_model, 
    qconfig_map, 
    example_inputs,
    prepare_custom_config=prepare_custom_config
)
# prep_model

In [12]:
def print_qconfig_status(model):
    """Print qconfig status for all modules"""
    print("\n=== QConfig Status for All Modules ===")
    
    quantized_modules = []
    excluded_modules = []
    
    for name, module in model.named_modules():
        qconfig = getattr(module, 'qconfig', 'not_set')
        
        if qconfig is None:
            excluded_modules.append(name)
            print(f"{name:60s} → EXCLUDED (qconfig=None)")
        elif qconfig == 'not_set':
            print(f"{name:60s} → NO QCONFIG")
        else:
            quantized_modules.append(name)
            print(f"{name:60s} → QUANTIZED ({type(qconfig).__name__})")
    
    print(f"\n📊 Summary:")
    print(f"   Quantized modules: {len(quantized_modules)}")
    print(f"   Excluded modules:  {len(excluded_modules)}")
    
    return quantized_modules, excluded_modules

# Use it after prepare_fx
# quantized_modules, excluded_modules = print_qconfig_status(prep_model)

In [13]:
# print(prep_model.graph)

In [14]:
# print(prep_model.graph.print_tabular())

In [15]:
# from torch.fx.passes.graph_drawer import FxGraphDrawer

# # 4. Render to PNG
# drawer = FxGraphDrawer(prep_model, "prep_model")
# dot = drawer.get_dot_graph()
# with open(project_dir / "assets" / "prep_model.png", "wb") as f:
#     f.write(dot.create_png())

In [16]:
# Build a validation/calibration loader (no augmentations, just center-crop)
crop_h, crop_w = (cfg.augmentation.crop_height, cfg.augmentation.crop_width)

qat_transform = A.Compose([
    A.CenterCrop(crop_h, crop_w)
])

val_dataset = KittiSemSegDataset(
    root_dir='/home/panos/Documents/data/kitti-360',
    train=False,
    transform=qat_transform
)
val_loader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=cfg.dataset.num_workers,
    pin_memory=True,
)
print(f"[QuantQAT] Validation dataset size: {len(val_dataset)}")

train_dataset = KittiSemSegDataset(
    root_dir='/home/panos/Documents/data/kitti-360',
    train=True,
    transform=qat_transform
)
train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=cfg.dataset.num_workers,
    pin_memory=True,
)
print(f"[QuantQAT] Train dataset size: {len(train_dataset)}")

[QuantQAT] Validation dataset size: 783
[QuantQAT] Train dataset size: 7042


In [17]:
import torch.nn as nn
import torch.optim as optim
from models.tools import CombinedLoss

# CRITICAL: Move to GPU BEFORE training
prep_model = prep_model.to(device)
print(f"QAT model moved to {device}")

# Setup training components
# criterion = CombinedLoss(alpha=0.8, ignore_index=255)
criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = optim.Adam(prep_model.parameters(), lr=1e-5)  # Lower learning rate for QAT
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader))
miou_metric = JaccardIndex(
    task='multiclass',
    num_classes=cfg.dataset.num_classes,
    average='micro',
    ignore_index=255
).to(device)

QAT model moved to cuda


In [None]:
# QAT Training Loop
num_epochs = 5  # Usually 1-5 epochs for QAT fine-tuning
print(f"Starting QAT training for {num_epochs} epochs...")

for epoch in range(num_epochs):
    prep_model = prep_model.to(device)
    print(f"QAT model moved to {device}")

    ####### TRAIN #######
    prep_model.train()
    miou_metric.reset()
    running_train_loss = 0.0

    train_bar = tqdm(train_loader, desc=f"[QAT Epoch {epoch+1}/{num_epochs}] Train")
    for batch_idx, (imgs, masks) in enumerate(train_bar, start=1):
        if batch_idx >= 100:
            break
        imgs = imgs.permute(0, 3, 1, 2)
        
        # Forward pass through preprocessing
        input = fp32_model.process(imgs).to(device)  # [B, 3, H, W]
        
        # Forward pass through QAT model
        logits = prep_model(input)
        
        # Compute loss
        masks = masks.to(device)
        loss = criterion(logits, masks.long())
        running_train_loss += loss.item()
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()

        # compute IoU
        preds = torch.argmax(logits, dim=1)  # [B, H, W]
        miou_metric.update(preds, masks)
        
        optimizer.step()
        scheduler.step()

        # Update progress bar
        train_bar.set_postfix(loss=running_train_loss / (batch_idx + 1), lr=scheduler.get_last_lr()[0])

    print(f"\t Average Train loss: {running_train_loss/batch_idx:.4f}")
    print(f"\t Train mIoU: {miou_metric.compute().item():.4f}")

    # ####### VALIDATION #######
    prep_model.eval()
    miou_metric.reset()
    running_val_loss = 0.0

    with torch.no_grad():
        val_bar = tqdm(val_loader, desc=f"[QAT Epoch {epoch + 1}/{num_epochs}]  Val")
        for batch_idx, (imgs, masks) in enumerate(val_bar, start=1):
            imgs = imgs.permute(0, 3, 1, 2).to(device)  # [B, H, W, C] -> [B, C, H, W]
            input = fp32_model.process(imgs).to(device)

            # forward + loss
            logits = prep_model(input)

            # Loss
            masks = masks.to(device)  # [B, H, W]
            loss = criterion(logits, masks.long())

            # accumulate losses
            running_val_loss += loss.item()

            # compute IoU on this batch
            preds = torch.argmax(logits, dim=1)  # [B, H, W]
            miou_metric.update(preds, masks)

            val_bar.set_postfix(val_loss=running_val_loss / batch_idx)

    avg_val_loss = running_val_loss / len(val_loader)
    avg_val_miou = miou_metric.compute().item()

    print(f"\t Average Val loss: {running_val_loss/len(val_loader):.4f}")
    print(f"\t Val mIoU: {miou_metric.compute().item():.4f}")

    print(f"Epoch {epoch+1} completed")

print("QAT training completed!")

Starting QAT training for 5 epochs...
QAT model moved to cuda


[QAT Epoch 1/5] Train:  11%|█         | 99/881 [16:58<2:14:07, 10.29s/it, loss=0.19, lr=9.69e-6] 


	 Average Train loss: 0.1900
	 Train mIoU: 0.8773
QAT Model moved to cuda


[Eval]: 100%|██████████| 98/98 [22:01<00:00, 13.48s/it]


[QAT] INT8‐quantized model mIoU = 0.8170
Epoch 1 completed
QAT model moved to cuda


[QAT Epoch 2/5] Train:   0%|          | 0/881 [00:03<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 216.00 MiB. GPU 0 has a total capacity of 3.63 GiB of which 32.00 MiB is free. Including non-PyTorch memory, this process has 3.59 GiB memory in use. Of the allocated memory 3.41 GiB is allocated by PyTorch, and 107.36 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [14]:
from torch.ao.quantization.quantize_fx import convert_fx

# 6. Convert: swap out float ops for quantized kernels
prep_model = prep_model.to("cpu")
prep_model.eval()
quant_model = convert_fx(prep_model)



In [None]:
# Evaluate mIoU of the INT8 model
with torch.no_grad():
    miou_metric = JaccardIndex(
        task="multiclass",
        num_classes=NUM_CLASSES,
        average="micro",
        ignore_index=255,
    )

    for imgs, masks in tqdm(val_loader, desc="[Eval]"):
        imgs  = imgs.permute(0, 3, 1, 2).float()
        input = fp32_model.process(imgs)
        logits = quant_model(input)           # [B, num_classes, H, W]
        preds  = torch.argmax(logits, dim=1)  # [B, H, W]
        miou_metric.update(preds, masks)

    int8_miou = miou_metric.compute().item()
print(f"[QAT] INT8‐quantized model mIoU = {int8_miou:.4f}")

In [None]:
# 4j) Save the INT8 state_dict
save_path = os.path.join(project_dir, "checkpoints", f"dino-fpn-qat-int8.pth")
torch.save(quant_model.state_dict(), save_path)
print(f"[QuantPTQ] Saved INT8 weights to {save_path}")

In [None]:
quant_model

In [15]:
fp32_mem_footprint = get_memory_footprint(fp32_model)
quant_mem_footprint = get_quant_memory_footprint(quant_model)

=== Model Memory Footprint ===
Total:    90,328,097 params, 344.57 MB
=== Model Memory Footprint ===
Total:     94,686,819 params, 108.13 MB
