In [1]:
import sys, os, json, time, subprocess, pathlib
from pathlib import Path
from davis2017.davis import DAVIS
import imageio.v3 as iio
import numpy as np
from tqdm import tqdm
import torch

# ── USER‐CONFIGURABLE PATHS ──────────────────────────────────────────────────
DAVIS_ROOT  = Path("./data/davis/DAVIS")          # ← point this at your DAVIS folder
OUT_DIR     = Path("./data/sam2_preds_MA_0.15")           # ← where we’ll write out PNGs
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ── STEP 1: load the DAVIS “val” split (semi‐supervised task) ───────────────────
ds = DAVIS(str(DAVIS_ROOT), task="semi-supervised", subset="val", resolution="480p")
print(f"Loaded {len(ds.sequences)} validation sequences")

# ── STEP 2: build SAM 2 video predictor ─────────────────────────────────────────
from sam2.build_sam import build_sam2_video_predictor

device = "cuda" if torch.cuda.is_available() else "cpu"
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"   # ← adjust if needed
model_cfg       = "configs/sam2.1/sam2.1_hiera_l.yaml"   # ← adjust if needed

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)
predictor.to(device)   # make sure model is on CUDA if available

Loaded 30 validation sequences


SAM2VideoPredictor(
  (image_encoder): ImageEncoder(
    (trunk): Hiera(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 144, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
      )
      (blocks): ModuleList(
        (0-1): 2 x MultiScaleBlock(
          (norm1): LayerNorm((144,), eps=1e-06, elementwise_affine=True)
          (attn): MultiScaleAttention(
            (qkv): Linear(in_features=144, out_features=432, bias=True)
            (proj): Linear(in_features=144, out_features=144, bias=True)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((144,), eps=1e-06, elementwise_affine=True)
          (mlp): MLP(
            (layers): ModuleList(
              (0): Linear(in_features=144, out_features=576, bias=True)
              (1): Linear(in_features=576, out_features=144, bias=True)
            )
            (act): GELU(approximate='none')
          )
        )
        (2): MultiScaleBlock(
          (norm1): LayerNorm((144,), eps=1e-06, elemen

In [None]:
# ── STEP 3: for each DAVIS sequence, run SAM 2 and save outputs ───────────────
fps = 0.0
for seq in ds.sequences:
    print(f"\n=== Processing sequence: {seq} ===")

    # 3a) directories of images & GT masks for this sequence
    img_dir  = DAVIS_ROOT / "JPEGImages"  / "480p" / seq
    mask_dir = DAVIS_ROOT / "Annotations" / "480p" / seq

    img_paths  = sorted(img_dir.glob("*.jpg"))
    mask_paths = sorted(mask_dir.glob("*.png"))
    assert len(img_paths) == len(mask_paths), (
        f"Image/mask count mismatch in {seq}: {len(img_paths)} vs {len(mask_paths)}"
    )

    # 3b) initialize inference state by “loading” this sequence as a video
    #     We pass the *directory* of frames to init_state.  Internally, it will call
    #     `load_video_frames(video_path=video_dir, ...)` and store all frames in memory.
    video_dir = str(img_dir)  # e.g. "./data/davis/DAVIS/JPEGImages/480p/<seq>"
    inference_state = predictor.init_state(
        video_path=video_dir,
        offload_video_to_cpu=False,
        offload_state_to_cpu=False,
        async_loading_frames=False
    )
    

# (b) load the single “00000.png” which contains two different colored regions
    rgb = iio.imread(str(mask_dir / "00000.png"))  # shape (H, W, 3)
    H, W, C = rgb.shape
    assert C == 3, "Expected a 3‐channel (RGB) first‐frame mask."

    # (c) find all unique RGB colors except black
    flat = rgb.reshape(-1, 3)                        # shape (H*W, 3)
    uniq_colors = np.unique(flat, axis=0)            # shape (K, 3), where K ≤ (H*W)
    # Remove the black color (0,0,0) if present
    non_black = [tuple(c) for c in uniq_colors if not np.all(c == 0)]
    if len(non_black) == 0:
        raise RuntimeError(f"No non‐black colors found in {seq}/00000.png")

    # (d) for each unique non‐black color, build a 2D boolean mask and register it
    print(f"Found {len(non_black)} unique non‐black colors in {seq}/00000.png")
    for idx, color in enumerate(non_black):
        # color is something like (200, 0, 0) or (0, 200, 0)
        R, G, B = color
        # build a binary mask: True where pixel == this color
        bin_mask = np.logical_and.reduce([
            rgb[:, :, 0] == R,
            rgb[:, :, 1] == G,
            rgb[:, :, 2] == B
        ])  # shape (H, W), dtype=bool

        # wrap as torch.bool on the same device as SAM 2
        mask2d = torch.from_numpy(bin_mask).to(device)

        # register this mask as object `idx`
        predictor.add_new_mask(
            inference_state=inference_state,
            frame_idx=0,
            obj_id=idx,  # choose 0,1,2,… per color
            mask=mask2d
        )

    # 3e) now propagate through all frames.  As each new frame is processed,
    #     propagate_in_video yields (frame_idx, [obj_ids], video_res_masks).
    #
    #     We’ll save each mask as “00000.png”, “00001.png”, … under OUT_DIR/<seq>/
    seq_out_dir = OUT_DIR / seq
    seq_out_dir.mkdir(parents=True, exist_ok=True)

    t0 = time.time()
    total_time = 0.0
    for frame_idx, obj_ids, video_res_masks in tqdm(
        predictor.propagate_in_video(inference_state),
        total=len(img_paths),
        desc=f"Propagating {seq}"
    ):
        total_time += time.time() - t0
        # # ‣ frame_idx is an integer (1,2,3,…).  video_res_masks is a tensor of shape
        # #   (num_objects, H, W).  For DAVIS, num_objects==1.
        # #
        # # ‣ Thresholding has already happened internally; `video_res_masks` is
        # #   a float‐tensor where positive values correspond to predicted “object.”
        # mask_np = (video_res_masks[0].cpu().numpy() > 0.0).astype(np.uint8) * 255

        # # Save with zero‐padded five digits to match DAVIS naming:
        # save_name = f"{frame_idx:05d}.png"
        # save_path = seq_out_dir / save_name
        # iio.imwrite(str(save_path), mask_np)

        # Suppose `video_res_masks` is whatever you get from propagate_in_video:
        #   • If there is only one object, it may be a 2D tensor of shape (H, W)
        #   • If there are multiple objects, it will be a 3D tensor of shape (O, H, W)

        pred_np = video_res_masks.cpu().numpy()   # dtype=float32 or float; # ───────────────────────────────────────────────────────────────
        # Assume you already did:
        #   pred_np = video_res_masks.cpu().numpy()

        # 1) Check how many dimensions `pred_np` has:
        if pred_np.ndim == 2:
            # Case A: single object, shape = (H, W)
            H, W = pred_np.shape
            O = 1
            pred_np = pred_np[np.newaxis, ...]  # -> now shape (1, H, W)

        elif pred_np.ndim == 3:
            # Could be either:
            #  (A) shape = (1, H, W)   ← single object with a leading axis
            #  (B) shape = (O, H, W)   ← multiple objects, no extra channel axis
            if pred_np.shape[0] == 1:
                # Treat as “one‐object” → squeeze to (1, H, W) (already fits our convention)
                O, H, W = pred_np.shape
            else:
                # Multi‐object already: (O, H, W)
                O, H, W = pred_np.shape
            # (no need to reshape because it’s already (O, H, W))

        elif pred_np.ndim == 4:
            # Some SAM 2 builds return (O, 1, H, W). In that case:
            #   • pred_np.shape = (O, 1, H, W)
            #   → we want to drop the “channel” dimension (axis=1).
            O = pred_np.shape[0]
            H = pred_np.shape[2]
            W = pred_np.shape[3]
            pred_np = pred_np[:, 0, :, :]  # now shape (O, H, W)

        else:
            raise RuntimeError(f"Unexpected mask array with ndim={pred_np.ndim}, shape={pred_np.shape}")

        # At this point:
        #   • pred_np is guaranteed to have shape (O, H, W)
        #   • O, H, W are set correctly
        # ───────────────────────────────────────────────────────────────

        # Now you can build your colored output exactly as before:

        colored = np.zeros((H, W, 3), dtype=np.uint8)

        for i in range(O):
            mask_i = (pred_np[i] > 0.0)   # boolean mask (H, W)
            if not mask_i.any():
                continue
            R, G, B = non_black[i]  # the original RGB for object i
            colored[mask_i, 0] = R
            colored[mask_i, 1] = G
            colored[mask_i, 2] = B

        save_name = f"{frame_idx:05d}.png"
        save_path = seq_out_dir / save_name
        iio.imwrite(str(save_path), colored)
        t0 = time.time()

    # record FPS
    cur_fps = len(img_paths) / total_time if total_time > 0 else 0.0
    fps += cur_fps
    print(f"→ Processed {len(img_paths)} frames in {seq} at {cur_fps:.2f} FPS")

    print(f"→ Saved all predicted masks for {seq} in {seq_out_dir}")

# save FPS into a file
fps_file = OUT_DIR / "fps.txt"
with open(fps_file, "w") as f:
    f.write(f"{fps / len(ds.sequences):.2f}\n")

print("\nAll sequences processed.")
print(f"Average FPS across all sequences: {fps / len(ds.sequences):.2f}")
print(f"Your SAM 2 masks live under: {OUT_DIR}")



=== Processing sequence: bike-packing ===


frame loading (JPEG): 100%|██████████| 69/69 [00:01<00:00, 49.72it/s]

Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
  pred_masks_gpu = fill_holes_in_mask_scores(


Found 2 unique non‐black colors in bike-packing/00000.png


propagate in video: 100%|██████████| 69/69 [00:11<00:00,  5.95it/s]7it/s]
Propagating bike-packing: 100%|██████████| 69/69 [00:11<00:00,  5.94it/s]


Skipped 43 frames due to low MAD.
→ Processed 69 frames in bike-packing at 6.56 FPS
→ Saved all predicted masks for bike-packing in data/sam2_preds_MA_0.15/bike-packing

=== Processing sequence: blackswan ===


frame loading (JPEG): 100%|██████████| 50/50 [00:00<00:00, 57.19it/s]


Found 1 unique non‐black colors in blackswan/00000.png


propagate in video: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]/s]
Propagating blackswan: 100%|██████████| 50/50 [00:03<00:00, 13.68it/s]


Skipped 47 frames due to low MAD.
→ Processed 50 frames in blackswan at 15.72 FPS
→ Saved all predicted masks for blackswan in data/sam2_preds_MA_0.15/blackswan

=== Processing sequence: bmx-trees ===


frame loading (JPEG): 100%|██████████| 80/80 [00:01<00:00, 57.32it/s]


Found 2 unique non‐black colors in bmx-trees/00000.png


propagate in video: 100%|██████████| 80/80 [00:19<00:00,  4.13it/s]/s]
Propagating bmx-trees: 100%|██████████| 80/80 [00:19<00:00,  4.13it/s]


Skipped 14 frames due to low MAD.
→ Processed 80 frames in bmx-trees at 4.32 FPS
→ Saved all predicted masks for bmx-trees in data/sam2_preds_MA_0.15/bmx-trees

=== Processing sequence: breakdance ===


frame loading (JPEG): 100%|██████████| 84/84 [00:01<00:00, 56.32it/s]


Found 1 unique non‐black colors in breakdance/00000.png


propagate in video: 100%|██████████| 84/84 [00:16<00:00,  5.22it/s]t/s]
Propagating breakdance: 100%|██████████| 84/84 [00:16<00:00,  5.22it/s]


Skipped 11 frames due to low MAD.
→ Processed 84 frames in breakdance at 5.50 FPS
→ Saved all predicted masks for breakdance in data/sam2_preds_MA_0.15/breakdance

=== Processing sequence: camel ===


frame loading (JPEG): 100%|██████████| 90/90 [00:01<00:00, 60.27it/s]


Found 1 unique non‐black colors in camel/00000.png


propagate in video: 100%|██████████| 90/90 [00:07<00:00, 12.26it/s]
Propagating camel: 100%|██████████| 90/90 [00:07<00:00, 12.25it/s]


Skipped 79 frames due to low MAD.
→ Processed 90 frames in camel at 13.94 FPS
→ Saved all predicted masks for camel in data/sam2_preds_MA_0.15/camel

=== Processing sequence: car-roundabout ===


frame loading (JPEG): 100%|██████████| 75/75 [00:01<00:00, 60.50it/s]


Found 1 unique non‐black colors in car-roundabout/00000.png


propagate in video: 100%|██████████| 75/75 [00:06<00:00, 11.91it/s].37it/s]
Propagating car-roundabout: 100%|██████████| 75/75 [00:06<00:00, 11.90it/s]


Skipped 66 frames due to low MAD.
→ Processed 75 frames in car-roundabout at 13.28 FPS
→ Saved all predicted masks for car-roundabout in data/sam2_preds_MA_0.15/car-roundabout

=== Processing sequence: car-shadow ===


frame loading (JPEG): 100%|██████████| 40/40 [00:00<00:00, 57.49it/s]


Found 1 unique non‐black colors in car-shadow/00000.png


propagate in video: 100%|██████████| 40/40 [00:05<00:00,  7.26it/s]t/s]
Propagating car-shadow: 100%|██████████| 40/40 [00:05<00:00,  7.25it/s]


Skipped 22 frames due to low MAD.
→ Processed 40 frames in car-shadow at 7.70 FPS
→ Saved all predicted masks for car-shadow in data/sam2_preds_MA_0.15/car-shadow

=== Processing sequence: cows ===


frame loading (JPEG): 100%|██████████| 104/104 [00:01<00:00, 60.63it/s]


Found 1 unique non‐black colors in cows/00000.png


propagate in video: 100%|██████████| 104/104 [00:07<00:00, 13.13it/s]
Propagating cows: 100%|██████████| 104/104 [00:07<00:00, 13.12it/s]


Skipped 96 frames due to low MAD.
→ Processed 104 frames in cows at 14.93 FPS
→ Saved all predicted masks for cows in data/sam2_preds_MA_0.15/cows

=== Processing sequence: dance-twirl ===


frame loading (JPEG): 100%|██████████| 90/90 [00:01<00:00, 60.07it/s]


Found 1 unique non‐black colors in dance-twirl/00000.png


Propagating dance-twirl:  20%|██        | 18/90 [00:02<00:12,  5.85it/s]

Skipped 7 frames due to low MAD.
→ Processed 50 frames in kite-surf at 3.11 FPS
→ Saved all predicted masks for kite-surf in data/sam2_preds_MA_0.15/kite-surf

=== Processing sequence: lab-coat ===


frame loading (JPEG): 100%|██████████| 47/47 [00:00<00:00, 54.95it/s]


Found 5 unique non‐black colors in lab-coat/00000.png


propagate in video: 100%|██████████| 47/47 [00:07<00:00,  5.97it/s]s]
Propagating lab-coat: 100%|██████████| 47/47 [00:07<00:00,  5.95it/s]


Skipped 35 frames due to low MAD.
→ Processed 47 frames in lab-coat at 6.61 FPS
→ Saved all predicted masks for lab-coat in data/sam2_preds_MA_0.15/lab-coat

=== Processing sequence: libby ===


frame loading (JPEG): 100%|██████████| 49/49 [00:00<00:00, 58.30it/s]


Found 1 unique non‐black colors in libby/00000.png


propagate in video: 100%|██████████| 49/49 [00:05<00:00,  9.00it/s]
Propagating libby: 100%|██████████| 49/49 [00:05<00:00,  8.99it/s]


Skipped 33 frames due to low MAD.
→ Processed 49 frames in libby at 9.71 FPS
→ Saved all predicted masks for libby in data/sam2_preds_MA_0.15/libby

=== Processing sequence: loading ===


frame loading (JPEG): 100%|██████████| 50/50 [00:00<00:00, 60.46it/s]


Found 3 unique non‐black colors in loading/00000.png


propagate in video: 100%|██████████| 50/50 [00:05<00:00,  9.10it/s]]
Propagating loading: 100%|██████████| 50/50 [00:05<00:00,  9.08it/s]


Skipped 43 frames due to low MAD.
→ Processed 50 frames in loading at 10.56 FPS
→ Saved all predicted masks for loading in data/sam2_preds_MA_0.15/loading

=== Processing sequence: mbike-trick ===


frame loading (JPEG): 100%|██████████| 79/79 [00:01<00:00, 60.53it/s]


Found 2 unique non‐black colors in mbike-trick/00000.png


Propagating mbike-trick:  72%|███████▏  | 57/79 [00:08<00:03,  6.17it/s]