# SA-1B/SA-V Val Split Qualitative

이 노트북은 SA-1B val 이미지 5~10장과 SA-V val 비디오 2~3개를 무작위로 추출해 모든 모델 조합의 포인트 프롬프트 결과를 PDF로 저장합니다. 실행 전에 `conda activate SAM2` 환경을 활성화한 뒤 Context7 MCP 설정과 동일한 라이브러리 버전을 사용하세요.


In [1]:

import json
import math
import random
import pickle
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib import font_manager
from matplotlib.backends.backend_pdf import PdfPages
from PIL import Image
from pycocotools import mask as mask_utils

from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor

TIMES_FONT_DIR = Path("../../Times New Roman")
if TIMES_FONT_DIR.exists():
    for font_path in TIMES_FONT_DIR.glob("*.ttf"):
        font_manager.fontManager.addfont(str(font_path))
else:
    print(f"[WARN] Times New Roman directory not found: {TIMES_FONT_DIR}")

plt.rcParams["figure.facecolor"] = "white"
plt.rcParams["savefig.facecolor"] = "white"
plt.rcParams["font.family"] = "Times New Roman Cyr"
plt.rcParams["font.sans-serif"] = ["Times New Roman Cyr"]
plt.rcParams["font.serif"] = ["Times New Roman Cyr"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if device.type == "cuda":
    autocast_context = torch.autocast("cuda", dtype=torch.bfloat16)
    autocast_context.__enter__()
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

GLOBAL_SEED = 2024
random.seed(GLOBAL_SEED)
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)


Using device: cuda


<torch._C.Generator at 0x7f9df3c7f7f0>

In [2]:
DATA_ROOT = Path("../../datasets")
SA1B_VAL_ROOT = DATA_ROOT / "sa-1b_split" / "val"
SAV_VAL_FRAMES = DATA_ROOT / "sa-v" / "sav_val" / "JPEGImages_24fps"
SAV_VAL_ANN = DATA_ROOT / "sa-v" / "sav_val" / "Annotations_6fps"

SA1B_NUM_IMAGES = 8  # target range [5, 10]
SAV_NUM_VIDEOS = 3   # target range [2, 3]
POINT_SET_SIZES = [1, 3, 5]
SA1B_PRIMARY_AREA_FRAC = 0.05
SA1B_SECONDARY_AREA_FRAC = 0.02
SAV_MIN_AREA_FRAC = 0.02
SAV_MIN_MASK_FRAMES = 6
SAV_TIMELINE_FRAMES = 6

SA1B_NUM_IMAGES = min(10, max(5, SA1B_NUM_IMAGES))
SAV_NUM_VIDEOS = min(3, max(2, SAV_NUM_VIDEOS))

RNG = np.random.default_rng(GLOBAL_SEED)

OUTPUT_DIR = Path("../qualitative_val_outputs")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
CACHE_DIR = OUTPUT_DIR / "cache"
CACHE_DIR.mkdir(parents=True, exist_ok=True)
PDF_IMAGES_PATH = OUTPUT_DIR / "qualitative_val_images.pdf"
PDF_VIDEOS_PATH = OUTPUT_DIR / "qualitative_val_videos.pdf"
IMAGE_CACHE_PATH = CACHE_DIR / "sa1b_preds.pkl"
VIDEO_CACHE_PATH = CACHE_DIR / "sav_preds.pkl"

REFERENCE_COLOR = (0, 196, 255)
PRED_COLOR = (255, 100, 100)
GT_COLOR = (111, 255, 0)
MASK_ALPHA = 0.55
POINT_MARKER_SIZE = 150
POINT_EDGE_COLOR = "white"
POINT_COLOR = "#00ff69"
POINT_LINEWIDTH = 1.1
BASE_FIG_COLS = 4

print(f"SA-1B val root: {SA1B_VAL_ROOT}")
print(f"SA-V val root : {SAV_VAL_FRAMES}")
print(f"Image PDF : {PDF_IMAGES_PATH.resolve()}")
print(f"Video PDF : {PDF_VIDEOS_PATH.resolve()}")


SA-1B val root: ../../datasets/sa-1b_split/val
SA-V val root : ../../datasets/sa-v/sav_val/JPEGImages_24fps
Image PDF : /home/lji/SAM/sam2/qualitative_val_outputs/qualitative_val_images.pdf
Video PDF : /home/lji/SAM/sam2/qualitative_val_outputs/qualitative_val_videos.pdf


In [3]:
@dataclass
class SA1BPromptSample:
    sample_id: str
    image_path: Path
    json_path: Path
    image_np: np.ndarray
    gt_mask: np.ndarray
    points_by_count: Dict[int, np.ndarray]
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class SAVVideoSample:
    sample_id: str
    video_id: str
    object_id: int
    video_dir: Path
    frame_indices: List[int]
    frame_idx_to_path: Dict[int, Path]
    prompt_frame_idx: int
    prompt_mask: np.ndarray
    prompt_points: np.ndarray
    timeline_indices: List[int]
    gt_masks_by_frame: Dict[int, np.ndarray]
    metadata: Dict[str, Any] = field(default_factory=dict)


def ensure_rle(segmentation, height: int, width: int) -> Dict:
    if isinstance(segmentation, dict) and "counts" in segmentation:
        rle = dict(segmentation)
    elif isinstance(segmentation, list):
        rles = mask_utils.frPyObjects(segmentation, height, width)
        rle = mask_utils.merge(rles)
    else:
        raise ValueError("Unsupported segmentation format")
    if isinstance(rle.get("counts"), bytes):
        rle["counts"] = rle["counts"].decode("ascii")
    return rle


def decode_rle_mask(rle: Dict) -> np.ndarray:
    mask = mask_utils.decode([rle])[:, :, 0]
    return mask.astype(bool)


def sample_points_from_mask(mask: np.ndarray, num_points: int, rng: np.random.Generator) -> np.ndarray:
    ys, xs = np.nonzero(mask)
    coords = np.stack([xs, ys], axis=1)
    if len(coords) == 0:
        raise ValueError("Empty mask cannot provide prompts")
    replace = len(coords) < num_points
    idxs = rng.choice(len(coords), size=num_points, replace=replace)
    return coords[idxs]


In [4]:
def _load_sa1b_entry(img_path: Path, rng: np.random.Generator, cache: Dict[Path, Dict]) -> Dict:
    if img_path in cache:
        return cache[img_path]

    json_path = img_path.with_suffix(".json")
    if not json_path.exists():
        raise FileNotFoundError(f"Missing annotation for {img_path}")

    with Image.open(img_path) as img:
        rgb = img.convert("RGB")
        image_np = np.array(rgb)
        width, height = rgb.size

    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    annotations = data["annotations"] if isinstance(data, dict) and "annotations" in data else data

    entry = {
        "image_np": image_np,
        "annotations": annotations,
        "width": width,
        "height": height,
        "ann_order": rng.permutation(len(annotations)).tolist() if annotations else [],
        "mask_cache": {},
    }
    cache[img_path] = entry
    return entry


def _select_mask(entry: Dict, area_thresh: float) -> Optional[Tuple[int, np.ndarray, float]]:
    if not entry["annotations"]:
        return None
    area = entry["width"] * entry["height"]
    for ann_idx in entry["ann_order"]:
        ann = entry["annotations"][ann_idx]
        seg = ann.get("segmentation")
        if seg is None:
            continue
        if ann_idx not in entry["mask_cache"]:
            try:
                rle = ensure_rle(seg, entry["height"], entry["width"])
                mask = decode_rle_mask(rle)
            except Exception:
                mask = None
            entry["mask_cache"][ann_idx] = mask
        mask = entry["mask_cache"].get(ann_idx)
        if mask is None or mask.sum() == 0:
            continue
        area_frac = mask.sum() / max(area, 1)
        if area_frac >= area_thresh:
            return ann_idx, mask, area_frac
    return None


def sample_sa1b_prompts(root: Path, num_images: int, point_sizes: List[int], rng: np.random.Generator) -> List[SA1BPromptSample]:
    image_paths = sorted(root.glob("*.jpg"))
    if not image_paths:
        raise FileNotFoundError(f"No SA-1B images found under {root}")

    target = min(len(image_paths), num_images)
    order = rng.permutation(len(image_paths))
    cache: Dict[Path, Dict] = {}
    used_indices = set()
    samples: List[SA1BPromptSample] = []

    for area_thresh in [SA1B_PRIMARY_AREA_FRAC, SA1B_SECONDARY_AREA_FRAC]:
        for idx in order:
            if idx in used_indices:
                continue
            img_path = image_paths[idx]
            json_path = img_path.with_suffix(".json")
            if not json_path.exists():
                continue
            try:
                entry = _load_sa1b_entry(img_path, rng, cache)
            except FileNotFoundError:
                continue
            selected = _select_mask(entry, area_thresh)
            if selected is None:
                continue
            ann_idx, mask, area_frac = selected

            points_by_count = {
                count: sample_points_from_mask(mask, count, rng)
                for count in point_sizes
            }
            samples.append(
                SA1BPromptSample(
                    sample_id=img_path.stem,
                    image_path=img_path,
                    json_path=json_path,
                    image_np=entry["image_np"],
                    gt_mask=mask,
                    points_by_count=points_by_count,
                    metadata={
                        "ann_idx": int(ann_idx),
                        "mask_area": int(mask.sum()),
                        "mask_area_frac": float(area_frac),
                    },
                )
            )
            used_indices.add(idx)
            if len(samples) == target:
                break
        if len(samples) == target:
            break

    if len(samples) < target:
        print(f"[WARN] Requested {target} SA-1B samples but only collected {len(samples)}.")
    return samples


def _compute_timeline_indices(frame_indices: List[int], desired: int, ensure_idx: int) -> List[int]:
    if not frame_indices:
        return []
    desired = max(1, min(desired, len(frame_indices)))
    timeline = []
    if ensure_idx in frame_indices:
        timeline.append(ensure_idx)
    positions = np.linspace(0, len(frame_indices) - 1, desired).astype(int)
    for pos in positions:
        idx = frame_indices[pos]
        if idx not in timeline:
            timeline.append(idx)
        if len(timeline) == desired:
            break
    timeline.sort()
    return timeline


def sample_sav_videos(frames_root: Path, ann_root: Path, num_videos: int, rng: np.random.Generator) -> List[SAVVideoSample]:
    video_dirs = sorted([p for p in frames_root.iterdir() if p.is_dir()])
    if not video_dirs:
        raise FileNotFoundError(f"No SA-V videos found under {frames_root}")

    target = min(len(video_dirs), num_videos)
    order = rng.permutation(len(video_dirs))
    samples: List[SAVVideoSample] = []

    for idx in order:
        video_dir = video_dirs[idx]
        video_id = video_dir.name
        frame_idx_to_path = {}
        for frame_path in sorted(video_dir.glob("*.jpg")):
            try:
                frame_idx = int(frame_path.stem)
            except ValueError:
                continue
            frame_idx_to_path[frame_idx] = frame_path
        frame_indices = sorted(frame_idx_to_path)
        if not frame_indices:
            continue

        ann_dir = ann_root / video_id
        if not ann_dir.exists():
            continue
        obj_dirs = [p for p in ann_dir.iterdir() if p.is_dir()]
        if not obj_dirs:
            continue
        obj_order = rng.permutation(len(obj_dirs))

        chosen = None
        for obj_idx in obj_order:
            obj_dir = obj_dirs[obj_idx]
            mask_entries: List[Tuple[int, np.ndarray]] = []
            for mask_file in sorted(obj_dir.glob("*.png")):
                try:
                    frame_idx = int(mask_file.stem)
                except ValueError:
                    continue
                if frame_idx not in frame_idx_to_path:
                    continue
                mask = np.array(Image.open(mask_file).convert("L")) > 0
                area_frac = mask.mean()
                if area_frac < SAV_MIN_AREA_FRAC:
                    continue
                mask_entries.append((frame_idx, mask))
            if len(mask_entries) < SAV_MIN_MASK_FRAMES:
                continue
            mask_entries.sort(key=lambda x: x[0])
            chosen = (int(obj_dir.name), mask_entries)
            break

        if chosen is None:
            continue

        object_id, mask_entries = chosen
        prompt_frame_idx, prompt_mask = mask_entries[0]
        prompt_points = sample_points_from_mask(prompt_mask, min(5, max(1, prompt_mask.sum())), rng)
        mask_frame_indices = [frame_idx for frame_idx, _ in mask_entries]
        timeline = _compute_timeline_indices(mask_frame_indices, SAV_TIMELINE_FRAMES, prompt_frame_idx)
        gt_masks_by_frame = {frame_idx: mask for frame_idx, mask in mask_entries}

        samples.append(
            SAVVideoSample(
                sample_id=f"{video_id}_obj{object_id:03d}",
                video_id=video_id,
                object_id=object_id,
                video_dir=video_dir,
                frame_indices=frame_indices,
                frame_idx_to_path=frame_idx_to_path,
                prompt_frame_idx=prompt_frame_idx,
                prompt_mask=prompt_mask,
                prompt_points=prompt_points,
                timeline_indices=timeline,
                gt_masks_by_frame=gt_masks_by_frame,
                metadata={
                    "area_frac": float(prompt_mask.mean()),
                    "num_frames": len(frame_indices),
                    "mask_frame_coverage": len(mask_entries) / len(frame_indices),
                },
            )
        )

        if len(samples) == target:
            break

    if len(samples) < target:
        print(f"[WARN] Requested {target} SA-V samples but only collected {len(samples)}.")
    return samples


In [5]:
def load_prediction_cache(cache_path: Path, expected_metadata: Dict[str, Any]):
    if not cache_path.exists():
        return None
    try:
        with open(cache_path, "rb") as f:
            data = pickle.load(f)
    except Exception as exc:
        print(f"[cache] Failed to load {cache_path.name}: {exc}")
        return None
    for key, value in expected_metadata.items():
        if data.get(key) != value:
            return None
    return data.get("predictions")


def save_prediction_cache(cache_path: Path, metadata: Dict[str, Any], predictions: Dict[str, Any]):
    payload = dict(metadata)
    payload["predictions"] = predictions
    try:
        with open(cache_path, "wb") as f:
            pickle.dump(payload, f)
    except Exception as exc:
        print(f"[cache] Failed to save {cache_path.name}: {exc}")


In [6]:
# Cache-aware inference helper overrides

def run_image_inference(model_specs: List[Dict], sa1b_samples: List[SA1BPromptSample], point_sizes: List[int]):
    sample_ids = sorted(sample.sample_id for sample in sa1b_samples)
    model_keys = sorted(spec["key"] for spec in model_specs)
    cache_meta = {
        "type": "sa1b",
        "version": 1,
        "sample_ids": sample_ids,
        "model_keys": model_keys,
        "point_sizes": list(point_sizes),
    }
    cached = load_prediction_cache(IMAGE_CACHE_PATH, cache_meta)
    if cached is not None:
        print(f"[cache] Using cached SA-1B predictions from {IMAGE_CACHE_PATH.name}")
        return cached

    predictions = {sample.sample_id: {} for sample in sa1b_samples}
    for spec in model_specs:
        ckpt_path = Path(spec["checkpoint"])
        if not ckpt_path.exists():
            print(f"[SKIP] Missing checkpoint for {spec['label']}: {ckpt_path}")
            continue
        print(f"[Image] Loading {spec['label']} from {ckpt_path}")
        predictor = build_image_predictor(spec["model_cfg"], spec["checkpoint"])
        for sample in sa1b_samples:
            predictor.set_image(sample.image_np)
            predictions[sample.sample_id][spec["key"]] = {
                count: predict_with_points(predictor, sample.points_by_count[count])
                for count in point_sizes
            }
        del predictor
        if device.type == "cuda":
            torch.cuda.empty_cache()

    save_prediction_cache(IMAGE_CACHE_PATH, cache_meta, predictions)
    return predictions


@torch.inference_mode()
def run_video_inference(model_specs: List[Dict], sav_samples: List[SAVVideoSample]):
    sample_ids = sorted(sample.sample_id for sample in sav_samples)
    model_keys = sorted(spec["key"] for spec in model_specs)
    cache_meta = {
        "type": "sav",
        "version": 1,
        "sample_ids": sample_ids,
        "model_keys": model_keys,
    }
    cached = load_prediction_cache(VIDEO_CACHE_PATH, cache_meta)
    if cached is not None:
        print(f"[cache] Using cached SA-V predictions from {VIDEO_CACHE_PATH.name}")
        return cached

    predictions = {sample.sample_id: {} for sample in sav_samples}
    for spec in model_specs:
        ckpt_path = Path(spec["checkpoint"])
        if not ckpt_path.exists():
            print(f"[SKIP] Missing checkpoint for {spec['label']}: {ckpt_path}")
            continue
        print(f"[Video] Loading {spec['label']} from {ckpt_path}")
        predictor = build_video_predictor(spec["model_cfg"], spec["checkpoint"])
        for sample in sav_samples:
            timeline = sample.timeline_indices or _compute_timeline_indices(
                sample.frame_indices, SAV_TIMELINE_FRAMES, sample.prompt_frame_idx
            )
            timeline_set = set(timeline)
            inference_state = predictor.init_state(str(sample.video_dir))
            prompt_mask_tensor = torch.from_numpy(sample.prompt_mask.astype(np.uint8))
            frame_idx, obj_ids, prompt_masks = predictor.add_new_mask(
                inference_state=inference_state,
                frame_idx=sample.prompt_frame_idx,
                obj_id=sample.object_id,
                mask=prompt_mask_tensor,
            )
            mask_records: Dict[int, np.ndarray] = {}
            if frame_idx in timeline_set and sample.object_id in obj_ids:
                obj_position = obj_ids.index(sample.object_id)
                mask_records[frame_idx] = (
                    prompt_masks[obj_position].squeeze().gt(0).cpu().numpy()
                )
            for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
                if not timeline_set:
                    break
                if out_frame_idx not in timeline_set:
                    continue
                for obj_position, obj_id in enumerate(out_obj_ids):
                    if obj_id != sample.object_id:
                        continue
                    mask = out_mask_logits[obj_position].squeeze().gt(0).cpu().numpy()
                    mask_records[out_frame_idx] = mask
                if len(mask_records) == len(timeline_set):
                    break
            predictions[sample.sample_id][spec["key"]] = {
                "timeline_indices": timeline,
                "masks": mask_records,
            }
        del predictor
        if device.type == "cuda":
            torch.cuda.empty_cache()

    save_prediction_cache(VIDEO_CACHE_PATH, cache_meta, predictions)
    return predictions



In [7]:
# Predictor helpers

def build_image_predictor(model_cfg_path: str, checkpoint_path: str) -> SAM2ImagePredictor:
    model = build_sam2(model_cfg_path, checkpoint_path, device=device)
    return SAM2ImagePredictor(model)


def build_video_predictor(model_cfg_path: str, checkpoint_path: str):
    return build_sam2_video_predictor(model_cfg_path, checkpoint_path, device=device)


def predict_with_points(predictor: SAM2ImagePredictor, point_coords: np.ndarray) -> np.ndarray:
    coords = np.array(point_coords, dtype=np.float32)
    labels = np.ones(len(coords), dtype=np.int32)
    masks, _, _ = predictor.predict(
        point_coords=coords,
        point_labels=labels,
        multimask_output=False,
    )
    return masks[0].astype(bool)



In [8]:
MODEL_SPECS = [
    {
        "key": "sam2_base_plus",
        "label": "SAM2.1 B+",
        "family": "sam2",
        "scale": "base_plus",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_b+.yaml",
        "checkpoint": "../checkpoints/sam2.1_hiera_base_plus.pt",
    },
    {
        "key": "sam2_small",
        "label": "SAM2.1 S",
        "family": "sam2",
        "scale": "small",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_s.yaml",
        "checkpoint": "../checkpoints/sam2.1_hiera_small.pt",
    },
    {
        "key": "sam2_tiny",
        "label": "SAM2.1 T",
        "family": "sam2",
        "scale": "tiny",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_t.yaml",
        "checkpoint": "../checkpoints/sam2.1_hiera_tiny.pt",
    },
    {
        "key": "minmax_base_plus",
        "label": "MinMax B+",
        "family": "minmax",
        "scale": "base_plus",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_b+.yaml",
        "checkpoint": "../sam2_minmax/minmax_qat_base_plus_20251111_122542/checkpoints/checkpoint_sam2.pt",
    },
    {
        "key": "minmax_small",
        "label": "MinMax S",
        "family": "minmax",
        "scale": "small",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_s.yaml",
        "checkpoint": "../sam2_minmax/minmax_qat_small_20251111_233441/checkpoints/checkpoint_sam2.pt",
    },
    {
        "key": "minmax_tiny",
        "label": "MinMax T",
        "family": "minmax",
        "scale": "tiny",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_t.yaml",
        "checkpoint": "../sam2_minmax/minmax_qat_tiny_20251111_165611/checkpoints/checkpoint_sam2.pt",
    },
    {
        "key": "baseonly_base_plus",
        "label": "ALPQ_SAM2 B+",
        "family": "baseonly",
        "scale": "base_plus",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_b+.yaml",
        "checkpoint": "../sam2_logs/ablations/adaptive_qat_toy_base_plus_20251112_101653/checkpoints/checkpoint.pt",
    },
    {
        "key": "baseonly_main_small",
        "label": "ALPQ_SAM2 S",
        "family": "baseonly",
        "scale": "small",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_s.yaml",
        "checkpoint": "../sam2_logs/main_small_20251113_212939/checkpoints/checkpoint.pt",
    },
    {
        "key": "baseonly_main_tiny",
        "label": "ALPQ_SAM2 T",
        "family": "baseonly",
        "scale": "tiny",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_t.yaml",
        "checkpoint": "../sam2_logs/main_tiny_20251113_212949/checkpoints/checkpoint.pt",
    },
    {
        "key": "alpq_base_plus",
        "label": "Explicit-ALPQ B+",
        "family": "alpq",
        "scale": "base_plus",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_b+.yaml",
        "checkpoint": "../sam2_logs/classic/adaptive_qat_toy_base_plus_20251110_155500/checkpoints/checkpoint.pt",
    },
    {
        "key": "alpq_small",
        "label": "Explicit-ALPQ S",
        "family": "alpq",
        "scale": "small",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_s.yaml",
        "checkpoint": "../sam2_logs/classic/adaptive_qat_toy_small_20251111_172858/checkpoints/checkpoint.pt",
    },
    {
        "key": "alpq_tiny",
        "label": "Explicit-ALPQ T",
        "family": "alpq",
        "scale": "tiny",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_t.yaml",
        "checkpoint": "../sam2_logs/classic/adaptive_qat_toy_tiny_20251112_161453_importancefixed/checkpoints/checkpoint.pt",
    },
]

SCALE_ORDER = ["base_plus", "small", "tiny"]
SCALE_DISPLAY = {"base_plus": "B+", "small": "S", "tiny": "T"}
FAMILY_ORDER = ["sam2", "minmax", "baseonly", "alpq"]
FAMILY_DISPLAY = {
    "sam2": "SAM2",
    "minmax": "MinMax",
    "baseonly": "ALPQ",
    "alpq": "Explicit-ALPQ",
}

SPECS_BY_KEY = {spec["key"]: spec for spec in MODEL_SPECS}
SPECS_BY_SCALE = defaultdict(list)
SPECS_BY_FAMILY = defaultdict(dict)
for spec in MODEL_SPECS:
    SPECS_BY_SCALE[spec["scale"]].append(spec)
    SPECS_BY_FAMILY[spec["family"]][spec["scale"]] = spec

print(f"Registered {len(MODEL_SPECS)} model variants for comparison.")


Registered 12 model variants for comparison.


In [9]:
# Override inference helpers with cache-aware versions

def run_image_inference(model_specs: List[Dict], sa1b_samples: List[SA1BPromptSample], point_sizes: List[int]):
    sample_ids = sorted(sample.sample_id for sample in sa1b_samples)
    model_keys = sorted(spec["key"] for spec in model_specs)
    cache_meta = {
        "type": "sa1b",
        "version": 1,
        "sample_ids": sample_ids,
        "model_keys": model_keys,
        "point_sizes": point_sizes,
    }
    cached = load_prediction_cache(IMAGE_CACHE_PATH, cache_meta)
    if cached is not None:
        print(f"[cache] Using cached SA-1B predictions from {IMAGE_CACHE_PATH.name}")
        return cached

    predictions = {sample.sample_id: {} for sample in sa1b_samples}
    for spec in model_specs:
        ckpt_path = Path(spec["checkpoint"])
        if not ckpt_path.exists():
            print(f"[SKIP] Missing checkpoint for {spec['label']}: {ckpt_path}")
            continue
        print(f"[Image] Loading {spec['label']} from {ckpt_path}")
        predictor = build_image_predictor(spec["model_cfg"], spec["checkpoint"])
        for sample in sa1b_samples:
            predictor.set_image(sample.image_np)
            predictions[sample.sample_id][spec["key"]] = {
                count: predict_with_points(predictor, sample.points_by_count[count])
                for count in point_sizes
            }
        del predictor
        if device.type == "cuda":
            torch.cuda.empty_cache()

    save_prediction_cache(IMAGE_CACHE_PATH, cache_meta, predictions)
    return predictions


@torch.inference_mode()
def run_video_inference(model_specs: List[Dict], sav_samples: List[SAVVideoSample]):
    sample_ids = sorted(sample.sample_id for sample in sav_samples)
    model_keys = sorted(spec["key"] for spec in model_specs)
    cache_meta = {
        "type": "sav",
        "version": 1,
        "sample_ids": sample_ids,
        "model_keys": model_keys,
    }
    cached = load_prediction_cache(VIDEO_CACHE_PATH, cache_meta)
    if cached is not None:
        print(f"[cache] Using cached SA-V predictions from {VIDEO_CACHE_PATH.name}")
        return cached

    predictions = {sample.sample_id: {} for sample in sav_samples}
    for spec in model_specs:
        ckpt_path = Path(spec["checkpoint"])
        if not ckpt_path.exists():
            print(f"[SKIP] Missing checkpoint for {spec['label']}: {ckpt_path}")
            continue
        print(f"[Video] Loading {spec['label']} from {ckpt_path}")
        predictor = build_video_predictor(spec["model_cfg"], spec["checkpoint"])
        for sample in sav_samples:
            timeline = sample.timeline_indices or _compute_timeline_indices(
                sample.frame_indices, SAV_TIMELINE_FRAMES, sample.prompt_frame_idx
            )
            target_frames = set(timeline)
            remaining_frames = set(timeline)
            inference_state = predictor.init_state(str(sample.video_dir))
            prompt_mask_tensor = torch.from_numpy(sample.prompt_mask.astype(np.uint8))
            frame_idx, obj_ids, prompt_masks = predictor.add_new_mask(
                inference_state=inference_state,
                frame_idx=sample.prompt_frame_idx,
                obj_id=sample.object_id,
                mask=prompt_mask_tensor,
            )
            mask_records: Dict[int, np.ndarray] = {}
            if frame_idx in remaining_frames and sample.object_id in obj_ids:
                obj_position = obj_ids.index(sample.object_id)
                mask_records[frame_idx] = (
                    prompt_masks[obj_position].squeeze().gt(0).cpu().numpy()
                )
                remaining_frames.discard(frame_idx)
            for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
                if out_frame_idx in remaining_frames:
                    for obj_position, obj_id in enumerate(out_obj_ids):
                        if obj_id != sample.object_id:
                            continue
                        mask = out_mask_logits[obj_position].squeeze().gt(0).cpu().numpy()
                        mask_records[out_frame_idx] = mask
                        remaining_frames.discard(out_frame_idx)
                        break
            if remaining_frames:
                missing = ", ".join(str(idx) for idx in sorted(remaining_frames))
                print(f"[warn] Missing {len(remaining_frames)} frames for {sample.sample_id}: {missing}")
            predictions[sample.sample_id][spec["key"]] = {
                "timeline_indices": timeline,
                "masks": mask_records,
            }
        del predictor
        if device.type == "cuda":
            torch.cuda.empty_cache()

    save_prediction_cache(VIDEO_CACHE_PATH, cache_meta, predictions)
    return predictions



In [10]:
# (legacy definitions removed to avoid overriding cache-aware helpers)


In [None]:
# Override video timeline renderer for clearer left-side labels
ROW_LABEL_X = -0.28


def _draw_row_label(ax, text: str, x_offset: float = ROW_LABEL_X):
    if not text:
        return
    ax.text(
        x_offset,
        0.5,
        text,
        transform=ax.transAxes,
        rotation=90,
        ha="center",
        va="center",
        fontsize=10,
        fontweight="bold",
        clip_on=False,
    )


def render_sav_video_timelines(sample: SAVVideoSample, sample_preds: Dict[str, Dict[str, Any]], pdf: PdfPages):
    cache: Dict[int, np.ndarray] = {}
    for scale in SCALE_ORDER:
        row_specs = [spec for spec in SPECS_BY_SCALE[scale] if spec["key"] in sample_preds]
        if not row_specs:
            continue
        timeline = sample.timeline_indices
        if not timeline:
            continue
        num_rows = 1 + len(row_specs)
        num_cols = len(timeline)
        fig, axes = plt.subplots(num_rows, num_cols, figsize=(3.2 * num_cols, 2.4 * num_rows))
        if num_rows == 1:
            axes = axes[np.newaxis, :]

        _draw_row_label(axes[0, 0], "GT")
        for col_idx, frame_idx in enumerate(timeline):
            img = _load_frame_image(sample, frame_idx, cache)
            gt_mask = sample.gt_masks_by_frame.get(frame_idx)
            plot_panel(
                axes[0, col_idx],
                img,
                mask=gt_mask,
                mask_color=GT_COLOR,
                points=sample.prompt_points if frame_idx == sample.prompt_frame_idx else None,
                top_label=f"Frame {frame_idx:05d}",
            )

        for row_idx, spec in enumerate(row_specs, start=1):
            _draw_row_label(axes[row_idx, 0], spec["label"])
            pred = sample_preds[spec["key"]]
            masks = pred.get("masks", {})
            for col_idx, frame_idx in enumerate(timeline):
                img = _load_frame_image(sample, frame_idx, cache)
                mask = masks.get(frame_idx)
                plot_panel(
                    axes[row_idx, col_idx],
                    img,
                    mask=mask,
                    mask_color=PRED_COLOR,
                    points=sample.prompt_points if frame_idx == sample.prompt_frame_idx else None,
                )

        fig.subplots_adjust(left=0.1, right=0.98, top=0.95, bottom=0.08, wspace=0.02, hspace=0.02)
        pdf.savefig(fig, bbox_inches="tight")
        plt.close(fig)



In [None]:
def blend_mask(image_np: np.ndarray, mask: np.ndarray, color: tuple, alpha: float) -> np.ndarray:
    overlay = image_np.astype(np.float32).copy()
    mask_bool = mask.astype(bool)
    color_arr = np.array(color, dtype=np.float32)
    overlay[mask_bool] = overlay[mask_bool] * (1 - alpha) + color_arr * alpha
    return overlay.astype(np.uint8)


def plot_points(ax, points: Optional[np.ndarray]):
    if points is None or len(points) == 0:
        return
    pts = np.array(points)
    ax.scatter(
        pts[:, 0],
        pts[:, 1],
        s=POINT_MARKER_SIZE,
        c=POINT_COLOR,
        marker="*",
        edgecolors=POINT_EDGE_COLOR,
        linewidths=POINT_LINEWIDTH,
    )


def plot_panel(
    ax,
    image_np: np.ndarray,
    mask: Optional[np.ndarray] = None,
    mask_color: tuple = PRED_COLOR,
    points: Optional[np.ndarray] = None,
    top_label: str = "",
    bottom_label: str = "",
):
    disp = image_np
    if mask is not None:
        disp = blend_mask(image_np, mask, mask_color, MASK_ALPHA)
    ax.imshow(disp)
    plot_points(ax, points)
    if top_label:
        ax.set_title(top_label, fontsize=25, fontweight="bold")
    if bottom_label:
        ax.text(
            0.5,
            -0.06,
            bottom_label,
            ha="center",
            va="top",
            fontsize=15,
            transform=ax.transAxes,
        )
    ax.axis("off")


def render_sa1b_one_point(sample: SA1BPromptSample, sample_preds: Dict[str, Dict[int, np.ndarray]], pdf: PdfPages):
    families = [
        fam
        for fam in FAMILY_ORDER
        if any(
            (spec := SPECS_BY_FAMILY[fam].get(scale)) and spec["key"] in sample_preds
            for scale in SCALE_ORDER
        )
    ]
    if not families:
        return
    scales = [
        scale
        for scale in SCALE_ORDER
        if any((spec := SPECS_BY_FAMILY[fam].get(scale)) and spec["key"] in sample_preds for fam in families)
    ]
    if not scales:
        return

    num_rows = len(families)
    num_cols = len(scales) + 1  # reference + scales
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(3.0 * num_cols, 2.8 * num_rows))
    if num_rows == 1:
        axes = axes[np.newaxis, :]
    for row_idx, fam in enumerate(families):
        axes[row_idx, 0].set_ylabel(FAMILY_DISPLAY[fam], rotation=90, fontsize=10, fontweight="bold", labelpad=18)
        plot_panel(
            axes[row_idx, 0],
            sample.image_np,
            mask=sample.gt_mask,
            mask_color=GT_COLOR,
            points=sample.points_by_count[1],
            top_label="GT" if row_idx == 0 else "",
            bottom_label="GT",
        )
        for col_offset, scale in enumerate(scales, start=1):
            spec = SPECS_BY_FAMILY[fam].get(scale)
            ax = axes[row_idx, col_offset]
            if spec and spec["key"] in sample_preds:
                plot_panel(
                    ax,
                    sample.image_np,
                    mask=sample_preds[spec["key"]][1],
                    mask_color=PRED_COLOR,
                    points=sample.points_by_count[1],
                    top_label=(f"{SCALE_DISPLAY[scale]}" if row_idx == 0 else ""),
                    bottom_label=spec["label"],
                )
            else:
                ax.axis("off")
                if row_idx == 0:
                    ax.set_title(f"{SCALE_DISPLAY[scale]} (missing)", fontsize=25, fontweight="bold")

    fig.subplots_adjust(left=0.02, right=0.98, top=0.95, bottom=0.08, wspace=0.05, hspace=0.05)
    fig.tight_layout()
    pdf.savefig(fig, bbox_inches="tight")
    plt.close(fig)


def render_sa1b_multi_scale(
    sample: SA1BPromptSample,
    sample_preds: Dict[str, Dict[int, np.ndarray]],
    scale: str,
    pdf: PdfPages,
):
    families = [
        fam
        for fam in FAMILY_ORDER
        if (spec := SPECS_BY_FAMILY[fam].get(scale)) and spec["key"] in sample_preds
    ]
    if not families:
        return

    num_rows = 1 + len(families)
    num_cols = len(POINT_SET_SIZES)
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(3.0 * num_cols, 2.8 * num_rows))
    if num_rows == 1:
        axes = axes[np.newaxis, :]

    for col_idx, count in enumerate(POINT_SET_SIZES):
        plot_panel(
            axes[0, col_idx],
            sample.image_np,
            mask=sample.gt_mask,
            mask_color=GT_COLOR,
            points=sample.points_by_count[count],
            top_label=f"{count} pt",
            bottom_label="GT",
        )

    for row_idx, fam in enumerate(families, start=1):
        spec = SPECS_BY_FAMILY[fam][scale]
        axes[row_idx, 0].set_ylabel(FAMILY_DISPLAY[fam], rotation=90, fontsize=15, fontweight="bold", labelpad=18)
        for col_idx, count in enumerate(POINT_SET_SIZES):
            plot_panel(
                axes[row_idx, col_idx],
                sample.image_np,
                mask=sample_preds[spec["key"]][count],
                mask_color=PRED_COLOR,
                points=sample.points_by_count[count],
                bottom_label=spec["label"],
            )

    fig.subplots_adjust(left=0.02, right=0.98, top=0.95, bottom=0.08, wspace=0.05, hspace=0.05)
    fig.tight_layout()
    pdf.savefig(fig, bbox_inches="tight")
    plt.close(fig)


def _load_frame_image(sample: SAVVideoSample, frame_idx: int, cache: Dict[int, np.ndarray]) -> np.ndarray:
    if frame_idx in cache:
        return cache[frame_idx]
    path = sample.frame_idx_to_path.get(frame_idx)
    if path is None:
        raise FileNotFoundError(f"Missing frame {frame_idx} for {sample.sample_id}")
    img = np.array(Image.open(path).convert("RGB"))
    cache[frame_idx] = img
    return img


def add_cover_page(pdf: PdfPages, title: str, sa1b_count: int, sav_count: int):
    fig = plt.figure(figsize=(8.27, 11.69))
    fig.text(0.05, 0.95, title, fontsize=24, weight="bold")
    fig.text(0.05, 0.91, f"Generated: {datetime.now().isoformat(timespec='seconds')}", fontsize=12)
    fig.text(0.05, 0.87, f"SA-1B samples: {sa1b_count} | SA-V videos: {sav_count}", fontsize=14)
    fig.text(0.05, 0.83, f"Point settings: {POINT_SET_SIZES}", fontsize=12)
    fig.text(0.05, 0.79, "Models", fontsize=14, weight="bold")

    y = 0.76
    for spec in MODEL_SPECS:
        fig.text(0.06, y, f"- {spec['label']} ({FAMILY_DISPLAY[spec['family']]} {SCALE_DISPLAY[spec['scale']]})", fontsize=11)
        y -= 0.02

    pdf.savefig(fig, bbox_inches="tight")
    plt.close(fig)


In [13]:
sa1b_samples = sample_sa1b_prompts(SA1B_VAL_ROOT, SA1B_NUM_IMAGES, POINT_SET_SIZES, RNG)
sav_samples = sample_sav_videos(SAV_VAL_FRAMES, SAV_VAL_ANN, SAV_NUM_VIDEOS, RNG)

if not sa1b_samples:
    raise RuntimeError("SA-1B 샘플을 찾지 못했습니다. 경로 설정을 확인하세요.")
if not sav_samples:
    raise RuntimeError("SA-V 샘플을 찾지 못했습니다. 경로 설정을 확인하세요.")

print(f"Collected {len(sa1b_samples)} SA-1B samples and {len(sav_samples)} SA-V videos.")


Collected 8 SA-1B samples and 3 SA-V videos.


In [14]:
# print("[Images] Running SA-1B qualitative inference ...")
# sa1b_preds = run_image_inference(MODEL_SPECS, sa1b_samples, POINT_SET_SIZES)
# with PdfPages(PDF_IMAGES_PATH) as pdf:
#     add_cover_page(pdf, "SA-1B Qualitative", len(sa1b_samples), 0)
#     for sample in sa1b_samples:
#         render_sa1b_one_point(sample, sa1b_preds.get(sample.sample_id, {}), pdf)
#         for scale in SCALE_ORDER:
#             render_sa1b_multi_scale(sample, sa1b_preds.get(sample.sample_id, {}), scale, pdf)
# print(f"[Images] Saved SA-1B PDF to {PDF_IMAGES_PATH.resolve()}")


In [15]:
print("[Videos] Running SA-V qualitative inference ...")
sav_preds = run_video_inference(MODEL_SPECS, sav_samples)
with PdfPages(PDF_VIDEOS_PATH) as pdf:
    add_cover_page(pdf, "SA-V Qualitative", 0, len(sav_samples))
    for sample in sav_samples:
        render_sav_video_timelines(sample, sav_preds.get(sample.sample_id, {}), pdf)
print(f"[Videos] Saved SA-V PDF to {PDF_VIDEOS_PATH.resolve()}")


[Videos] Running SA-V qualitative inference ...
[cache] Using cached SA-V predictions from sav_preds.pkl
[Videos] Saved SA-V PDF to /home/lji/SAM/sam2/qualitative_val_outputs/qualitative_val_videos.pdf
