# Video prediction with SAM 3

## 1. Setup

In [None]:
sam3_root = "/home/ronghanghu/workspace/sam3"

checkpoint_file = "/checkpoint/sam3/ronghanghu/sam3_release/ckpts/sam3_video_model_only.pt"
has_presence_token = True
geo_encoder_use_img_cross_attn = True

In [2]:
import torch

# use all available GPUs on the machine
gpus_to_use = range(torch.cuda.device_count())

# # use only a single GPU
# gpus_to_use = [torch.cuda.current_device()]

In [3]:
from sam3.model.sam3_video_predictor import Sam3VideoPredictorMultiGPU

predictor = Sam3VideoPredictorMultiGPU(
    checkpoint_path=checkpoint_file,
    has_presence_token=has_presence_token,
    geo_encoder_use_img_cross_attn=geo_encoder_use_img_cross_attn,
    gpus_to_use=gpus_to_use,
)

[0m[32mINFO 2025-10-02 22:53:03,103 153519 sam3_video_predictor.py: 282:[0m using the following GPU IDs: [0, 1, 2, 3, 4, 5, 6, 7]
[0m[32mINFO 2025-10-02 22:53:03,104 153519 sam3_video_predictor.py: 298:[0m 


	*** START loading model on all ranks ***


[0m[32mINFO 2025-10-02 22:53:03,105 153519 sam3_video_predictor.py: 300:[0m loading model on rank=0 with world_size=8 -- this could take a while ...


Enabled the use of perflib.


[0m[32mINFO 2025-10-02 22:53:10,518 153519 sam3_video_base.py: 156:[0m `setting max_num_objects` to 128 -- creating num_obj_for_compile=16 objects for torch.compile cache
[0m[32mINFO 2025-10-02 22:53:13,216 153519 sam3_video_predictor.py: 302:[0m loading model on rank=0 with world_size=8 -- DONE locally
[0m[32mINFO 2025-10-02 22:53:13,217 153519 sam3_video_predictor.py: 359:[0m spawning 7 worker processes


Enabled the use of perflib.
Enabled the use of perflib.
Enabled the use of perflib.


[0m[32mINFO 2025-10-02 22:53:20,291 153723 sam3_video_predictor.py: 439:[0m starting worker process rank=2 with world_size=8


Enabled the use of perflib.
Enabled the use of perflib.
Enabled the use of perflib.
Enabled the use of perflib.


[0m[32mINFO 2025-10-02 22:53:20,974 153725 sam3_video_predictor.py: 439:[0m starting worker process rank=4 with world_size=8
[0m[32mINFO 2025-10-02 22:53:21,115 153726 sam3_video_predictor.py: 439:[0m starting worker process rank=5 with world_size=8
[0m[32mINFO 2025-10-02 22:53:21,416 153722 sam3_video_predictor.py: 439:[0m starting worker process rank=1 with world_size=8
[0m[32mINFO 2025-10-02 22:53:21,419 153723 sam3_video_predictor.py: 300:[0m loading model on rank=2 with world_size=8 -- this could take a while ...
[0m[32mINFO 2025-10-02 22:53:21,603 153724 sam3_video_predictor.py: 439:[0m starting worker process rank=3 with world_size=8
[0m[32mINFO 2025-10-02 22:53:21,628 153728 sam3_video_predictor.py: 439:[0m starting worker process rank=7 with world_size=8
[0m[32mINFO 2025-10-02 22:53:21,634 153727 sam3_video_predictor.py: 439:[0m starting worker process rank=6 with world_size=8
[0m[32mINFO 2025-10-02 22:53:22,240 153725 sam3_video_predictor.py: 300:[0m l

NCCL version 2.25.1+cuda12.6


[0m[32mINFO 2025-10-02 22:55:06,212 153519 sam3_video_predictor.py: 412:[0m started NCCL process group on rank=0 with world_size=8
[0m[32mINFO 2025-10-02 22:55:06,212 153725 sam3_video_predictor.py: 412:[0m started NCCL process group on rank=4 with world_size=8
[0m[32mINFO 2025-10-02 22:55:06,212 153723 sam3_video_predictor.py: 412:[0m started NCCL process group on rank=2 with world_size=8
[0m[32mINFO 2025-10-02 22:55:06,212 153722 sam3_video_predictor.py: 412:[0m started NCCL process group on rank=1 with world_size=8
[0m[32mINFO 2025-10-02 22:55:06,212 153727 sam3_video_predictor.py: 412:[0m started NCCL process group on rank=6 with world_size=8
[0m[32mINFO 2025-10-02 22:55:06,212 153724 sam3_video_predictor.py: 412:[0m started NCCL process group on rank=3 with world_size=8
[0m[32mINFO 2025-10-02 22:55:06,212 153728 sam3_video_predictor.py: 412:[0m started NCCL process group on rank=7 with world_size=8
[0m[32mINFO 2025-10-02 22:55:06,212 153726 sam3_video_predic

### Visualization utils

In [4]:
import os
import glob
import matplotlib.pyplot as plt
from utils import visualize_formatted_frame_output, prepare_masks_for_visualization

Converting 5000 RGB samples to LAB color space...
Conversion to LAB complete.
Fitting KMeans with 128 clusters on 5000 samples...
KMeans fitting complete.


## 2. Running inference

In [5]:
# this video has 6 objects
video_frames_dir = f"{sam3_root}/assets/videos/0001"
prompt_text_str = "person"

# this video has ~80 objects
# video_frames_dir = "/checkpoint/sam3/shared/webdemo/data/ta/static/gallery/selected_examples/0018/rgb"
# prompt_text_str = "horse"

In [6]:
# load "image_files" for visualization purposes (they are not used by the model)
image_files = glob.glob(os.path.join(video_frames_dir, "*.jpg"))
try:
    # integer sort instead of string sort (so that e.g. "2.jpg" is before "11.jpg")
    image_files.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
except ValueError:
    # fallback to lexicographic sort if the format is not "<frame_index>.jpg"
    print(
        f'frame names are not in "<frame_index>.jpg" format: {image_files[:5]=}, '
        f"falling back to lexicographic sort."
    )
    image_files.sort()

### Opening an inference session on this video

In [None]:
response = predictor.handle_request(
    request=dict(
        type="start_session",
        resource_path=video_frames_dir,
    )
)
session_id = response["session_id"]

[0m[32mINFO 2025-10-02 22:55:09,523 153722 sam3_video_predictor.py: 468:[0m worker rank=1 received request request['type']='start_session'
[0m[32mINFO 2025-10-02 22:55:09,524 153727 sam3_video_predictor.py: 468:[0m worker rank=6 received request request['type']='start_session'
[0m[32mINFO 2025-10-02 22:55:09,524 153723 sam3_video_predictor.py: 468:[0m worker rank=2 received request request['type']='start_session'
[0m[32mINFO 2025-10-02 22:55:09,524 153726 sam3_video_predictor.py: 468:[0m worker rank=5 received request request['type']='start_session'
[0m[32mINFO 2025-10-02 22:55:09,524 153725 sam3_video_predictor.py: 468:[0m worker rank=4 received request request['type']='start_session'
[0m[32mINFO 2025-10-02 22:55:09,524 153728 sam3_video_predictor.py: 468:[0m worker rank=7 received request request['type']='start_session'
[0m[32mINFO 2025-10-02 22:55:09,524 153724 sam3_video_predictor.py: 468:[0m worker rank=3 received request request['type']='start_session'
frame 

### Adding a text prompt on frame 0 and propagation throughout the video

Note that the first call might be slower due to setting up buffers. **You can rerun all the cells below when measuring speed.**

In [None]:
# note: in case you already ran one text prompt and now want to switch to another text prompt
# it's required to reset the session first (otherwise the results would be wrong)
_ = predictor.handle_request(
    request=dict(
        type="reset_session",
        session_id=session_id,
    )
)

In [None]:
frame_idx = 0  # add a text prompt on frame 0
response = predictor.handle_request(
    request=dict(
        type="add_prompt",
        session_id=session_id,
        frame_index=frame_idx,
        text=prompt_text_str,
    )
)
out = response["outputs"]

In [None]:
plt.close("all")
visualize_formatted_frame_output(
    frame_idx,
    image_files,
    outputs_list=[prepare_masks_for_visualization({frame_idx: out})],
    titles=["SAM 3 Dense Tracking outputs"],
    figsize=(6, 4),
)

In [None]:
# we will just propagate from frame 0 to the end of the video
outputs_per_frame = {}
for response in predictor.handle_stream_request(
    request=dict(
        type="propagate_in_video",
        session_id=session_id,
    )
):
    outputs_per_frame[response["frame_index"]] = response["outputs"]

outputs_per_frame = prepare_masks_for_visualization(outputs_per_frame)

In [None]:
vis_frame_stride = 60
plt.close("all")
for frame_idx in range(0, len(outputs_per_frame), vis_frame_stride):
    print(f"frame {frame_idx}")
    visualize_formatted_frame_output(
        frame_idx,
        image_files,
        outputs_list=[outputs_per_frame],
        titles=["SAM 3 Dense Tracking outputs"],
        figsize=(6, 4),
    )

In [None]:
# finally, close the inference session to free its GPU resources
# (you may start a new session on another video)
_ = predictor.handle_request(
    request=dict(
        type="close_session",
        session_id=session_id,
    )
)

## 3. Clean up

In [None]:
# after all inference is done, we can shutdown the predictor
# to free up the multi-GPU process group
predictor.shutdown()