In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.

# Video segmentation and tracking with SAM 3

### This notebook demonstrates how to use SAM 3 for interactive video segmentation and dense tracking. It covers the following capabilities:

- **Text prompts**: Using natural language descriptions to segment objects (e.g., "person", "face", "visual")
- **Point prompts**: Adding positive/negative clicks to segment and refine objects
- **Box prompts**: Using bounding boxes combined with text for precise object localization
- **Interactive refinement**: Adding clicks on any frame to improve segmentation quality

#### We use the terms _segment_ or _mask_ to refer to the model prediction for an object on a single frame, and _masklet_ to refer to the spatio-temporal masks across the entire video. 

<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/sam3/blob/main/notebooks/sam3_video_predictor_example.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" /></a>

In [None]:
using_colab = False

In [None]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib scikit-learn
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/sam3.git'

## Set-up

In [None]:
import sys
sys.path.append('..')

In [None]:
import torch

# use all available GPUs on the machine
gpus_to_use = range(torch.cuda.device_count())

# # use only a single GPU
# gpus_to_use = [torch.cuda.current_device()]

In [None]:
import os
import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sam3.visualization_utils import (
    visualize_formatted_frame_output,
    prepare_masks_for_visualization,
    visualize_prompt_overlay,
    show_points,
    draw_box_on_image,
    load_frame,
)

# helper functions
def propagate_in_video(predictor, session_id):

    # we will just propagate from frame 0 to the end of the video
    outputs_per_frame = {}
    for response in predictor.handle_stream_request(
        request=dict(
            type="propagate_in_video",
            session_id=session_id,
        )
    ):
        outputs_per_frame[response["frame_index"]] = response["outputs"]

    return outputs_per_frame

def abs_to_rel_coords(coords, IMG_WIDTH, IMG_HEIGHT, coord_type='point'):
    """Convert absolute coordinates to relative coordinates (0-1 range)

    Args:
        coords: List of coordinates
        coord_type: 'point' for [x, y] or 'box' for [x, y, w, h]
    """
    if coord_type == 'point':
        return [[x / IMG_WIDTH, y / IMG_HEIGHT] for x, y in coords]
    elif coord_type == 'box':
        return [[x / IMG_WIDTH, y / IMG_HEIGHT, w / IMG_WIDTH, h / IMG_HEIGHT] for x, y, w, h in coords]
    else:
        raise ValueError(f"Unknown coord_type: {coord_type}")

In [None]:
from sam3.model.sam3_video_predictor import Sam3VideoPredictorMultiGPU

sam3_root = f"/home/{os.getenv('USER')}/sam3"

# checkpoint_file = f"{sam3_root}/assets/checkpoints/sam3_video_model_only.pt"
checkpoint_file = "/checkpoint/sam3/haithamkhedr/checkpoints/sam3_dense/sam3_v2_rc2.pt"
has_presence_token = True
geo_encoder_use_img_cross_attn = True
predictor = Sam3VideoPredictorMultiGPU(
    checkpoint_path=checkpoint_file,
    has_presence_token=has_presence_token,
    geo_encoder_use_img_cross_attn=geo_encoder_use_img_cross_attn,
    gpus_to_use=gpus_to_use,
)

### Loading an example video

We assume that the video is stored as a list of JPEG frames with filenames like `<frame_index>.jpg`.

For your custom videos, you can extract their JPEG frames using ffmpeg (https://ffmpeg.org/) as follows:
```
ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'
```
where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks ffmpeg to start the JPEG file from `00000.jpg`.

In [None]:
# this video has 6 objects
video_path = f"{sam3_root}/assets/videos/0001"

In [None]:
# load "video_frames_for_vis" for visualization purposes (they are not used by the model)
if isinstance(video_path, str) and video_path.endswith(".mp4"):
    cap = cv2.VideoCapture(video_path)
    video_frames_for_vis = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        video_frames_for_vis.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    cap.release()
else:
    video_frames_for_vis = glob.glob(os.path.join(video_path, "*.jpg"))
    try:
        # integer sort instead of string sort (so that e.g. "2.jpg" is before "11.jpg")
        video_frames_for_vis.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
    except ValueError:
        # fallback to lexicographic sort if the format is not "<frame_index>.jpg"
        print(
            f'frame names are not in "<frame_index>.jpg" format: {video_frames_for_vis[:5]=}, '
            f"falling back to lexicographic sort."
        )
        video_frames_for_vis.sort()

## Running inference
### Opening an inference session on this video
SAM 3 requires stateful inference for interactive video segmentation, so we need to initialize an **inference session** on this video.

During initialization, it loads all the JPEG frames in the video directory and stores their features in the session state.

In [None]:
response = predictor.handle_request(
    request=dict(
        type="start_session",
        resource_path=video_path,
    )
)
session_id = response["session_id"]

## Video promptable concept segmentation with text

Using SAM 3 you can describe objects using natural language,

and the model will automatically detect and track all instances of that object throughout the video.

### Adding a text prompt on frame 0 and propagation throughout the video

Note that the first call might be slower due to setting up buffers. **You can rerun all the cells below when measuring speed.**

In [None]:
# note: in case you already ran one text prompt and now want to switch to another text prompt
# it's required to reset the session first (otherwise the results would be wrong)
_ = predictor.handle_request(
    request=dict(
        type="reset_session",
        session_id=session_id,
    )
)

### Adding a text prompt

Here we use the text prompt "person" to detect all people in the video. 

SAM 3 will automatically identify multiple person instances and assign each a unique object ID.

In [None]:
prompt_text_str = "person"
frame_idx = 0  # add a text prompt on frame 0
response = predictor.handle_request(
    request=dict(
        type="add_prompt",
        session_id=session_id,
        frame_index=frame_idx,
        text=prompt_text_str,
    )
)
out = response["outputs"]

In [None]:
plt.close("all")
visualize_formatted_frame_output(
    frame_idx,
    video_frames_for_vis,
    outputs_list=[prepare_masks_for_visualization({frame_idx: out})],
    titles=["SAM 3 Dense Tracking outputs"],
    figsize=(6, 4),
)

In [None]:
# now we propagate the outputs from frame 0 to the end of the video and collect all outputs
outputs_per_frame = propagate_in_video(predictor, session_id)

# finally, we reformat the outputs for visualization and plot the outputs every 60 frames
outputs_per_frame = prepare_masks_for_visualization(outputs_per_frame)

vis_frame_stride = 60
plt.close("all")
for frame_idx in range(0, len(outputs_per_frame), vis_frame_stride):
    print(f"frame {frame_idx}")
    visualize_formatted_frame_output(
        frame_idx,
        video_frames_for_vis,
        outputs_list=[outputs_per_frame],
        titles=["SAM 3 Dense Tracking outputs"],
        figsize=(6, 4),
    )

###  Removing objects

We can remove individual objects using their id. Let's remove object 2.

In [None]:
# we pick id 2, which is the dancer in the front
obj_id = 2
response = predictor.handle_request(
    request=dict(
        type="remove_object",
        session_id=session_id,
        obj_id=obj_id,
    )
)

In [None]:
# now we propagate the outputs from frame 0 to the end of the video and collect all outputs
outputs_per_frame = propagate_in_video(predictor, session_id)

# finally, we reformat the outputs for visualization and plot the outputs every 60 frames
outputs_per_frame = prepare_masks_for_visualization(outputs_per_frame)

vis_frame_stride = 60
plt.close("all")
for frame_idx in range(0, len(outputs_per_frame), vis_frame_stride):
    print(f"frame {frame_idx}")
    visualize_formatted_frame_output(
        frame_idx,
        video_frames_for_vis,
        outputs_list=[outputs_per_frame],
        titles=["SAM 3 Dense Tracking outputs"],
        figsize=(6, 4),
    )

###  Adding new objects with point prompts

In addition to text prompts, SAM 3 supports **point prompts** to add specific objects that might not be detected automatically.
Here we add a new object by clicking on a specific location.

Note: label `1` indicates a *positive click (to add a region)* while label `0` indicates a *negative click (to remove a region)*.

#### Get image dimensions for coordinate conversion

We compute the image dimensions once at the beginning and use them throughout for coordinate conversions.

All coordinates in this example are defined in absolute pixel coordinates and then converted to relative coordinates (0-1 range) for the model.

In [None]:
sample_img = Image.fromarray(load_frame(video_frames_for_vis[0]))

IMG_WIDTH, IMG_HEIGHT = sample_img.size
print(f"Image dimensions: {IMG_WIDTH} x {IMG_HEIGHT}")

In [None]:
# let's add back the dancer via point prompts.
# we will only use positive clicks to add the dancer back.
points_abs = [
    [740, 450],
    [760, 630],
    [760, 550],
    
]
# all positive clicks
labels = np.array([1, 1, 1])

frame_idx = 0
obj_id = 2

points_abs = np.array(points_abs)

image = Image.fromarray(load_frame(video_frames_for_vis[0]))
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_points(points_abs, labels, plt.gca())
plt.axis('on')
plt.show()  


# Convert to relative coordinates for the model
points = np.array(abs_to_rel_coords(points_abs, IMG_WIDTH, IMG_HEIGHT, coord_type='point'))

# convert points and labels to tensors
points_tensor = torch.tensor(points, dtype=torch.float32)
points_labels_tensor = torch.tensor(labels, dtype=torch.int32)

In [None]:
response = predictor.handle_request(
    request=dict(
        type="add_prompt",
        session_id=session_id,
        frame_index=frame_idx,
        points=points_tensor,
        point_labels=points_labels_tensor,
        obj_id=obj_id,
    )
)


# now we propagate the outputs from frame 0 to the end of the video and collect all outputs
outputs_per_frame = propagate_in_video(predictor, session_id)

# finally, we reformat the outputs for visualization and plot the outputs every 60 frames
outputs_per_frame = prepare_masks_for_visualization(outputs_per_frame)

vis_frame_stride = 60
plt.close("all")
for frame_idx in range(0, len(outputs_per_frame), vis_frame_stride):
    print(f"frame {frame_idx}")
    visualize_formatted_frame_output(
        frame_idx,
        video_frames_for_vis,
        outputs_list=[outputs_per_frame],
        titles=["SAM 3 Dense Tracking outputs"],
        figsize=(6, 4),
    )

###  Interactive refinement with point prompts

In addition to using **point prompts** for adding objects, we can use them to refine objects.

This is particularly useful when the initial segmentation from text prompts needs fine adjustments.

In [None]:
# We want to only track the upper body of the dancer, so we add positive clicks to it
# and negative clicks to the pants.
points_abs = [
    [740, 450],  # positive click
    [760, 630],  # negative click
    [760, 550],  # positive click
]
labels = np.array([1, 0, 1])

frame_idx = 0
refine_obj_id = 2

points_abs = np.array(points_abs)

image = Image.fromarray(load_frame(video_frames_for_vis[0]))
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_points(points_abs, labels, plt.gca())
plt.axis('on')
plt.show()  


# Convert to relative coordinates for the model
points = np.array(abs_to_rel_coords(points_abs, IMG_WIDTH, IMG_HEIGHT, coord_type='point'))

# convert points and labels to tensors
points_tensor = torch.tensor(points, dtype=torch.float32)
points_labels_tensor = torch.tensor(labels, dtype=torch.int32)

In [None]:
response = predictor.handle_request(
    request=dict(
        type="add_prompt",
        session_id=session_id,
        frame_index=frame_idx,
        points=points_tensor,
        point_labels=points_labels_tensor,
        obj_id=refine_obj_id,
    )
)

outputs_per_frame = propagate_in_video(predictor, session_id)
outputs_per_frame = prepare_masks_for_visualization(outputs_per_frame)


vis_frame_stride = 60
plt.close("all")
for frame_idx in range(0, len(outputs_per_frame), vis_frame_stride):
    print(f"frame {frame_idx}")
    visualize_formatted_frame_output(
        frame_idx,
        video_frames_for_vis,
        outputs_list=[outputs_per_frame],
        titles=["SAM 3 Dense Tracking outputs"],
        figsize=(6, 4),
    )

### Reset the session

In [None]:
_ = predictor.handle_request(
    request=dict(
        type="reset_session",
        session_id=session_id,
    )
)

###  Adding new objects with box prompts

In addition to text and point prompts, SAM 3 supports **box prompts** to add objects.

This is particularly useful if we cannot describe a concept, but we can draw a box around it.

#### Adding object via box only prompts

In this example we prompt the model with a concrete step in the video.

In [None]:
boxes_abs = [
    [190, 365, 89, 80]
]

boxes = abs_to_rel_coords(boxes_abs, IMG_WIDTH, IMG_HEIGHT, 'box')

# we use "visual" as text prompt for a prompt that uses the box only
prompt_text_str = "visual"
frame_idx = 0

image = Image.fromarray(load_frame(video_frames_for_vis[0]))
image_with_box = image
for box in boxes_abs:
    image_with_box = draw_box_on_image(image_with_box, box, (0,255,0))
plt.imshow(image_with_box)
plt.axis('off')  # Hide the axis
plt.show()

response = predictor.handle_request(
    request=dict(
        type="add_prompt",
        session_id=session_id,
        frame_index=frame_idx,
        text=prompt_text_str,
        bounding_boxes=boxes,
        bounding_box_labels=[1] * len(boxes),
    )
)

In [None]:
outputs_per_frame = propagate_in_video(predictor, session_id)
outputs_per_frame = prepare_masks_for_visualization(outputs_per_frame)


vis_frame_stride = 60
plt.close("all")
for frame_idx in range(0, len(outputs_per_frame), vis_frame_stride):
    print(f"frame {frame_idx}")
    visualize_formatted_frame_output(
        frame_idx,
        video_frames_for_vis,
        outputs_list=[outputs_per_frame],
        titles=["SAM 3 Dense Tracking outputs"],
        figsize=(6, 4),
    )

In [None]:
_ = predictor.handle_request(
    request=dict(
        type="reset_session",
        session_id=session_id,
    )
)

#### Adding object via box and text prompts

In this example we prompt the model with both a text and a box prompt for the face of the person.

In [None]:
boxes_abs = [
    [520, 200, 90, 40]
]

boxes = abs_to_rel_coords(boxes_abs, IMG_WIDTH, IMG_HEIGHT, 'box')

prompt_text_str = "face"
frame_idx = 0

# visualize the concrete step box prompt
visualize_prompt_overlay(
    frame_idx,
    video_frames_for_vis,
    bounding_boxes=boxes,
    box_labels=[1],
    text_prompt=prompt_text_str,
)


response = predictor.handle_request(
    request=dict(
        type="add_prompt",
        session_id=session_id,
        frame_index=frame_idx,
        text=prompt_text_str,
        bounding_boxes=boxes,
        bounding_box_labels=[1] * len(boxes),
    )
)

In [None]:
outputs_per_frame = propagate_in_video(predictor, session_id)
outputs_per_frame = prepare_masks_for_visualization(outputs_per_frame)


vis_frame_stride = 60
plt.close("all")
for frame_idx in range(0, len(outputs_per_frame), vis_frame_stride):
    print(f"frame {frame_idx}")
    visualize_formatted_frame_output(
        frame_idx,
        video_frames_for_vis,
        outputs_list=[outputs_per_frame],
        titles=["SAM 3 Dense Tracking outputs"],
        figsize=(6, 4),
    )

#### Refine multiple objects simultaneously

Refine object 0 (face region)

First, we refine the face region using negative and positive points to improve boundary precision.

In [None]:
points_abs = [
    [580, 250],  # negative point to remove unwanted area
    [570, 220],  # positive point to keep desired area
]

point_labels = np.array([0,1])
frame_idx = 0
object_id = 0


points_abs = np.array(points_abs)

image = Image.fromarray(load_frame(video_frames_for_vis[0]))
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_points(points_abs, point_labels, plt.gca())
plt.axis('on')
plt.show()  


# Convert to relative coordinates for the model
points = np.array(abs_to_rel_coords(points_abs, IMG_WIDTH, IMG_HEIGHT, coord_type='point'))

# Convert points and labels to tensors
points_tensor = torch.tensor(points, dtype=torch.float32)
labels_tensor = torch.tensor(point_labels, dtype=torch.int32)

response = predictor.handle_request(
    request=dict(
        type="add_prompt",
        session_id=session_id,
        frame_index=frame_idx,
        points=points_tensor,
        point_labels=labels_tensor,
        obj_id=object_id,
    )
)

In [None]:
outputs_per_frame = propagate_in_video(predictor, session_id)
outputs_per_frame = prepare_masks_for_visualization(outputs_per_frame)


vis_frame_stride = 60
plt.close("all")
for frame_idx in range(0, len(outputs_per_frame), vis_frame_stride):
    print(f"frame {frame_idx}")
    visualize_formatted_frame_output(
        frame_idx,
        video_frames_for_vis,
        outputs_list=[outputs_per_frame],
        titles=["SAM 3 Dense Tracking outputs"],
        figsize=(6, 4),
    )

Refine object 3 (body region)

Next, we refine the body region using multiple prompts.

In [None]:
points_abs = [
    [740, 450],  # positive click
    [770, 630],  # negative click
    [760, 550],  # positive click
    [760, 600],  # negative click

]
labels = np.array([1, 0, 1, 0])
points_abs = np.array(points_abs)
point_labels = np.array(labels)

plt.close('all')
image = Image.fromarray(load_frame(video_frames_for_vis[0]))
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_points(points_abs, point_labels, plt.gca())
plt.axis('on')
plt.show()  

# Convert to relative coordinates for the model
points = abs_to_rel_coords(points_abs,IMG_WIDTH, IMG_HEIGHT, coord_type='point')

frame_idx = 0
object_id = 2


# Convert points and labels to tensors
points_tensor = torch.tensor(points, dtype=torch.float32)
labels_tensor = torch.tensor(point_labels, dtype=torch.int32)

response = predictor.handle_request(
    request=dict(
        type="add_prompt",
        session_id=session_id,
        frame_index=frame_idx,
        points=points_tensor,
        point_labels=labels_tensor,
        obj_id=object_id,
    )
)

In [None]:
outputs_per_frame = propagate_in_video(predictor, session_id)
outputs_per_frame = prepare_masks_for_visualization(outputs_per_frame)


vis_frame_stride = 60
plt.close("all")
for frame_idx in range(0, len(outputs_per_frame), vis_frame_stride):
    print(f"frame {frame_idx}")
    visualize_formatted_frame_output(
        frame_idx,
        video_frames_for_vis,
        outputs_list=[outputs_per_frame],
        titles=["SAM 3 Dense Tracking outputs"],
        figsize=(6, 4),
    )

### Close session

In [None]:
# finally, close the inference session to free its GPU resources
# (you may start a new session on another video)
_ = predictor.handle_request(
    request=dict(
        type="close_session",
        session_id=session_id,
    )
)

### Clean-up

In [None]:
# after all inference is done, we can shutdown the predictor
# to free up the multi-GPU process group
predictor.shutdown()