In [1]:
import os
import sys
from pathlib import Path

from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from torchmetrics import JaccardIndex
from torch.quantization import quantize_dynamic

from hydra import compose, initialize

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
"""Achieves 79.1% mIoU on the KITTI-360 validation set."""

'Achieves 79.1% mIoU on the KITTI-360 validation set.'

In [6]:
# Add the project root directory to the Python path
cur_dir     = Path.cwd()
project_dir = cur_dir.parent
print(f"Adding project directory to sys.path: {project_dir}")
sys.path.append(str(project_dir))

from models.DinoFPNbn import DinoFPN
from data.dataset import KittiSemSegDataset
from utils.others import get_memory_footprint, get_quant_memory_footprint

Adding project directory to sys.path: /home/panos/dev/hf_seg


In [4]:
with initialize(version_base=None, config_path=f"../configs", job_name="inference_metrics"):
    cfg = compose(config_name="config")

In [5]:
# Dataset + DataLoader
dataset_root = '/home/panos/Documents/data/kitti-360'
val_dataset = KittiSemSegDataset(dataset_root, train=False, transform=None)
val_loader  = DataLoader(val_dataset,
                            batch_size=cfg.train.batch_size,
                            shuffle=False,
                            num_workers=cfg.dataset.num_workers,
                            pin_memory=True)

# Initialize original model & load checkpoint
model = DinoFPN(num_labels=cfg.dataset.num_classes, model_cfg=cfg.model)
checkpoint_path = project_dir / f"checkpoints/{cfg.checkpoint.model_name}.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print(f"Loaded original model from {checkpoint_path}")

# Create quantized model (dynamic quantization)
qmodel = quantize_dynamic(
    model,
    {torch.nn.Linear, torch.nn.Conv2d},
    dtype=torch.qint8,
    inplace=False
)
qmodel.eval()
print("Created quantized version of the model (INT8 dynamic).")

# Prepare mIoU metric
miou_metric = JaccardIndex(
    task='multiclass',
    num_classes=cfg.dataset.num_classes,
    average='micro',
    ignore_index=None
)

Loaded original model from /home/panos/dev/hf_seg/checkpoints/dino-fpn-bn.pth
Created quantized version of the model (INT8 dynamic).


In [16]:
with torch.no_grad():
    for batch_idx, (imgs, masks) in enumerate(tqdm(val_loader)):
        imgs = imgs.permute(0, 3, 1, 2)  # [B, H, W, C] -> [B, C, H, W]
        input = model.process(imgs)

        # forward + loss
        logits = model(input)
        cls_map = None

        # compute IoU on this batch
        preds = torch.argmax(logits, dim=1)  # [B, H, W]
        miou_metric.update(preds, masks)

    avg_val_miou = miou_metric.compute().item()

print(f"\n  Val mIoU: {avg_val_miou:.4f}")

100%|██████████| 66/66 [29:26<00:00, 26.76s/it]


  Val mIoU: 0.7911



