# Model Footprint & Throughput Profiler

Adaptive QAT / Min-Max QAT 등 다양한 체크포인트를 불러와서

* 파라미터 수
* 가중치 비트 기반 예상 모델 사이즈
* FPS / Latency

를 비교하기 위한 템플릿입니다. 아래 설정 블록을 자신의 실험 경로에 맞게 수정한 뒤 셀을 순서대로 실행하세요.

In [None]:
import math
import time
from pathlib import Path
from typing import Any, Dict, Optional

import numpy as np
import torch

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

torch.set_grad_enabled(False)

# 노트북을 sam2/notebooks 폴더에서 실행하는 경우, sam2 상위로 이동
if Path.cwd().name == "notebooks":
    REPO_ROOT = Path.cwd().parent
    %cd ..
else:
    REPO_ROOT = Path.cwd()

In [None]:
# ==== 사용자 설정 ====
# ckpt_path는 실제 체크포인트 경로로 교체하세요.

CHECKPOINTS = [
    {
        "name": "original_base_plus",
        "ckpt_path": "checkpoints/sam2.1_hiera_base_plus.pt",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_b+.yaml",
        "bitwidth_key": None,  # state dict 안에서 layer별 비트를 담고 있는 키 (없으면 None)
        "notes": "Base plus model"
    },
    {
        "name": "original_small",
        "ckpt_path": "checkpoints/sam2.1_hiera_small.pt",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_s.yaml",
        "bitwidth_key": None,  # state dict 안에서 layer별 비트를 담고 있는 키 (없으면 None)
        "notes": "Small model"
    },
    {
        "name": "original_tiny",
        "ckpt_path": "checkpoints/sam2.1_hiera_tiny.pt",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_t.yaml",
        "bitwidth_key": None,  # state dict 안에서 layer별 비트를 담고 있는 키 (없으면 None)
        "notes": "Tiny model"
    },
    {
        "name": "ours_base_plus",
        "ckpt_path": "sam2_logs/adaptive_qat_toy_20251110_155500/checkpoints/checkpoint.pt",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_b+.yaml",
        "bitwidth_key": "bitwidths",  # state dict 안에서 layer별 비트를 담고 있는 키 (없으면 None)
        "notes": "Base plus model with adaptive QAT"
    },
    {
        "name": "ours_small",
        "ckpt_path": "sam2_logs/adaptive_qat_toy_small_20251111_172858/checkpoints/checkpoint.pt",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_s.yaml",
        "bitwidth_key": "bitwidths",  # state dict 안에서 layer별 비트를 담고 있는 키 (없으면 None)
        "notes": "Small model with adaptive QAT"
    },
    {
        "name": "ours_tiny",
        "ckpt_path": "sam2_logs/adaptive_qat_toy_tiny_20251111_173031/checkpoints/checkpoint.pt",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_t.yaml",
        "bitwidth_key": "bitwidths",  # state dict 안에서 layer별 비트를 담고 있는 키 (없으면 None)
        "notes": "Tiny model with adaptive QAT"
    },
]

DEVICE = "cuda:7"
BITWIDTH_KEY_FALLBACK = None  # 모든 entry에 공통 키가 있다면 지정
WARMUP_RUNS = 5
TIMED_RUNS = 20
IMAGE_SIZE = 1024  # FPS 측정 시 사용할 입력 해상도

SUMMARY_RESULTS: list[Dict[str, Any]] = []

In [None]:
BENCHMARK_IMAGES = sorted([
    *Path('../datasets/benchmark/coco/images').glob('*.jpg'),
    *Path('../datasets/benchmark/sa1b').glob('*.jpg'),
])
BENCHMARK_VIDEO_FRAMES = []
for video_dir in Path('../datasets/benchmark/sav/JPEGImages_24fps').glob('sav_*'):
    frames = sorted(video_dir.glob('*.jpg'))
    BENCHMARK_VIDEO_FRAMES.extend(frames[:3])  # first 3 frames per video
BENCHMARK_IMGS = [Path(p) for p in BENCHMARK_IMAGES + BENCHMARK_VIDEO_FRAMES]
print(f'Benchmark images collected: {len(BENCHMARK_IMGS)}')


In [None]:
fallback_bits = []
NON_QUANTIZED_MODULES = {"image_encoder.trunk.pos_embed", "image_encoder.trunk.pos_embed_window"}


In [None]:
BIT_CONTROLLER_PREFIX = "bit_controller."
BIT_CONTROLLER_DOT_TOKEN = "_DOT_"


def _resolve_model_state(state: Dict[str, Any]) -> Dict[str, torch.Tensor]:
    candidates = ("model", "student", "ema", "state_dict", "module")
    for key in candidates:
        value = state.get(key)
        if isinstance(value, dict):
            return value
    return {k: v for k, v in state.items() if torch.is_tensor(v)}


def _merge_bit_controller_state(state: Dict[str, Any], model_state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    merged = dict(model_state)
    def _inject(source):
        if not isinstance(source, dict):
            return
        for key, value in source.items():
            if isinstance(key, str) and key.startswith(BIT_CONTROLLER_PREFIX) and torch.is_tensor(value):
                merged[key] = value
    _inject(state.get('bit_controller'))
    _inject(state.get('model_aux'))
    return merged


def _extract_bits_from_bit_controller(model_state: Dict[str, torch.Tensor]) -> Dict[str, float]:
    bit_map: Dict[str, float] = {}
    for name, tensor in model_state.items():
        if not isinstance(tensor, torch.Tensor):
            continue
        if not name.startswith(BIT_CONTROLLER_PREFIX):
            continue
        suffix = name[len(BIT_CONTROLLER_PREFIX):]
        if suffix.startswith("_"):
            continue
        try:
            value = float(tensor.detach().cpu().mean().item())
        except Exception:
            continue
        layer_name = suffix.replace(BIT_CONTROLLER_DOT_TOKEN, ".")
        bit_map[layer_name] = value
    return bit_map


def _get_bitwidth_map(state: Dict[str, Any], model_state: Dict[str, torch.Tensor], entry: Dict[str, Any]) -> Dict[str, float]:
    bit_map = _extract_bits_from_bit_controller(model_state)
    if bit_map:
        return bit_map
    key = entry.get("bitwidth_key") or BITWIDTH_KEY_FALLBACK
    if key and key in state and isinstance(state[key], dict):
        return state[key]
    return {}


def _bytes_to_megabytes(num_bytes: float) -> float:
    return num_bytes / (1024 ** 2)


def summarize_checkpoint(entry: Dict[str, Any]) -> Dict[str, Any]:
    ckpt_path = Path(entry["ckpt_path"]).expanduser()
    if not ckpt_path.is_file():
        raise FileNotFoundError(f"Missing checkpoint: {ckpt_path}")

    state = torch.load(ckpt_path, map_location="cpu")
    model_state = _merge_bit_controller_state(state, _resolve_model_state(state))
    bitwidth_map = _get_bitwidth_map(state, model_state, entry) or {}
    non_quantized = set(NON_QUANTIZED_MODULES)

    default_w_bits = entry.get("default_bits", {}).get("weight", 32)
    total_params = 0
    total_bits = 0.0

    for name, tensor in model_state.items():
        if not torch.is_tensor(tensor):
            continue
        if name.startswith(BIT_CONTROLLER_PREFIX):
            continue
        numel = tensor.numel()
        base_name = name
        for suffix in (".weight", ".bias"):
            if base_name.endswith(suffix):
                base_name = base_name[:-len(suffix)]
                break
        total_params += numel
        layer_bits = bitwidth_map.get(name)
        candidate = base_name
        while layer_bits is None and '.' in candidate:
            candidate = candidate.rsplit('.', 1)[0]
            layer_bits = bitwidth_map.get(candidate)
        if layer_bits is None and base_name in non_quantized:
            bits = float(default_w_bits)
        elif layer_bits is None:
            bits = float(default_w_bits)
        else:
            bits = float(layer_bits)
        total_bits += numel * bits

    fp32_size_mb = _bytes_to_megabytes(total_params * 4)
    avg_bits = (total_bits / max(total_params, 1)) if total_params else float(default_w_bits)
    est_size_mb = _bytes_to_megabytes(total_bits / 8)

    summary = {
        "name": entry["name"],
        "params": total_params,
        "fp32_size_mb": fp32_size_mb,
        "est_quant_size_mb": est_size_mb,
        "avg_weight_bits": avg_bits,
        "default_weight_bits": default_w_bits,
        "parsed_bit_layers": len(bitwidth_map),
        "notes": entry.get("notes", "")
    }
    return summary


In [None]:
SUMMARY_RESULTS = [summarize_checkpoint(entry) for entry in CHECKPOINTS]

try:
    import pandas as pd
    display(pd.DataFrame(SUMMARY_RESULTS))
except ImportError:
    SUMMARY_RESULTS

In [None]:
def _sync(device: str) -> None:
    if device.startswith("cuda") and torch.cuda.is_available():
        torch.cuda.synchronize(torch.device(device))

def build_predictor(entry: Dict[str, Any], device: str = DEVICE) -> SAM2ImagePredictor:
    predictor_model = build_sam2(
        config_file=entry["model_cfg"],
        ckpt_path=entry["ckpt_path"],
        device=device,
        apply_postprocessing=True,
    )
    return SAM2ImagePredictor(predictor_model)

def benchmark_entry(entry: Dict[str, Any], *, device: str = DEVICE, image_size: int = IMAGE_SIZE,
                    warmup: int = WARMUP_RUNS, runs: int = TIMED_RUNS) -> Dict[str, Any]:
    predictor = build_predictor(entry, device=device)

    dummy_image = (np.random.rand(image_size, image_size, 3) * 255).astype(np.uint8)
    predictor.set_image(dummy_image)
    full_box = np.array([0, 0, image_size, image_size])

    for _ in range(max(1, warmup)):
        predictor.predict(box=full_box, multimask_output=False)

    timings = []
    for _ in range(max(1, runs)):
        _sync(device)
        start = time.perf_counter()
        predictor.predict(box=full_box, multimask_output=False)
        _sync(device)
        timings.append(time.perf_counter() - start)

    avg_latency = float(np.mean(timings))
    fps = float(1.0 / avg_latency) if avg_latency > 0 else float("inf")
    return {
        "name": entry["name"],
        "latency_ms": avg_latency * 1000,
        "fps": fps,
        "image_size": image_size,
        "runs": runs,
    }

In [None]:
benchmark_results = []

if DEVICE.startswith("cuda") and not torch.cuda.is_available():
    print(f"CUDA 디바이스({DEVICE})를 사용할 수 없어 FPS 측정을 건너뜁니다.")
else:
    for entry in CHECKPOINTS:
        try:
            benchmark_results.append(benchmark_entry(entry))
        except Exception as exc:
            print(f"[WARN] {entry['name']} 측정 실패: {exc}")

if benchmark_results:
    try:
        import pandas as pd
        display(pd.DataFrame(benchmark_results))
    except ImportError:
        benchmark_results