# Video prediction with SAM 3

## 1. Setup

In [None]:
sam3_root = "/home/ronghanghu/workspace/sam3"

# # prod v12 model (s3://fb-fair-onevision/ronghanghu/sam3_release/ckpts/prod_v12_sam2_internal_shared_dth0.5_newdet0.7_assth0.1_nms0.1.pt)
# checkpoint_file = "/checkpoint/sam3/ronghanghu/sam3_release/ckpts/prod_v12_sam2_internal_shared_dth0.5_newdet0.7_assth0.1_nms0.1.pt"
# has_presence_token = False
# geo_encoder_use_img_cross_attn = False

# prod v14 model (s3://fb-fair-onevision/ronghanghu/sam3_release/ckpts/sam3_v14_rc1.pt)
checkpoint_file = "/checkpoint/sam3/ronghanghu/sam3_release/ckpts/sam3_v14_rc1.pt"
has_presence_token = True
geo_encoder_use_img_cross_attn = True

In [None]:
import torch

# use all available GPUs on the machine
gpus_to_use = range(torch.cuda.device_count())

# # use only a single GPU
# gpus_to_use = [torch.cuda.current_device()]

In [None]:
from sam3.model.sam3_video_predictor import Sam3VideoPredictorMultiGPU

predictor = Sam3VideoPredictorMultiGPU(
    checkpoint_path=checkpoint_file,
    has_presence_token=has_presence_token,
    geo_encoder_use_img_cross_attn=geo_encoder_use_img_cross_attn,
    gpus_to_use=gpus_to_use,
)

### Visualization utils

In [None]:
import os
import glob
import matplotlib.pyplot as plt
from utils import visualize_formatted_frame_output, prepare_masks_for_visualization

## 2. Running inference

In [None]:
# this video has 6 objects
video_frames_dir = f"{sam3_root}/assets/videos/0001"
prompt_text_str = "person"

# this video has ~80 objects
# video_frames_dir = "/checkpoint/sam3/shared/webdemo/data/ta/static/gallery/selected_examples/0018/rgb"
# prompt_text_str = "horse"

In [None]:
# load "image_files" for visualization purposes (they are not used by the model)
image_files = glob.glob(os.path.join(video_frames_dir, "*.jpg"))
try:
    # integer sort instead of string sort (so that e.g. "2.jpg" is before "11.jpg")
    image_files.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
except ValueError:
    # fallback to lexicographic sort if the format is not "<frame_index>.jpg"
    print(
        f'frame names are not in "<frame_index>.jpg" format: {image_files[:5]=}, '
        f"falling back to lexicographic sort."
    )
    image_files.sort()

### Opening an inference session on this video

In [None]:
response = predictor.handle_request(
    request=dict(
        type="start_session",
        resource_path=video_frames_dir,
    )
)
session_id = response["session_id"]

### Adding a text prompt on frame 0 and propagation throughout the video

Note that the first call might be slower due to setting up buffers. **You can rerun all the cells below when measuring speed.**

In [None]:
# note: in case you already ran one text prompt and now want to switch to another text prompt
# it's required to reset the session first (otherwise the results would be wrong)
_ = predictor.handle_request(
    request=dict(
        type="reset_session",
        session_id=session_id,
    )
)

In [None]:
frame_idx = 0  # add a text prompt on frame 0
response = predictor.handle_request(
    request=dict(
        type="add_prompt",
        session_id=session_id,
        frame_index=frame_idx,
        text=prompt_text_str,
    )
)
out = response["outputs"]

In [None]:
plt.close("all")
visualize_formatted_frame_output(
    frame_idx,
    image_files,
    outputs_list=[prepare_masks_for_visualization({frame_idx: out})],
    titles=["SAM 3 Dense Tracking outputs"],
    figsize=(6, 4),
)

In [None]:
# we will just propagate from frame 0 to the end of the video
outputs_per_frame = {}
for response in predictor.handle_stream_request(
    request=dict(
        type="propagate_in_video",
        session_id=session_id,
    )
):
    outputs_per_frame[response["frame_index"]] = response["outputs"]

outputs_per_frame = prepare_masks_for_visualization(outputs_per_frame)

In [None]:
vis_frame_stride = 60
plt.close("all")
for frame_idx in range(0, len(outputs_per_frame), vis_frame_stride):
    print(f"frame {frame_idx}")
    visualize_formatted_frame_output(
        frame_idx,
        image_files,
        outputs_list=[outputs_per_frame],
        titles=["SAM 3 Dense Tracking outputs"],
        figsize=(6, 4),
    )

In [None]:
# finally, close the inference session to free its GPU resources
# (you may start a new session on another video)
_ = predictor.handle_request(
    request=dict(
        type="close_session",
        session_id=session_id,
    )
)

## 3. Clean up

In [None]:
# after all inference is done, we can shutdown the predictor
# to free up the multi-GPU process group
predictor.shutdown()