# DINOv2-Small evaluation on ACDC

This notebook evaluates a trained DINOv2-Small checkpoint on the ACDC validation split.
It includes:
- environment setup for Colab,
- dataset unpacking,
- checkpoint loading,
- metrics + plots,
- ClearML logging.

In [None]:
!pip install -q albumentations torchmetrics pytorch-lightning clearml python-dotenv gdown

In [None]:
import os
import zipfile
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchmetrics.classification import MulticlassJaccardIndex

import albumentations as A
from albumentations.pytorch import ToTensorV2

from clearml import Task
from google.colab import drive, userdata

In [None]:
CONFIG = {
    "project_name": "Segmentation_Urban_Scene_CourseWork",
    "task_name": "DinoV2_Small_Eval_ACDC_Val",
    "drive_root": "/content/drive/MyDrive",
    "weights_rel_path": "weights/dinov2-small-cityscapes-epoch=48-val_miou=0.6292.ckpt",
    "acdc_zips": [
        "gt_trainval.zip",
        "rgb_anon_trainvaltest.zip",
    ],
    "data_dir": "/content/data/acdc",
    "split": "val",
    "conditions": ["fog", "night", "rain", "snow"],
    "model_name": "dinov2_vits14",
    "num_classes": 19,
    "image_size": (518, 1022),
    "batch_size": 8,
    "num_workers": 2,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

CITYSCAPES_CLASSES = [
    "road", "sidewalk", "building", "wall", "fence", "pole", "traffic light",
    "traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car",
    "truck", "bus", "train", "motorcycle", "bicycle"
]

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

os.environ["CLEARML_API_ACCESS_KEY"] = userdata.get("CLEARML_API_ACCESS_KEY")
os.environ["CLEARML_API_SECRET_KEY"] = userdata.get("CLEARML_API_SECRET_KEY")

task = Task.init(
    project_name=CONFIG["project_name"],
    task_name=CONFIG["task_name"],
    output_uri=False,
)
task.connect(CONFIG)

print(f"Using device: {CONFIG['device']}")

In [None]:
drive.mount("/content/drive")

Path(CONFIG["data_dir"]).mkdir(parents=True, exist_ok=True)

for zip_name in CONFIG["acdc_zips"]:
    zip_path = Path(CONFIG["drive_root"]) / zip_name
    if not zip_path.exists():
        raise FileNotFoundError(f"Zip file not found: {zip_path}")

    print(f"Unpacking {zip_path} -> {CONFIG['data_dir']}")
    with zipfile.ZipFile(zip_path, "r") as zf:
        zf.extractall(CONFIG["data_dir"])

weights_path = Path(CONFIG["drive_root"]) / CONFIG["weights_rel_path"]
if not weights_path.exists():
    raise FileNotFoundError(f"Checkpoint not found: {weights_path}")

print(f"Checkpoint path: {weights_path}")

In [None]:
class ACDCDataset(Dataset):
    def __init__(self, root_dir, split="val", conditions=None, augmentation=None):
        self.root_dir = Path(root_dir)
        self.split = split
        self.conditions = conditions or ["fog", "night", "rain", "snow"]
        self.augmentation = augmentation
        self.items = []

        for condition in self.conditions:
            rgb_root = self.root_dir / "rgb_anon" / condition / split
            gt_root = self.root_dir / "gt" / condition / split
            if not rgb_root.exists() or not gt_root.exists():
                print(f"Skip missing split folder: {condition}/{split}")
                continue

            for image_path in sorted(rgb_root.rglob("*_rgb_anon.png")):
                rel = image_path.relative_to(rgb_root)
                mask_name = image_path.name.replace("_rgb_anon.png", "_gt_labelTrainIds.png")
                mask_path = gt_root / rel.parent / mask_name
                if mask_path.exists():
                    self.items.append(
                        {
                            "image": str(image_path),
                            "mask": str(mask_path),
                            "condition": condition,
                            "stem": image_path.stem,
                        }
                    )

        if len(self.items) == 0:
            raise RuntimeError("No ACDC image/mask pairs found. Check dataset path and split.")

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        sample = self.items[idx]
        image = cv2.imread(sample["image"])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(sample["mask"], 0)

        if self.augmentation is not None:
            transformed = self.augmentation(image=image, mask=mask)
            image, mask = transformed["image"], transformed["mask"]

        return {
            "image": image,
            "mask": mask.long(),
            "condition": sample["condition"],
            "stem": sample["stem"],
        }

In [None]:
valid_transform = A.Compose([
    A.Resize(height=CONFIG["image_size"][0], width=CONFIG["image_size"][1]),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2(),
])

dataset = ACDCDataset(
    root_dir=CONFIG["data_dir"],
    split=CONFIG["split"],
    conditions=CONFIG["conditions"],
    augmentation=valid_transform,
)

loader = DataLoader(
    dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    num_workers=CONFIG["num_workers"],
    pin_memory=True,
)

print(f"ACDC samples: {len(dataset)}")
print(pd.DataFrame(dataset.items)["condition"].value_counts().sort_index())

In [None]:
class LinearSegmentationHead(nn.Module):
    def __init__(self, embed_dim=384, num_classes=19, patch_size=14):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_classes = num_classes
        self.patch_size = patch_size
        self.linear = nn.Conv2d(embed_dim, num_classes, kernel_size=1)

    def forward(self, x, h, w):
        bsz, n_tokens, channels = x.shape
        patch_h = h // self.patch_size
        patch_w = w // self.patch_size
        x = x.reshape(bsz, patch_h, patch_w, channels).permute(0, 3, 1, 2)
        x = self.linear(x)
        x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False)
        return x


class DinoV2SegmentationModel(nn.Module):
    def __init__(self, model_name="dinov2_vits14", num_classes=19):
        super().__init__()
        self.backbone = torch.hub.load("facebookresearch/dinov2", model_name)
        self.embed_dim = 384
        self.patch_size = 14
        self.segmentation_head = LinearSegmentationHead(
            embed_dim=self.embed_dim,
            num_classes=num_classes,
            patch_size=self.patch_size,
        )

    def forward(self, x):
        _, _, h, w = x.shape
        features = self.backbone.forward_features(x)
        patch_features = features["x_norm_patchtokens"]
        logits = self.segmentation_head(patch_features, h, w)
        return logits


def load_lightning_checkpoint(model, ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location=device)
    state_dict = ckpt.get("state_dict", ckpt)

    model_state = {}
    for key, value in state_dict.items():
        if key.startswith("backbone.") or key.startswith("segmentation_head."):
            model_state[key] = value

    missing, unexpected = model.load_state_dict(model_state, strict=False)
    print(f"Missing keys: {len(missing)}")
    print(f"Unexpected keys: {len(unexpected)}")

    model.to(device)
    model.eval()
    return model


model = DinoV2SegmentationModel(
    model_name=CONFIG["model_name"],
    num_classes=CONFIG["num_classes"],
)
model = load_lightning_checkpoint(model, str(weights_path), CONFIG["device"])

In [None]:
@torch.no_grad()
def evaluate(model, dataloader, num_classes=19, device="cpu"):
    metric_all = MulticlassJaccardIndex(num_classes=num_classes, average="macro", ignore_index=255).to(device)
    metric_per_class = MulticlassJaccardIndex(num_classes=num_classes, average="none", ignore_index=255).to(device)

    condition_metrics = {
        cond: MulticlassJaccardIndex(num_classes=num_classes, average="macro", ignore_index=255).to(device)
        for cond in CONFIG["conditions"]
    }

    total_pixels = 0
    total_correct = 0

    cache = []

    for batch in dataloader:
        images = batch["image"].to(device)
        masks = batch["mask"].to(device)
        conditions = batch["condition"]

        logits = model(images)
        preds = torch.argmax(logits, dim=1)

        metric_all.update(preds, masks)
        metric_per_class.update(preds, masks)

        valid = masks != 255
        total_correct += ((preds == masks) & valid).sum().item()
        total_pixels += valid.sum().item()

        for i, cond in enumerate(conditions):
            condition_metrics[cond].update(preds[i:i+1], masks[i:i+1])

        cache.append(
            {
                "images": images.detach().cpu(),
                "masks": masks.detach().cpu(),
                "preds": preds.detach().cpu(),
                "conditions": list(conditions),
                "stems": list(batch["stem"]),
            }
        )

    overall_miou = metric_all.compute().item()
    class_iou = metric_per_class.compute().detach().cpu().numpy()
    pixel_acc = total_correct / max(total_pixels, 1)

    condition_miou = {
        cond: metric.compute().item()
        for cond, metric in condition_metrics.items()
    }

    return {
        "overall_miou": overall_miou,
        "pixel_accuracy": pixel_acc,
        "class_iou": class_iou,
        "condition_miou": condition_miou,
        "cache": cache,
    }


results = evaluate(model, loader, num_classes=CONFIG["num_classes"], device=CONFIG["device"])

In [None]:
summary_df = pd.DataFrame(
    {
        "metric": ["mIoU", "Pixel Accuracy"],
        "value": [results["overall_miou"], results["pixel_accuracy"]],
    }
)

condition_df = pd.DataFrame(
    {
        "condition": list(results["condition_miou"].keys()),
        "mIoU": list(results["condition_miou"].values()),
    }
).sort_values("condition")

class_df = pd.DataFrame(
    {
        "class_id": list(range(CONFIG["num_classes"])),
        "class_name": CITYSCAPES_CLASSES,
        "IoU": results["class_iou"],
    }
).sort_values("IoU", ascending=False)

print("Overall metrics")
display(summary_df)
print("\nPer-condition mIoU")
display(condition_df)
print("\nPer-class IoU")
display(class_df)

logger = task.get_logger()
logger.report_table("evaluation", "overall", iteration=0, table_plot=summary_df)
logger.report_table("evaluation", "per_condition", iteration=0, table_plot=condition_df)
logger.report_table("evaluation", "per_class", iteration=0, table_plot=class_df)
logger.report_scalar("evaluation", "mIoU", results["overall_miou"], iteration=0)
logger.report_scalar("evaluation", "pixel_accuracy", results["pixel_accuracy"], iteration=0)

In [None]:
fig1, ax1 = plt.subplots(figsize=(7, 4))
ax1.bar(condition_df["condition"], condition_df["mIoU"])
ax1.set_ylim(0, 1)
ax1.set_ylabel("mIoU")
ax1.set_title("DINOv2-Small on ACDC val: per-condition mIoU")
ax1.grid(axis="y", alpha=0.25)
fig1.tight_layout()
plt.show()

logger.report_matplotlib_figure(
    title="plots",
    series="per_condition_miou",
    iteration=0,
    figure=fig1,
)

plot_df = class_df.sort_values("IoU", ascending=False)
fig2, ax2 = plt.subplots(figsize=(10, 5))
ax2.bar(plot_df["class_name"], plot_df["IoU"])
ax2.set_ylim(0, 1)
ax2.set_ylabel("IoU")
ax2.set_title("Per-class IoU (sorted)")
ax2.grid(axis="y", alpha=0.25)
ax2.tick_params(axis="x", rotation=70)
fig2.tight_layout()
plt.show()

logger.report_matplotlib_figure(
    title="plots",
    series="per_class_iou",
    iteration=0,
    figure=fig2,
)

In [None]:
def denormalize_for_display(img_tensor):
    img = img_tensor.permute(1, 2, 0).numpy().astype(np.float32)
    mean = np.array(IMAGENET_MEAN, dtype=np.float32)
    std = np.array(IMAGENET_STD, dtype=np.float32)
    img = std * img + mean
    img = np.clip(img, 0, 1)
    return img


def _collect_condition_samples_from_cache(cached_batches, conditions, samples_per_condition=3):
    collected = {cond: [] for cond in conditions}
    for batch in cached_batches:
        for i, cond in enumerate(batch["conditions"]):
            if cond in collected and len(collected[cond]) < samples_per_condition:
                collected[cond].append(
                    {
                        "image": batch["images"][i],
                        "mask": batch["masks"][i],
                        "pred": batch["preds"][i],
                    }
                )
        if all(len(collected[c]) >= samples_per_condition for c in conditions):
            break
    return collected


def show_predictions_by_condition(cached_batches, condition_miou, conditions, samples_per_condition=3):
    samples = _collect_condition_samples_from_cache(cached_batches, conditions, samples_per_condition)

    rows = len(conditions) * samples_per_condition
    fig, axes = plt.subplots(rows, 3, figsize=(14, rows * 3.2))
    if rows == 1:
        axes = np.array([axes])

    row_idx = 0
    for cond in conditions:
        cond_score = condition_miou.get(cond, float("nan"))
        for s_idx in range(samples_per_condition):
            if s_idx < len(samples[cond]):
                sample = samples[cond][s_idx]
                image = denormalize_for_display(sample["image"])
                gt = sample["mask"].numpy()
                pred = sample["pred"].numpy()

                axes[row_idx, 0].imshow(image)
                axes[row_idx, 0].set_title(f"{cond} #{s_idx + 1} | image")
                axes[row_idx, 0].axis("off")

                axes[row_idx, 1].imshow(gt, vmin=0, vmax=18)
                axes[row_idx, 1].set_title(f"GT | {cond} mIoU={cond_score:.4f}")
                axes[row_idx, 1].axis("off")

                axes[row_idx, 2].imshow(pred, vmin=0, vmax=18)
                axes[row_idx, 2].set_title("Prediction")
                axes[row_idx, 2].axis("off")
            else:
                axes[row_idx, 0].axis("off")
                axes[row_idx, 1].axis("off")
                axes[row_idx, 2].axis("off")
            row_idx += 1

    fig.tight_layout()
    plt.show()
    return fig


def _predict_ref_single_image(image_path, model, resize_hw, device):
    image = cv2.imread(str(image_path))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    transformed = A.Compose([
        A.Resize(height=resize_hw[0], width=resize_hw[1]),
        A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ToTensorV2(),
    ])(image=image)

    image_resized = A.Resize(height=resize_hw[0], width=resize_hw[1])(image=image)["image"]
    image_tensor = transformed["image"].unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(image_tensor)
        probs = torch.softmax(logits, dim=1)
        pred = torch.argmax(probs, dim=1).squeeze(0).cpu().numpy()
        max_prob = probs.max(dim=1).values.squeeze(0).cpu().numpy()

    mean_conf = float(max_prob.mean())
    entropy = float((-probs * torch.log(probs + 1e-8)).sum(dim=1).mean().item())

    return {
        "image": image_resized,
        "pred": pred,
        "mean_conf": mean_conf,
        "mean_entropy": entropy,
    }


def show_ref_predictions_by_condition(root_dir, split_ref, conditions, model, resize_hw, device, samples_per_condition=3):
    collected = {cond: [] for cond in conditions}

    for cond in conditions:
        ref_root = Path(root_dir) / "rgb_anon" / cond / split_ref
        if not ref_root.exists():
            continue

        files = sorted(ref_root.rglob("*_rgb_ref_anon.png"))
        for fp in files[:samples_per_condition]:
            pred_item = _predict_ref_single_image(
                image_path=fp,
                model=model,
                resize_hw=resize_hw,
                device=device,
            )
            collected[cond].append(pred_item)

    rows = len(conditions) * samples_per_condition
    fig, axes = plt.subplots(rows, 2, figsize=(12, rows * 3.2))
    if rows == 1:
        axes = np.array([axes])

    row_idx = 0
    metrics_rows = []
    for cond in conditions:
        conf_vals = []
        ent_vals = []
        for s_idx in range(samples_per_condition):
            if s_idx < len(collected[cond]):
                item = collected[cond][s_idx]
                conf_vals.append(item["mean_conf"])
                ent_vals.append(item["mean_entropy"])

                axes[row_idx, 0].imshow(item["image"])
                axes[row_idx, 0].set_title(f"{cond} ref #{s_idx + 1} | image")
                axes[row_idx, 0].axis("off")

                axes[row_idx, 1].imshow(item["pred"], vmin=0, vmax=18)
                axes[row_idx, 1].set_title(
                    f"Prediction | conf={item['mean_conf']:.3f}, H={item['mean_entropy']:.3f}"
                )
                axes[row_idx, 1].axis("off")
            else:
                axes[row_idx, 0].axis("off")
                axes[row_idx, 1].axis("off")
            row_idx += 1

        if len(conf_vals) > 0:
            metrics_rows.append(
                {
                    "condition": cond,
                    "mean_confidence": float(np.mean(conf_vals)),
                    "mean_entropy": float(np.mean(ent_vals)),
                    "num_samples": len(conf_vals),
                }
            )

    fig.tight_layout()
    plt.show()

    metrics_df = pd.DataFrame(metrics_rows)
    return fig, metrics_df


fig3 = show_predictions_by_condition(
    cached_batches=results["cache"],
    condition_miou=results["condition_miou"],
    conditions=CONFIG["conditions"],
    samples_per_condition=3,
)
logger.report_matplotlib_figure(
    title="plots",
    series="qualitative_predictions_by_condition",
    iteration=0,
    figure=fig3,
)

fig4, ref_metrics_df = show_ref_predictions_by_condition(
    root_dir=CONFIG["data_dir"],
    split_ref=f"{CONFIG['split']}_ref",
    conditions=CONFIG["conditions"],
    model=model,
    resize_hw=CONFIG["image_size"],
    device=CONFIG["device"],
    samples_per_condition=3,
)
logger.report_matplotlib_figure(
    title="plots",
    series="qualitative_predictions_ref_by_condition",
    iteration=0,
    figure=fig4,
)

if len(ref_metrics_df) > 0:
    print("Ref split proxy metrics (no GT-based IoU):")
    display(ref_metrics_df)
    logger.report_table("evaluation", "ref_proxy_metrics", iteration=0, table_plot=ref_metrics_df)
    for _, row in ref_metrics_df.iterrows():
        logger.report_scalar("ref_proxy", f"confidence_{row['condition']}", row["mean_confidence"], iteration=0)
        logger.report_scalar("ref_proxy", f"entropy_{row['condition']}", row["mean_entropy"], iteration=0)
else:
    print("No ref images found for selected split.")

In [None]:
task.close()
print("ClearML task closed.")