# AnomalyDINO Interactive Demo

Select MVTech category, backbone, SAM3 usage, and number of reference images. Then run and view metrics and a quick visualization.

In [10]:
import sys, pathlib
repo_root = pathlib.Path("..").resolve()
sys.path.append(str(repo_root))

In [11]:
from IPython.display import HTML
HTML('''<style>.input_area,.jp-CodeCell .jp-Cell-inputWrapper{display:none !important;}</style>''')

In [12]:
import os
import pathlib
import numpy as np
import cv2
import torch
import matplotlib
matplotlib.use("module://matplotlib_inline.backend_inline")
import matplotlib.pyplot as plt
from io import BytesIO
import ipywidgets as widgets

from backbones import get_backbone
from dataset.dataloader import load_mvtec
from models.model_bank_knn import PatchKNNDetector
from segmenters.sam3 import SAM3Segmenter
from evaluation.anomaly_evaluator import AnomalyEvaluator

def find_repo_root(start: pathlib.Path) -> pathlib.Path:
    cur = start
    for _ in range(5):  # walk up to 5 levels just in case
        if (cur / "dataset" / "mvtec_anomaly_detection").exists():
            return cur
        if cur.parent == cur:
            break
        cur = cur.parent
    return start

repo_root = find_repo_root(pathlib.Path().resolve())
root = os.environ.get("MVTec_ROOT", str(repo_root / "dataset" / "mvtec_anomaly_detection"))
categories = [
    "bottle", "cable", "capsule", "carpet", "grid", "hazelnut", "leather",
    "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper",
]

# Available backbones from the registry
backbone_names = [
    "dinov2_small", "dinov2_base", "dinov2_large",
    "dinov3_small", "dinov3_base", "dinov3_large",
]

category_dd = widgets.Dropdown(options=categories, value="bottle", description="Category")
backbone_dd = widgets.Dropdown(options=backbone_names, value="dinov2_small", description="Backbone")
use_sam_cb = widgets.Checkbox(value=False, description="Use SAM3")
nref_slider = widgets.IntSlider(value=1, min=1, max=10, step=1, description="n_ref")
run_btn = widgets.Button(description="Run", button_style="success")
out = widgets.Output()

ui = widgets.VBox([
    widgets.HBox([category_dd, backbone_dd]),
    widgets.HBox([use_sam_cb, nref_slider]),
    run_btn,
    out
])
display(ui)

VBox(children=(HBox(children=(Dropdown(description='Category', options=('bottle', 'cable', 'capsule', 'carpet'â€¦

In [13]:

from IPython.display import display

def run_experiment(_):
    with out:
        out.clear_output(wait=True)
        category = category_dd.value
        backbone_name = backbone_dd.value
        use_sam = use_sam_cb.value
        n_ref = nref_slider.value
        device = "cuda" if torch.cuda.is_available() else "cpu"

        print(f"Category: {category} | Backbone: {backbone_name} | SAM3: {use_sam} | n_ref: {n_ref} | device: {device}")

        train_paths, test_paths = load_mvtec(category=category, root=root)
        test_paths = test_paths[::20]  # subsample for speed
        train_paths = train_paths[:max(1, n_ref)]

        evaluator = AnomalyEvaluator(pixel_subsample_rate=1.0, compute_pro=False)
        segmenter = SAM3Segmenter(text_prompt=category, device=device) if use_sam else None
        backbone = get_backbone(backbone_name)
        model = PatchKNNDetector(backbone=backbone, segmenter=segmenter, device=device)

        model.fit(train_paths, n_ref=n_ref)

        vis_img = vis_amap = None
        for p in test_paths:
            image, amap, score = model.predict(p)
            is_anomaly = 0 if "good" in p else 1
            if is_anomaly == 0:
                gt_mask = np.zeros_like(amap)
            else:
                mask_path = p.replace("test", "ground_truth").replace(".png", "_mask.png")
                if os.path.exists(mask_path):
                    gt_mask = cv2.imread(mask_path, 0)
                    if gt_mask.shape != amap.shape:
                        gt_mask = cv2.resize(gt_mask, (amap.shape[1], amap.shape[0]), interpolation=cv2.INTER_NEAREST)
                    gt_mask = (gt_mask > 0).astype(int)
                else:
                    gt_mask = np.zeros_like(amap)

            evaluator.update(score, is_anomaly, amap, gt_mask)
            if vis_img is None:
                vis_img, vis_amap = image, amap

        results = evaluator.compute()
        print("Metrics:")
        for k, v in results.items():
            print(f"  {k}: {v:.4f}")

        if vis_img is not None:
            amap = vis_amap.astype(np.float32)
            amap = (amap - amap.min()) / (amap.max() - amap.min() + 1e-8)
            fig, axes = plt.subplots(1, 3, figsize=(10, 4))
            axes[0].imshow(vis_img)
            axes[0].axis("off")
            axes[0].set_title("Input")
            im = axes[1].imshow(amap, cmap="jet")
            axes[1].axis("off")
            axes[1].set_title("Anomaly map")
            axes[2].imshow(vis_img)
            axes[2].imshow(amap, cmap="jet", alpha=0.5)
            axes[2].axis("off")
            axes[2].set_title("Overlay")
            fig.tight_layout()
            buf = BytesIO()
            fig.savefig(buf, format="png", bbox_inches="tight")
            buf.seek(0)
            display(widgets.Image(value=buf.read(), format="png"))
            plt.close(fig)


run_btn.on_click(run_experiment)
