# SAM 3 — Video Segmentation

Sample every Nth frame from a video, run text-prompted SAM 3 segmentation, and save RGB frames, per-frame mask overlays, and individual masks.

## 1 — Imports

In [1]:
import os
import time
import cv2
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path
from transformers import Sam3Model, Sam3Processor

# Repo root (two levels up from analysis/tutorials/)
ROOT = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))

%matplotlib inline

## 2 — Helper Functions

In [2]:
def load_model(model_name="facebook/sam3"):
    """Load SAM 3 model and processor onto GPU (or CPU fallback)."""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Loading model: {model_name} on {device} ...")
    model = Sam3Model.from_pretrained(model_name).to(device)
    processor = Sam3Processor.from_pretrained(model_name)
    print("Model and processor loaded successfully!")
    return model, processor, device


def extract_frames(video_path, sample_rate=30, start_skip = 0, end_skip = 0):
    """Extract every Nth frame from a video. Returns list of (frame_index, BGR numpy array)."""
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise FileNotFoundError(f"Cannot open video: {video_path}")

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    print(f"Video: {total_frames} frames, {fps:.1f} fps, sampling every {sample_rate} frames")

    frames = []
    idx = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if idx % sample_rate == 0:
            frames.append((idx, frame))
        idx += 1

    cap.release()
    frames = frames[start_skip:len(frames)-end_skip]
    print(f"Extracted {len(frames)} frames")
    return frames


def segment_frame(model, processor, rgb_frame, text_prompt, device, threshold=0.5):
    """Run SAM 3 segmentation on a single RGB numpy array.
    
    Returns post-processed results dict with 'masks' and 'scores'.
    """
    image = Image.fromarray(rgb_frame)
    inputs = processor(images=image, text=text_prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    results = processor.post_process_instance_segmentation(
        outputs, threshold=threshold, target_sizes=[image.size[::-1]]
    )[0]
    return results


def save_frame_results(rgb_frame, results, frame_name, rgb_dir, overlay_dir, mask_dir, alpha=0.45):
    """Save RGB frame, overlay, and individual masks for one frame."""
    colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
    masks = results['masks']
    scores = results['scores'].cpu().numpy()

    # Save RGB
    plt.imsave(os.path.join(rgb_dir, frame_name), rgb_frame)

    # Build & save overlay
    overlay = rgb_frame.copy()
    for i, mask in enumerate(masks):
        mask_bool = mask.cpu().numpy().astype(bool)
        color = colors[i % len(colors)]
        for c in range(3):
            overlay[:, :, c] = np.where(
                mask_bool,
                overlay[:, :, c] * (1 - alpha) + color[c] * alpha,
                overlay[:, :, c],
            )
    plt.imsave(os.path.join(overlay_dir, frame_name), overlay)

    # Save individual masks
    for i, mask in enumerate(masks):
        mask_np = (mask.cpu().numpy() * 255).astype(np.uint8)
        mask_fname = frame_name.replace(".png", f"_mask_{i}.png")
        Image.fromarray(mask_np).save(os.path.join(mask_dir, mask_fname))


def process_video(model, processor, device, video_path, output_dir,
                  text_prompt="skin", sample_rate=30, start_skip = 0, end_skip = 0, 
                  threshold=0.5, alpha=0.45):
    """
    Full pipeline: extract frames, run SAM 3 segmentation, save results.

    Output structure:
        <output_dir>/<video_name>-rgb/frame_000001.png
        <output_dir>/<video_name>-overlay/frame_000001.png
        <output_dir>/<video_name>-masks/frame_000001_mask_0.png, ...
    """
    video_name = Path(video_path).stem

    rgb_dir     = os.path.join(output_dir, f"{video_name}-rgb")
    overlay_dir = os.path.join(output_dir, f"{video_name}-overlay")
    mask_dir    = os.path.join(output_dir, f"{video_name}-masks")
    os.makedirs(rgb_dir, exist_ok=True)
    os.makedirs(overlay_dir, exist_ok=True)
    os.makedirs(mask_dir, exist_ok=True)

    frames = extract_frames(video_path, sample_rate, start_skip, end_skip)
    total = len(frames)
    seg_times = []

    for i, (frame_idx, bgr_frame) in enumerate(frames):
        fname = f"frame_{i+1:06d}.png"
        rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)

        # Segment
        t0 = time.time()
        results = segment_frame(model, processor, rgb_frame, text_prompt, device, threshold)
        seg_time = time.time() - t0
        seg_times.append(seg_time)

        num_obj = len(results['masks'])
        scores = results['scores'].cpu().numpy()

        # Save
        save_frame_results(rgb_frame, results, fname, rgb_dir, overlay_dir, mask_dir, alpha)

        print(f"  [{i+1}/{total}] frame {frame_idx} -> {fname}  |  {num_obj} obj  |  {seg_time:.2f}s")

    avg_time = np.mean(seg_times) if seg_times else 0
    print(f"\nDone! Processed {total} frames (avg {avg_time:.2f}s per frame)")
    print(f"  RGB:     {rgb_dir}")
    print(f"  Overlay: {overlay_dir}")
    print(f"  Masks:   {mask_dir}")

## 3 — Load Model & Set Constants (run once)

In [3]:
# ---- Constant hyperparameters (change once, not per video) ----------
model_name   = "facebook/sam3"                        # HF model id
text_prompt  = "blade"                                 # what to detect
threshold    = 0.5                                    # confidence threshold
alpha        = 0.45                                   # overlay transparency
# ---------------------------------------------------------------------

model, processor, device = load_model(model_name)

Loading model: facebook/sam3 on cuda ...


Loading weights:   0%|          | 0/1468 [00:00<?, ?it/s]

Model and processor loaded successfully!


## 4 — Run on Video (change `video_path` and re-run this cell)

In [4]:
# ---- Change this per run --------------------------------------------
video_path   = "analysis/data/sam3/surgery_video.mp4"   # input video
output_dir   = "analysis/outputs/sam3/"                 # output folder
sample_rate  = 30                                       # sample every Nth frame
# ---------------------------------------------------------------------

video_path  = os.path.join(ROOT, video_path)
output_dir  = os.path.join(ROOT, output_dir)

process_video(model, processor, device, video_path, output_dir,
              text_prompt=text_prompt, sample_rate=sample_rate, start_skip = 10, end_skip = 0,
              threshold=threshold, alpha=alpha)

Video: 1638 frames, 24.0 fps, sampling every 30 frames
Extracted 45 frames
  [1/45] frame 300 -> frame_000001.png  |  2 obj  |  1.06s
  [2/45] frame 330 -> frame_000002.png  |  3 obj  |  0.68s
  [3/45] frame 360 -> frame_000003.png  |  1 obj  |  0.67s
  [4/45] frame 390 -> frame_000004.png  |  2 obj  |  0.67s
  [5/45] frame 420 -> frame_000005.png  |  2 obj  |  0.67s
  [6/45] frame 450 -> frame_000006.png  |  2 obj  |  0.67s
  [7/45] frame 480 -> frame_000007.png  |  1 obj  |  0.67s
  [8/45] frame 510 -> frame_000008.png  |  4 obj  |  0.66s
  [9/45] frame 540 -> frame_000009.png  |  2 obj  |  0.67s
  [10/45] frame 570 -> frame_000010.png  |  2 obj  |  0.66s
  [11/45] frame 600 -> frame_000011.png  |  3 obj  |  0.67s
  [12/45] frame 630 -> frame_000012.png  |  1 obj  |  0.67s
  [13/45] frame 660 -> frame_000013.png  |  1 obj  |  0.68s
  [14/45] frame 690 -> frame_000014.png  |  1 obj  |  0.67s
  [15/45] frame 720 -> frame_000015.png  |  1 obj  |  0.67s
  [16/45] frame 750 -> frame_00001

OutOfMemoryError: CUDA out of memory. Tried to allocate 94.00 MiB. GPU 0 has a total capacity of 11.77 GiB of which 49.69 MiB is free. Process 646106 has 7.44 GiB memory in use. Including non-PyTorch memory, this process has 4.26 GiB memory in use. Of the allocated memory 3.44 GiB is allocated by PyTorch, and 462.54 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)