In [1]:
import sys, os, json, time, subprocess, pathlib
from pathlib import Path
from davis2017.davis import DAVIS
import imageio.v3 as iio
import numpy as np
from tqdm import tqdm
import torch

DAVIS_ROOT  = Path("./data/davis/DAVIS")          # ← adjust
OUT_DIR     = Path("./data/sam2_preds")     # where we’ll save PNGs
OUT_DIR.mkdir(parents=True, exist_ok=True)

# DAVIS helper (semi-supervised = first-frame GT mask)
ds = DAVIS(str(DAVIS_ROOT), task="semi-supervised", subset="val", resolution="480p")
print(f"Loaded {len(ds.sequences)} validation sequences")


Loaded 30 validation sequences


In [2]:
from sam2.build_sam import build_sam2_video_predictor

device = "cuda" if torch.cuda.is_available() else "cpu"


sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

In [3]:
import matplotlib.pyplot as plt

def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def run_sequence(seq_name: str) -> float:
    img_dir = Path(ds.sequences[seq_name]['images'][0]).parent
    inference_state   = predictor.init_state(str(img_dir))
    predictor.reset_state(inference_state)

    # --- first-frame GT mask (ensure 2-D) ---
    first_gt = iio.imread(ds.sequences[seq_name]['masks'][0])
    if first_gt.ndim == 3:          # palette PNG → RGB/RGBA
        first_gt = first_gt[..., 0]

    for k in range(1, int(first_gt.max()) + 1):
        predictor.add_new_mask(
            inference_state,
            frame_idx=0,
            obj_id=f"obj-{k}",
            mask=(first_gt == k).astype("uint8"),
        )

    # --- propagate & save ---
    t0 = time.time()
    n_frames = len(ds.sequences[seq_name]['images'])
    # run propagation throughout the video and collect the results in a dict
    video_segments = {}  # video_segments contains the per-frame segmentation results
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
        video_segments[out_frame_idx] = {
            out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
            for i, out_obj_id in enumerate(out_obj_ids)
        }
        # # save the results to PNGs
        # for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        #     out_mask = out_mask.astype("uint8") * 255
        #     out_path = OUT_DIR / f"{seq_name}_{out_frame_idx:04d}_{out_obj_id}.png"
        #     iio.imwrite(out_path, out_mask)

    vis_frame_stride = 1
    plt.close("all")
    for out_frame_idx in range(0, len(n_frames), vis_frame_stride):
        plt.figure(figsize=(6, 4))
        plt.title(f"frame {out_frame_idx}")
        plt.imshow(Image.open(ds.sequences[seq_name]['images'][out_frame_idx]))
        for out_obj_id, out_mask in video_segments[out_frame_idx].items():
            show_mask(out_mask, plt.gca(), obj_id=out_obj_id)

    return n_frames / (time.time() - t0)


In [None]:
fps_vals = {}
for seq in tqdm(ds.get_sequences(), desc="SAM2 on DAVIS-val"):
    fps_vals[seq] = run_sequence(seq)

print(f"Mean FPS: {sum(fps_vals.values())/len(fps_vals):.2f}")


frame loading (JPEG): 100%|██████████| 69/69 [00:01<00:00, 60.71it/s]

Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
  pred_masks_gpu = fill_holes_in_mask_scores(


Skipping frame 1 due to low MAD (0.04)
Skipping frame 2 due to low MAD (0.06)
Skipping frame 3 due to low MAD (0.07)
Skipping frame 4 due to low MAD (0.09)
