In [None]:
# @title Prepare environment

import os
import sys

# Fetch VideoPrism repository if Python does not know about it and install
# dependencies needed for this notebook.
if not os.path.exists("videoprism_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/everettVT/videoprism.git videoprism_repo
  os.chdir('./videoprism_repo')
  !pip install .
  os.chdir('..')

# Append VideoPrism code to Python import path.
if "videoprism_repo" not in sys.path:
  sys.path.append("videoprism_repo")

In [None]:
!pip install "daft>=0.6.1"

In [None]:
import daft
from daft import col, DataType as dt
import numpy as np
import jax
import jax.numpy as jnp
from jax.extend import backend
import tensorflow as tf
from videoprism import models as vp

- B: batch size (number of videos in a batch).
- T: number of frames per video clip (typically 16).
- N: tokens per frame (for 288×288 with 18×18 patches → 16×16 = 256).
- D: embedding dimension (Base: 768; Large: 1024).

Video-text model returns:
- video_embeddings: [B, D] (global video embeddings).
- text_embeddings: [B, D] (global text embeddings).
- Optional: frame_embeddings [B, T, D]; tokens [B, T×N, D]

Retrieval:
- cosine similarity reduces to dot product since outputs are L2-normalized.
- For a single video vs K texts: [1, D] @ [K, D]^T → [1, K].

In [None]:
PATHS = ["/Users/everett-founder/Movies/digitlism.mp4"]
B, T, H, W, C = 2, 16, 288, 288, 3
MODEL_NAME = 'videoprism_lvt_public_v1_base' # or 'videoprism_lvt_public_v1_large'

In [None]:
df_frames = daft.read_video_frames(
    PATHS,
    image_height=H,
    image_width=W,
)
df_frames.show(3)

### Sampling Strategies

In [None]:
@daft.func()
def normalize(image: np.ndarray) -> dt.tensor(dt.float32()):
    return np.asarray(image).astype(np.float32) / 255.0

df_norm = df_frames.with_column("image_tensor", normalize(col("data")))
df_norm.show(3)

In [None]:
df_grouped = (
    df_frames
    .with_column("group_index", df_frames["frame_index"] // T)
    .groupby("path", "group_index")
    .agg_list("data", "frame_index")
)
df_grouped.show(3)

### Stack, Normalize, and Cast

In [None]:
@daft.func(return_dtype=dt.tensor(dt.float32(), shape=(1, 16,288, 288, 3)))
def stack_clip(frames: list[np.ndarray], indices: list[int], clip_size: int):
    """Stacks a list of frames into a single numpy array

    Args:
        frames: List[T] of (H,W,3) float32
        indices: List[T] of int

    Returns:
        (1,T,H,W,3) float32 in [0,1]

    In a parallel/distributed groupby, a pre-group sort isn’t guaranteed
    to survive aggregation order; partitions can concatenate in
    non-deterministic order. Additionally, the image dtype is natively a
    list[uint8], so we need to cast to float32 before normalizing from
    [0,255] to [0,1].

    Steps:
    1. Aggregate both image_tensor and frame_index.
    2. Sort by frame_index inside the group-level UDF, then stack.
    3. Normalize and cast in one step.
    4. Add a batch dimension and return.

    """

    # Don't assume frames are sorted already:
    order = np.argsort(np.asarray(indices))

    # Convert Daft Image to np.ndarray
    def to_np(x):
        if hasattr(x, "to_numpy"):
            return x.to_numpy()          # Daft Image -> np.ndarray (H,W,C) uint8
        return np.asarray(x)

    # Sort frames by frame_index
    frames_sorted = [to_np(frames[i]) for i in order]

    # Ensure Tails are padded with duplicates
    if len(order) < clip_size:
        frames_sorted.extend([frames_sorted[-1]] * (clip_size - len(order)))

    # Stack, Normalize, and Cast in one step
    x = np.stack(frames_sorted[:clip_size], axis=0).astype(np.float32) / 255.0 # (T,H,W,3) float32 in [0,1]

    return x[None, ...] # [1,T,H,W,C] where T=clip_size

df_clips = df_grouped.with_column("clip", stack_clip(df_grouped["data"], df_grouped["frame_index"], clip_size=NUM_FRAMES))
df_clips.show(3)


In [None]:
@daft.udf(
    return_dtype = dt.list(dt.embedding(dt.float32(), 768)),
    batch_size=B, # clips per batch (tune for throughput)
)
class VideoPrismVideoUDF:
    def __init__(self, model_name: str = "videoprism_lvt_public_v1_base"):
        from videoprism import models as vp
        self.model = vp.get_model(model_name)
        self.params = vp.load_pretrained_weights(model_name)
        self.text_tokenizer = vp.load_text_tokenizer('c4_en')

        @jax.jit
        def vf_b(x):  # [B,T,288,288,3] -> [B,D]
            v, _, _ = self.model.apply(self.params, x, None, None, train=False)
            return v
        @jax.jit
        def vf_1(x):  # [1,T,288,288,3] -> [1,D]
            v, _, _ = self.model.apply(self.params, x, None, None, train=False)
            return v

        self.vf_b = vf_b
        self.vf_1 = vf_1

        # Warmup both
        _ = self.vf_b(jnp.zeros((B, T, H, W, C), jnp.float32)).block_until_ready()
        _ = self.vf_1(jnp.zeros((1, T, H, W, C), jnp.float32)).block_until_ready()

    def __call__(self,
        clips: list[np.ndarray],
    ):
        n = len(clips)
        if n == B:
            # Batch Inference
            xb = jnp.stack(clips, axis=0) # [B,T,H,W,C]
            v = self.vf_b(jnp.asarray(xb)) # [B,768]
            return [v[i] for i in range(B)]
        else:
            return [self.vf_1(clip) for clip in clips]
