[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/healthonrails/annolid/blob/main/docs/tutorials/YOLO_SAHI_inference_for_ultralytics.ipynb)

## 0. Preperation

- Install latest version of SAHI and ultralytics:

In [None]:
!pip install -U torch sahi ultralytics
!pip install ipywidgets
!pip install supervision

## 1. Sliced Inference with an Ultralytics Model

- Instantiate a detection model by defining model weight path and other parameters:

In [None]:
import cv2
import numpy as np
import os
import torch
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from google.colab.patches import cv2_imshow
from IPython.display import clear_output

# Initialize the detection model (using Ultralytics YOLO weights)
detection_model = AutoDetectionModel.from_pretrained(
    model_type='ultralytics',
    model_path="yolo11n.pt",  # Use the appropriate model weight file
    confidence_threshold=0.35,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

def process_video(input_video_path, output_video_path,
                  slice_height=256, slice_width=256,
                  overlap_height_ratio=0.2, overlap_width_ratio=0.2,
                  export_dir="/content/sahi_video_frames"):
    """
    Processes a video frame-by-frame by running sliced prediction on each frame.
    Annotated frames are generated via SAHI's export_visuals method (which saves a PNG file).
    The PNG is then read back as a NumPy array for writing to the output video.
    Every 30 frames the annotated frame is displayed inline (suitable for Colab).

    Args:
        input_video_path (str): Path to the input video file.
        output_video_path (str): Path to save the annotated output video.
        slice_height (int): Height of each slice used during inference.
        slice_width (int): Width of each slice used during inference.
        overlap_height_ratio (float): Overlap ratio for slice height.
        overlap_width_ratio (float): Overlap ratio for slice width.
        export_dir (str): Directory to save temporary annotated frames.
    """
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        print("Error: Could not open video file.")
        return

    # Create the export directory if it doesn't exist.
    os.makedirs(export_dir, exist_ok=True)

    # Retrieve video properties.
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    # Set up the video writer for the output video.
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))

    frame_count = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Run sliced prediction on the current frame.
        # get_sliced_prediction returns a PredictionResult instance.
        result = get_sliced_prediction(
            frame,
            detection_model,
            slice_height=slice_height,
            slice_width=slice_width,
            overlap_height_ratio=overlap_height_ratio,
            overlap_width_ratio=overlap_width_ratio,
        )

        # Use SAHI's export_visuals to annotate the frame.
        # A unique file name is generated for each frame.
        file_name = f"frame_{frame_count}"
        result.export_visuals(
            export_dir=export_dir,
            file_name=file_name,
            text_size=1,
            rect_th=2,
            hide_conf=False
        )

        # Read the annotated frame from disk.
        annotated_frame_path = os.path.join(export_dir, f"{file_name}.png")
        annotated_frame = cv2.imread(annotated_frame_path)

        # Write the annotated frame to the output video.
        out.write(annotated_frame)

        # Every 30 frames, display the annotated frame inline.
        if frame_count % 30 == 0:
            cv2_imshow(annotated_frame)
            clear_output(wait=True)
            print(f"Processed {frame_count} frames...")

        frame_count += 1

    cap.release()
    out.release()
    print("Video processing complete.")

In [None]:
# Example usage:
input_video = "/content/video-4.mp4"
output_video = "/video-4_tracked.mp4"
process_video(input_video, output_video)