# SAM2 Video Predictor
https://github.com/facebookresearch/sam2/blob/main/notebooks/video_predictor_example.ipynb

## Environment Setup

### Import from Hugging Face

In [None]:
try:
    import huggingface_hub
    print("`huggingface_hub` is installed.")
except ImportError:
    print("`huggingface_hub` is not installed.")


In [None]:
using_colab = True

In [None]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/sam2.git'

### GPU Setup
- Connect A100 or L4 GPUs

In [None]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

### SAM2 Checkpoint Model

In [None]:
    # !mkdir -p videos
    # !wget -P videos https://dl.fbaipublicfiles.com/segment_anything_2/assets/bedroom.zip
    # !unzip -d videos videos/bedroom.zip

    !mkdir -p ../checkpoints/
    !wget -P ../checkpoints/ https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt

### Loading the SAM 2 video predictor

In [None]:
from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

In [None]:
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

## Sample Video

### Dataset
- temporary import (uploaded files get deleted once the runtime is recycled)

In [None]:
from google.colab import files

# Upload a video file from your local machine
uploaded = files.upload()

# List the uploaded files
for file_name in uploaded.keys():
    print(f"Uploaded file: {file_name}")

- Convert video to JPEG frames (for a single video for now)

In [None]:
# create output folder
clip = '2017_G6016_Q4_011'
os.makedirs(f"/content/{clip}", exist_ok=True)

In [None]:
# ffmpeg version check
!ffmpeg -version

In [None]:
!ffmpeg -i /content/{clip}.mp4 -q:v 2 -start_number 0 /content/{clip}/%05d.jpg

- Set frame directory

In [None]:
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = f"/content/{clip}/"

# scan all the JPEG frame names in this directory
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]

frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

- Display the first frame

In [None]:
# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))

### Initialize Inference State
- SAM 2 requires stateful inference for interactive video segmentation, so we need to initialize an inference state on this video.

In [None]:
inference_state = predictor.init_state(video_path=video_dir)

## Segment and Track a Single Object

In [None]:
# reset if any previous session
predictor.reset_state(inference_state)

### Identify xy-coordinates
- identify coordinates externally

### Add positive clicks on a frame

enable multi-prompts (i.e., multi-objects)

In [None]:
prompts = {} # hold all the clicks we add for visualization

Player 1

In [None]:
frame_idx = 0  # the frame index we interact with
obj_id = 1  # give a unique id to each object we interact with (it can be any integers)

# Let's add a positive click to get started
points = np.array([[602, 320]], dtype=np.float32)

# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], np.int32)
prompts[obj_id] = points, labels
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=frame_idx,
    obj_id=obj_id,
    points=points,
    labels=labels,
)

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
    show_points(*prompts[out_obj_id], plt.gca())
    show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)

Player 2

In [None]:
frame_idx = 0  # the frame index we interact with
obj_id = 2  # give a unique id to each object we interact with (it can be any integers)

# Let's add a positive click to get started
points = np.array([[488, 328]], dtype=np.float32)

# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], np.int32)
prompts[obj_id] = points, labels
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=frame_idx,
    obj_id=obj_id,
    points=points,
    labels=labels,
)

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
    show_points(*prompts[out_obj_id], plt.gca())
    show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)

Add second positive click for the same object (if needed)

In [None]:
# ann_frame_idx = 0  # the frame index we interact with
# ann_obj_id = 1  # give a unique id to each object we interact with (it can be any integers)

# # Let's add a 2nd positive click to refine the mask
# # sending all clicks (and their labels) to `add_new_points_or_box`
# points = np.array([[602, 320], [488, 328]], dtype=np.float32)
# # for labels, `1` means positive click and `0` means negative click
# labels = np.array([1, 1], np.int32)
# _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
#     inference_state=inference_state,
#     frame_idx=ann_frame_idx,
#     obj_id=ann_obj_id,
#     points=points,
#     labels=labels,
# )

# # show the results on the current (interacted) frame
# plt.figure(figsize=(9, 6))
# plt.title(f"frame {ann_frame_idx}")
# plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
# show_points(points, labels, plt.gca())
# show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

### Add negative clicks on a frame (if needed)

### Propagate the prompts to get masklet

Display each frame with its masklet

In [None]:
# run propagation throughout the video and collect the results in a dict
video_segments = {}  # video_segments contains the per-frame segmentation results

for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# render the segmentation results every few frames
vis_frame_stride = 30
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)

Convert frames to GIF

In [None]:
# Initialize a list to store frames for the GIF
frames_list = []

# Render the segmentation results every few frames
vis_frame_stride = 10
plt.close("all")

for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    # Create figure and axis
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.set_title(f"Frame {out_frame_idx}")

    # Load and display the frame
    frame_path = os.path.join(video_dir, frame_names[out_frame_idx])
    img = Image.open(frame_path)
    ax.imshow(img)

    # Overlay segmentation masks
    if out_frame_idx in video_segments:
        for out_obj_id, out_mask in video_segments[out_frame_idx].items():
            show_mask(out_mask, ax, obj_id=out_obj_id)

    # Convert figure to image and append to frames_list
    fig.canvas.draw()
    frame_img = np.array(fig.canvas.renderer.buffer_rgba())  # Convert figure to image
    frames_list.append(Image.fromarray(frame_img))  # Append the image to the list

    plt.close(fig)  # Close figure to avoid memory leaks

print(f"Collected {len(frames_list)} frames for GIF generation.")


In [None]:
output_gif_path = "/content/SAM2Tracking.gif"

if frames_list:
    frames_list[0].save(
        output_gif_path, save_all=True, append_images=frames_list[1:], duration=100, loop=0
    )
    print(f"GIF saved successfully: {output_gif_path}")
