In [None]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
import cv2

In [None]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

In [None]:
from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

In [None]:
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

### View First Frame

In [None]:
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "/home/minq02/curly/sam2/CUSTOM/frames"

# scan all the JPEG frame names in this directory
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))

In [None]:
inference_state = predictor.init_state(video_path=video_dir)\

# inference_state = predictor.init_state(
#     video_path=video_dir,
#     offload_video_to_cpu=True,
#     async_loading_frames=True
# )

### Segmenting using box prompt

In [None]:
predictor.reset_state(inference_state)

In [None]:
ann_frame_idx = 0  # the frame index we interact with
ann_obj_id = 0  # give a unique id to each object we interact with (it can be any integers)

# Let's add a box at (x_min, y_min, x_max, y_max) = (300, 0, 500, 400) to get started
box = np.array([1100, 1310, 2200, 1730], dtype=np.float32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    box=box,
)

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_box(box, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

In [None]:
ann_frame_idx = 0  # the frame index we interact with
ann_obj_id = 0  # give a unique id to each object we interact with (it can be any integers)

# Let's add a positive click at (x, y) = (460, 60) to refine the mask
points = np.array([[1900, 1500], [1200, 1300]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1, 0], np.int32)
# note that we also need to send the original box input along with
# the new refinement click together into `add_new_points_or_box`
box = np.array([1100, 1310, 2200, 1730], dtype=np.float32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
    box=box,
)

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_box(box, plt.gca())
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

### VISUALIZE TRACKING

In [None]:
# run propagation throughout the video and collect the results in a dict
video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# render the segmentation results every few frames
vis_frame_stride = 10
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)

### SAVE RESULTS

In [None]:
base = "/home/minq02/curly/sam2/CUSTOM"
img_dir, mask_dir, ovl_dir, lbl_dir = [os.path.join(base, p) for p in
    ("images","annotated_masks","overlays","labels")]
for d in (img_dir, mask_dir, ovl_dir, lbl_dir): os.makedirs(d, exist_ok=True)

# naming + stride
START_INDEX = 805  # CHANGE
PAD = 5
IMG_EXT = ".png"
SAVE_STRIDE = 1  # CHANGE

alpha = 0.35
CLASS_ID = 0
MIN_AREA_PX = 100
MAX_PTS = 200

def to_hw_u8(mask, H, W):
    m = np.asarray(mask)
    m = (m > 0).astype(np.uint8)
    m = np.squeeze(m)
    if m.ndim == 1:
        assert m.size == H*W
        m = m.reshape(H, W)
    elif m.ndim == 3:
        if m.shape[0] == 1:   m = m[0]
        elif m.shape[-1] == 1: m = m[..., 0]
        else:                  m = (m > 0).any(axis=-1).astype(np.uint8)
    return (m * 255).astype(np.uint8)

cur_idx = START_INDEX

for frame_idx in range(0, len(frame_names), SAVE_STRIDE):   # <- stride here
    stem_new = f"{cur_idx:0{PAD}d}"                         # increment only when saved
    cur_idx += 1

    img_path = os.path.join(video_dir, frame_names[frame_idx])
    img = Image.open(img_path).convert("RGB")
    img_np = np.array(img); H, W = img_np.shape[:2]

    obj_dict = video_segments.get(frame_idx, {})
    masks = list(obj_dict.values())

    Image.fromarray(img_np).save(os.path.join(img_dir, f"{stem_new}{IMG_EXT}"), quality=95)

    if not masks:
        open(os.path.join(lbl_dir, f"{stem_new}.txt"), "w").close()
        Image.fromarray(img_np).save(os.path.join(ovl_dir, f"{stem_new}.png"))
        continue

    if len(masks) > 1:
        def area(m): return int(np.count_nonzero(to_hw_u8(m, H, W)))
        mask = masks[int(np.argmax([area(m) for m in masks]))]
    else:
        mask = masks[0]

    mask_u8 = to_hw_u8(mask, H, W)
    Image.fromarray(mask_u8, mode="L").save(os.path.join(mask_dir, f"{stem_new}.png"))

    contours, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    lines = []
    for cnt in contours:
        if cv2.contourArea(cnt) < MIN_AREA_PX: continue
        eps = 0.001 * cv2.arcLength(cnt, True)
        cnt = cv2.approxPolyDP(cnt, eps, True).reshape(-1, 2)
        if len(cnt) > MAX_PTS:
            cnt = cnt[::int(np.ceil(len(cnt)/MAX_PTS))]
        if len(cnt) < 3: continue
        cnt = cnt.astype(np.float32)
        cnt[:,0] = np.clip(cnt[:,0]/W, 0, 1); cnt[:,1] = np.clip(cnt[:,1]/H, 0, 1)
        lines.append(f"{CLASS_ID} " + " ".join(f"{x:.6f} {y:.6f}" for x,y in cnt))

    with open(os.path.join(lbl_dir, f"{stem_new}.txt"), "w") as f:
        f.write("\n".join(lines))

    overlay = img_np.copy()
    if len(contours) > 0:
        fill = np.zeros_like(overlay)
        cv2.fillPoly(fill, contours, (40, 220, 240))
        overlay = cv2.addWeighted(overlay, 1.0, fill, alpha, 0.0)
        cv2.drawContours(overlay, contours, -1, (255, 255, 255), 2)
    Image.fromarray(overlay).save(os.path.join(ovl_dir, f"{stem_new}.png"))
