# Video Shot Boundary Detection with SigLip2 Embeddings

<a target="_blank" href="https://colab.research.google.com/github/everettVT/daft-video-embeddings/blob/main/workload/sbd_image_embeddings_siglip.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [71]:
!pip install -q "daft[huggingface]" transformers numpy

In [None]:

T, H, W, C = 16, 288, 288, 3
ROW_LIMIT = 2048


# Chi-Squared SBD Params
CHSQ_HISTOGRAM_BINS = 32
CHSQ_THRESHOLD = 0.3
CHSQ_MIN_SHOT_DURATION = 0.5 # seconds

PATHS = [
    "https://www.youtube.com/watch?v=WAsmZJ2kff0", # GPU Pipeline Optimization Explained
    "https://www.youtube.com/watch?v=BLcKDQRTFKY", # Wrangle PDFs with Custom UDFs
    "https://www.youtube.com/watch?v=Qnw6059ddgE", # Data and AI Processing at Scale
    "https://www.youtube.com/watch?v=eYXDSuNpKTk", # Life after Apache Spark
    "https://www.youtube.com/watch?v=3JWrg1DitaA", # Scaling Data Processing and ML Training with Daft + Ray
]

In [None]:
import daft
from daft.functions import embed_image
from daft import col, lit, DataType as dt

import numpy as np

In [None]:
df_frames = daft.read_video_frames(
    PATHS,
    image_height=H,
    image_width=W,
).limit(ROW_LIMIT).collect() # Materialize a few frames so we don't re-read from YT
df_frames.show(3)

### Generate SigLip2 Embeddings

In [None]:
df_emb = df_frames.with_column("emb_siglip2_base_patch_16_512", embed_image(df_frames["data"], model_name="google/siglip2-base-patch16-512", provider="transformers"))

### Chi-Squared Shot Boundary Detection

In [None]:
@daft.func()
def histogram(data: daft.Image, bins: int, range: tuple[float, float]) -> dt.tensor(dt.int64(), shape=(3, HISTOGRAM_BINS)):
    flat = np.asarray(data).reshape(-1, 3)
    hist = np.zeros((3, bins), dtype=np.int64)
    for i in range(3):
        h, _ = np.histogram(flat[:, i], bins=bins, range=range)
        hist[i] = h.astype(np.int64, copy=False)
    return hist


@daft.func(return_dtype=dt.list(dt.int64()))
def detect_sbd_chsq_hysteresis(
    hists: list[np.ndarray],
    indices: list[int],
    pts: list[int],
    high_threshold: float = 0.4,
    low_threshold: float = 0.25,
    min_shot_duration: float = 0.5,  # seconds
    ) -> list[int]:

    if len(hists) < 2:
        return []

    # Validate thresholds
    if not (low_threshold < high_threshold):
        # Swap to enforce proper ordering if misconfigured
        low_threshold, high_threshold = min(low_threshold, high_threshold), max(low_threshold, high_threshold)

    # Convert and sort by frame index
    h_arr = [np.asarray(h, dtype=np.float32) for h in hists]
    idx_arr = np.asarray(indices, dtype=np.int64)
    pts_arr = np.asarray(pts, dtype=np.int64)
    order = np.argsort(idx_arr)
    h_arr = [h_arr[i] for i in order]
    idx_arr = idx_arr[order]
    pts_arr = pts_arr[order]

    # Compute per-transition chi-squared distances
    eps = 1e-8
    dists = []
    for i in range(1, len(h_arr)):
        h1 = h_arr[i-1]
        h2 = h_arr[i]
        if h1.ndim == 1:
            h1 = h1[None, :]
        if h2.ndim == 1:
            h2 = h2[None, :]
        h1n = h1 / (np.sum(h1, axis=1, keepdims=True) + eps)
        h2n = h2 / (np.sum(h2, axis=1, keepdims=True) + eps)
        num = (h1n - h2n) ** 2
        den = h1n + h2n + eps
        chisq_per_channel = 0.5 * np.sum(num / den, axis=1)
        d = float(np.mean(chisq_per_channel))
        dists.append(d)
    dists = np.asarray(dists, dtype=np.float32)

    # Hysteresis scan: choose the local peak within a high-threshold crossing, end when below low-threshold
    boundaries: list[int] = []
    min_us = int(min_shot_duration * 1_000_000.0)
    last_boundary_pts = None

    inside_cluster = False
    peak_dist = -1.0
    peak_i = -1

    for i in range(1, len(idx_arr)):
        d = dists[i-1]
        if not inside_cluster:
            if d >= high_threshold:
                inside_cluster = True
                peak_dist = d
                peak_i = i
        else:
            # Track peak while in cluster
            if d > peak_dist:
                peak_dist = d
                peak_i = i
            # Exit when we drop below low threshold; commit the peak as boundary
            if d <= low_threshold:
                cand_pts = int(pts_arr[peak_i])
                if last_boundary_pts is None or (cand_pts - last_boundary_pts) >= min_us:
                    boundaries.append(int(idx_arr[peak_i]))
                    last_boundary_pts = cand_pts
                inside_cluster = False
                peak_dist = -1.0
                peak_i = -1

    # If we ended still inside a cluster, commit the peak
    if inside_cluster and peak_i >= 0:
        cand_pts = int(pts_arr[peak_i])
        if last_boundary_pts is None or (cand_pts - last_boundary_pts) >= min_us:
            boundaries.append(int(idx_arr[peak_i]))

    return boundaries


In [None]:
df_shots = (
    df_emb
    .with_column("histogram", histogram(df_frames["data"], bins=CHSQ_HISTOGRAM_BINS))
    .with_column("shot_index", (col("frame_time") / daft.lit(CHSQ_MIN_SHOT_DURATION)).cast(dt.int64()))
    .sort("frame_index")
    .groupby("path", "show_index")
    .agg_list("frame_index", "histogram", "emb_siglip2_base_patch_16_512") #  Omit image data
    .with_column("shot_boundaries", detect_sbd_chsq(col("frame_index"), col("histogram"), col("emb_siglip2_base_patch_16_512")))
)
df_clips.show(3)


In [None]:
df_sbd