In [0]:
# %pip install opencv-python

In [0]:
# %pip install imageio imageio[ffmpeg] imageio[pyav]
# %restart_python

In [0]:
from huggingface_hub import login
import os

hf_pat = dbutils.secrets.get("justinm-buildathon-secrets", "hf_pat")
os.environ["HF_TOKEN"] = hf_pat
login(token=hf_pat)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [0]:
import os
import mlflow.pyfunc
import torch
import numpy as np
import pandas as pd
import cv2
import base64
from io import BytesIO
from PIL import Image
from transformers import Sam3Processor, Sam3Model


class SAM3Video(mlflow.pyfunc.PythonModel):
    """
    MLflow wrapper for SAM3 image + video segmentation with batching
    """

    # -------------------------
    # Model loading
    # -------------------------
    def load_context(self, context):
        from huggingface_hub import login
        from transformers import logging

        logging.set_verbosity_error()
        logging.disable_progress_bar()

        hf_pat = os.environ["HF_TOKEN"]
        login(token=hf_pat)

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        dtype = torch.float16 if self.device == "cuda" else torch.float32
        print("device:", self.device)

        print("loading model...")
        self.model = Sam3Model.from_pretrained(
            "facebook/sam3",
            torch_dtype=dtype
        ).to(self.device)

        print("loading processor...")
        self.processor = Sam3Processor.from_pretrained("facebook/sam3")

        print("context loaded")

    # -------------------------
    # Utils
    # -------------------------
    def _encode_mask(self, mask: np.ndarray) -> str:
        """Encode float mask â†’ base64"""
        buf = BytesIO()
        np.save(buf, mask.astype(np.float32))
        return base64.b64encode(buf.getvalue()).decode()

    def _video_capture(self, path):
        if path.startswith("http"):
            return cv2.VideoCapture(path)
        return cv2.VideoCapture(os.path.expanduser(path))

    # -------------------------
    # Core video processing
    # -------------------------
    def _process_video(
        self,
        video_path: str,
        prompt: str,
        frame_stride: int,
        batch_size: int,
        threshold: float,
        mask_threshold: float
    ):
        cap = self._video_capture(video_path)

        frames = []
        frame_indices = []
        results = []
        idx = 0

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            if idx % frame_stride == 0:
                rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(Image.fromarray(rgb))
                frame_indices.append(idx)

            if len(frames) == batch_size:
                results.extend(
                    self._run_batch(
                        frames,
                        frame_indices,
                        prompt,
                        threshold,
                        mask_threshold
                    )
                )
                frames, frame_indices = [], []

            idx += 1

        # leftover frames
        if frames:
            results.extend(
                self._run_batch(
                    frames,
                    frame_indices,
                    prompt,
                    threshold,
                    mask_threshold
                )
            )

        cap.release()
        return results

    # -------------------------
    # Batched SAM3 inference
    # -------------------------
    def _run_batch(
        self,
        images,
        frame_indices,
        prompt,
        threshold,
        mask_threshold
    ):
        inputs = self.processor(
            images=images,
            text=[prompt] * len(images),
            return_tensors="pt"
        ).to(self.device)

        for k in inputs:
            if inputs[k].dtype == torch.float32:
                inputs[k] = inputs[k].to(self.model.dtype)

        with torch.no_grad():
            outputs = self.model(**inputs)

        processed = self.processor.post_process_instance_segmentation(
            outputs,
            threshold=threshold,
            mask_threshold=mask_threshold,
            target_sizes=inputs["original_sizes"].tolist()
        )

        batch_results = []
        for i, res in enumerate(processed):
            batch_results.append({
                "frame_idx": frame_indices[i],
                "scores": res["scores"].cpu().tolist(),
                "masks": [
                    self._encode_mask(m.cpu().numpy())
                    for m in res["masks"]
                ]
            })

        return batch_results

    # -------------------------
    # MLflow predict
    # -------------------------
    def predict(self, context, model_input, params=None):
        if isinstance(model_input, pd.DataFrame):
            row = model_input.iloc[0].to_dict()
        else:
            row = model_input

        video_path = row["video_path"]
        prompt = row["prompt"]

        frame_stride = int(row.get("frame_stride", 1))
        batch_size = int(row.get("batch_size", 4))
        threshold = float(row.get("threshold", 0.5))
        mask_threshold = float(row.get("mask_threshold", 0.5))

        print(video_path)
        print(prompt)
        print(frame_stride)
        print(batch_size)
        print(threshold)

        return self._process_video(
            video_path=video_path,
            prompt=prompt,
            frame_stride=frame_stride,
            batch_size=batch_size,
            threshold=threshold,
            mask_threshold=mask_threshold
        )




In [0]:
# Load model and get predictions
# print("Loading MLflow model...")
# model = mlflow.pyfunc.load_model(MODEL_URI)
model = SAM3Video()
model.load_context(context=None)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


device: cuda
loading model...
loading processor...
context loaded


In [0]:
import mlflow
import imageio
import numpy as np
import pandas as pd
import cv2
import base64
import timeit
from io import BytesIO
from PIL import Image
import os

# Configuration
# MODEL_URI = "your_model_uri_here"  # e.g., "models:/sam3_video/1" or "runs:/run_id/model"
video_name = "maren_jack"

VIDEO_PATH = f"/Volumes/pubsec_video_processing/cv/auto_segment/images/{video_name}.MOV"  # Your input video
PROMPT = "boy in white sweater with black stripes"  # Your segmentation prompt
# OUTPUT_FRAMES_DIR = "/Volumes/pubsec_video_processing/cv/images/bruno1_output_dir/"
OUTPUT_VIDEO_PATH = f"/Volumes/pubsec_video_processing/cv/auto_segment/images/{video_name}_output.mp4"
FPS = 30  # Adjust to match your video's FPS

# Create output directory
# os.makedirs(OUTPUT_FRAMES_DIR, exist_ok=True)

# Prepare model input
model_input = pd.DataFrame([{
    "video_path": VIDEO_PATH,
    "prompt": PROMPT,
    "frame_stride": 5,  # Process every nth frame
    "batch_size": 4,
    "threshold": 0.5,
    "mask_threshold": 0.5
}])

print("Running inference...")
starting_time = timeit.default_timer()
results = model.predict(context=None, model_input=model_input)
print(f"Inference time: {round((timeit.default_timer() - starting_time))} secs")

Running inference...
/Volumes/pubsec_video_processing/cv/auto_segment/images/maren_jack.MOV
boy in white sweater with black stripes
5
4
0.5
Inference time: 54 secs


In [0]:
type(results)
result_map = {r["frame_idx"]: r for r in results}
print(result_map.keys())
print(result_map[5])

dict_keys([0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 105, 110, 115, 120, 125, 130, 135, 140, 145, 150, 155, 160, 165, 170, 175, 180, 185, 190, 195, 200, 205, 210, 215, 220, 225, 230, 235, 240, 245, 250, 255, 260, 265, 270, 275, 280, 285, 290, 295, 300, 305, 310, 315, 320, 325, 330, 335, 340, 345, 350, 355, 360, 365, 370, 375, 380, 385, 390, 395, 400, 405, 410, 415, 420, 425, 430, 435, 440, 445, 450, 455, 460, 465, 470, 475, 480, 485, 490, 495, 500, 505, 510, 515, 520, 525, 530, 535, 540, 545, 550, 555, 560, 565, 570, 575, 580, 585, 590, 595, 600, 605, 610, 615, 620, 625, 630, 635, 640, 645, 650, 655, 660, 665, 670, 675, 680, 685, 690, 695, 700, 705, 710, 715, 720, 725, 730, 735, 740, 745, 750, 755, 760, 765, 770, 775, 780, 785, 790, 795, 800, 805, 810, 815, 820, 825, 830, 835, 840, 845, 850, 855, 860, 865, 870, 875, 880, 885, 890, 895, 900, 905, 910, 915, 920, 925, 930, 935, 940, 945, 950, 955, 960, 965, 970])


dict_keys([0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 105, 110, 115, 120, 125, 130, 135, 140, 145, 150, 155, 160, 165, 170, 175, 180, 185, 190, 195, 200, 205, 210, 215, 220, 225, 230, 235, 240, 245, 250, 255, 260, 265, 270, 275, 280, 285, 290, 295, 300, 305, 310, 315, 320, 325, 330, 335, 340, 345, 350, 355, 360, 365, 370, 375, 380, 385, 390, 395, 400, 405, 410, 415, 420, 425, 430, 435, 440, 445, 450, 455, 460, 465, 470, 475, 480, 485, 490, 495, 500, 505, 510, 515, 520, 525, 530, 535, 540, 545, 550, 555, 560, 565, 570, 575, 580, 585, 590, 595, 600, 605, 610, 615, 620, 625, 630, 635, 640, 645, 650, 655, 660, 665, 670, 675, 680, 685, 690, 695, 700, 705, 710, 715, 720, 725, 730, 735, 740, 745, 750, 755, 760, 765, 770, 775, 780, 785, 790, 795, 800, 805, 810, 815, 820, 825, 830, 835, 840, 845, 850, 855, 860, 865, 870, 875, 880, 885, 890, 895, 900, 905, 910, 915, 920, 925, 930, 935, 940, 945, 950, 955, 960, 965, 970])
{'frame_idx': 5, 'scores': [0.96875

In [0]:
from mlflow.tracking import MlflowClient
from mlflow.models import infer_signature

# specify the location the model will be saved/registered in Unity Catalog
catalog = "pubsec_video_processing"
schema = "cv"
model_name = "transformers-sam3-video"
model_full_name = f"{catalog}.{schema}.{model_name}"
mlflow.set_registry_uri("databricks-uc")

signature = infer_signature(model_input=model_input, model_output=results)

# Define conda environment with dependencies
conda_env = {
    'channels': ['conda-forge', 'defaults'],
    'dependencies': [
        'python=3.12.3',
        'pip',
        {
            'pip': [
                'mlflow>=2.10.0',
                'torch>=2.0.0',
                'git+https://github.com/huggingface/transformers.git',
                'Pillow',
                'torchvision',
                "cloudpickle==3.0.0",
                # 'pillow>=9.0.0',
                'numpy>=1.23.0',
                'pandas>=1.5.0',
                'accelerate>=0.20.0'
            ]
        }
    ],
    'name': 'sam3_tracker_env'
}

with mlflow.start_run() as run:
    mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=SAM3Video(),
        signature=signature,
        conda_env=conda_env,
        # extra_pip_requirements=[
        #   "torch",
        #   "git+https://github.com/huggingface/transformers.git",
        #   "Pillow"
        # ]
    )
    
    run_id = run.info.run_id
    print(f"Model registered! URI: runs:/{run_id}/model")

2026/01/12 22:31:28 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


device: cuda
loading model...
loading processor...
context loaded


 - accelerate (current: uninstalled, required: accelerate>=0.20.0)
To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.


Uploading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]

Uploading /local_disk0/user_tmp_data/spark-9cfdd6f6-1f17-40aa-aa96-10/tmpwazimghl/model/python_model.pkl:   0%â€¦

2026/01/12 22:31:44 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...


Model registered! URI: runs:/30822138f2e644818e8653c047727d88/model


INFO:py4j.clientserver:Closing down clientserver connection
2026/01/12 22:31:44 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!


In [0]:
# register the model using the "run" from above.
mlflow.register_model(model_uri=f"runs:/{run_id}/model", name=model_full_name)

Successfully registered model 'pubsec_video_processing.cv.transformers-sam3-video'.


Downloading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]

Downloading /local_disk0/user_tmp_data/spark-9cfdd6f6-1f17-40aa-aa96-10/tmpipw1eec0/model/python_model.pkl:   â€¦

INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection


Uploading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]

Uploading /local_disk0/user_tmp_data/spark-9cfdd6f6-1f17-40aa-aa96-10/tmpipw1eec0/model/python_model.pkl:   0%â€¦

Created version '1' of model 'pubsec_video_processing.cv.transformers-sam3-video'.


<ModelVersion: aliases=[], creation_timestamp=1768257118133, current_stage=None, description='', last_updated_timestamp=1768257124465, name='pubsec_video_processing.cv.transformers-sam3-video', run_id='30822138f2e644818e8653c047727d88', run_link=None, source='dbfs:/databricks/mlflow-tracking/3404012171082729/30822138f2e644818e8653c047727d88/artifacts/model', status='READY', status_message='', tags={}, user_id='justin.monaldo@databricks.com', version='1'>

In [0]:
# 2. Preview sample frames
import matplotlib.pyplot as plt

def decode_mask(encoded_mask: str) -> np.ndarray:
    """Decode base64 mask back to numpy array"""
    buf = BytesIO(base64.b64decode(encoded_mask))
    return np.load(buf)

def overlay_masks_on_frame(frame, masks, scores, alpha=0.5, score_threshold=0.5):
    """Overlay segmentation masks on a frame with different colors"""
    overlay = frame.copy()
    
    # Filter masks by score
    valid_indices = [i for i, score in enumerate(scores) if score >= score_threshold]
    
    # Generate colors for each mask
    colors = plt.cm.rainbow(np.linspace(0, 1, len(valid_indices)))[:, :3] * 255
    
    for idx, mask_idx in enumerate(valid_indices):
        mask = masks[mask_idx]
        color = colors[idx].astype(np.uint8)
        
        # Create colored mask
        colored_mask = np.zeros_like(frame)
        colored_mask[mask > 0.5] = color
        
        # Blend with frame
        overlay = cv2.addWeighted(overlay, 1, colored_mask, alpha, 0)
        
        # Optional: Add contours
        contours, _ = cv2.findContours(
            (mask > 0.5).astype(np.uint8), 
            cv2.RETR_EXTERNAL, 
            cv2.CHAIN_APPROX_SIMPLE
        )
        cv2.drawContours(overlay, contours, -1, color.tolist(), 2)
    
    return overlay
def display_sample_frames(
    original_video_path: str,
    prediction_output: list,
    num_samples: int = 5,
    alpha: float = 0.5
):
    """Display sample frames with segmentation overlays"""
    
    cap = cv2.VideoCapture(original_video_path)
    prediction_map = {pred["frame_idx"]: pred for pred in prediction_output}
    
    # Select evenly spaced frames that have predictions
    available_frames = sorted(prediction_map.keys())
    if len(available_frames) == 0:
        print("No predictions found!")
        return
    
    step = max(1, len(available_frames) // num_samples)
    sample_indices = available_frames[::step][:num_samples]
    
    fig, axes = plt.subplots(1, len(sample_indices), figsize=(20, 4))
    if len(sample_indices) == 1:
        axes = [axes]
    
    for ax, frame_idx in zip(axes, sample_indices):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        
        if ret:
            pred = prediction_map[frame_idx]
            masks = [decode_mask(m) for m in pred["masks"]]
            scores = pred["scores"]
            
            overlay = overlay_masks_on_frame(frame, masks, scores, alpha)

            overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
            
            ax.imshow(overlay_rgb)
            ax.set_title(f"Frame {frame_idx}\n{len(scores)} objects")
            ax.axis('off')
    
    cap.release()
    plt.tight_layout()
    plt.show()

display_sample_frames(
    original_video_path=video_path,
    prediction_output=results,
    num_samples=5,
    alpha=0.6
)

INFO:matplotlib.font_manager:generated new fontManager


[0;31m---------------------------------------------------------------------------[0m
[0;31mNameError[0m                                 Traceback (most recent call last)
File [0;32m<command-7066239738725458>, line 85[0m
[1;32m     81[0m     plt[38;5;241m.[39mtight_layout()
[1;32m     82[0m     plt[38;5;241m.[39mshow()
[1;32m     84[0m display_sample_frames(
[0;32m---> 85[0m     original_video_path[38;5;241m=[39mvideo_path,
[1;32m     86[0m     prediction_output[38;5;241m=[39mresults,
[1;32m     87[0m     num_samples[38;5;241m=[39m[38;5;241m5[39m,
[1;32m     88[0m     alpha[38;5;241m=[39m[38;5;241m0.6[39m
[1;32m     89[0m )

[0;31mNameError[0m: name 'video_path' is not defined

In [0]:
print(len(results))

195


In [0]:
# OUTPUT_FRAMES_DIR = "/Volumes/pubsec_video_processing/cv/images/bruno1_output_dir/"
OUTPUT_VIDEO_PATH = f"/Volumes/pubsec_video_processing/cv/images/{video_name}_output2.mp4"

# Open original video to get frames
print("Processing frames and applying masks...")
cap = cv2.VideoCapture(VIDEO_PATH)
fps = cap.get(cv2.CAP_PROP_FPS) or FPS

# Create a mapping of frame_idx to results
result_map = {r["frame_idx"]: r for r in results}

frame_idx = 0
saved_frames = []
saved_images = []

i = 0 
while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    
    # Convert BGR to RGB
    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    
    # If this frame has segmentation results, apply the mask
    if frame_idx in result_map:
        # i+=1
        # print('Processing frame:', i, 'of', len(result_map))
        res = result_map[frame_idx]
        
        if res["masks"]:
            # Get the first (highest score) mask
            mask = decode_mask(res["masks"][0])
            
            # Create visualization: overlay mask on original frame
            # Option 1: Show only segmented object
            # masked_frame = rgb_frame * mask[..., None]
            
            # Option 2: Overlay with transparency
            overlay = rgb_frame.copy()
            overlay[mask > 0.5] = [0, 255, 0]  # Green overlay
            masked_frame = cv2.addWeighted(rgb_frame, 0.7, overlay, 0.3, 0)
            
            # Option 3: Show mask as binary
            # masked_frame = (mask[..., None] * 255).astype(np.uint8).repeat(3, axis=2)
        else:
            masked_frame = rgb_frame
    # else:
    #     masked_frame = rgb_frame
    
    # Save frame
    saved_images.append(Image.fromarray(masked_frame))
    # frame_path = os.path.join(OUTPUT_FRAMES_DIR, f"frame_{frame_idx:05d}.png")
    # Image.fromarray(masked_frame).save(frame_path)
    # saved_frames.append(frame_path)
    
    frame_idx += 1

cap.release()
print(f"Saved {len(saved_images)} frames to memory") #{OUTPUT_FRAMES_DIR}")

# 3. Create full segmented video
import imageio
import os
import shutil
import tempfile

print("Writing video to temporary file...")
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
    temp_video_path = tmp_file.name

imageio.mimsave(
    temp_video_path,
    saved_images,
    fps=24,
    codec='libx264',
    pixelformat='yuv420p'
)

temp_size = os.path.getsize(temp_video_path)
print(f"Temporary video created: {temp_size:,} bytes ({temp_size/1024/1024:.2f} MB)")

# Copy to Volumes
print(f"Copying to Volumes: {OUTPUT_VIDEO_PATH}")
shutil.copy2(temp_video_path, OUTPUT_VIDEO_PATH)

final_size = os.path.getsize(OUTPUT_VIDEO_PATH)
print(f"âœ“ Video successfully saved to: {OUTPUT_VIDEO_PATH}")
print(f"  Final size: {final_size:,} bytes ({final_size/1024/1024:.2f} MB)")

# Clean up temporary file
if os.path.exists(temp_video_path):
    os.remove(temp_video_path)
    print("Cleaned up temporary file")