In [1]:
# @title Prepare environment

import os
import sys

# Fetch VideoPrism repository if Python does not know about it and install
# dependencies needed for this notebook.
if not os.path.exists("videoprism_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/everettVT/videoprism.git videoprism_repo
  os.chdir('./videoprism_repo')
  !pip install .
  os.chdir('..')

# Append VideoPrism code to Python import path.
if "videoprism_repo" not in sys.path:
  sys.path.append("videoprism_repo")

Processing /content/videoprism_repo
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting einshape (from videoprism==1.0.0)
  Downloading einshape-1.0-py3-none-any.whl.metadata (706 bytes)
Collecting tensorflow==2.19.0 (from videoprism==1.0.0)
  Downloading tensorflow-2.19.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting tensorboard==2.19.0 (from videoprism==1.0.0)
  Downloading tensorboard-2.19.0-py3-none-any.whl.metadata (1.8 kB)
Collecting tensorboard-data-server<0.8.0,>=0.7.0 (from tensorboard==2.19.0->videoprism==1.0.0)
  Downloading tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl.metadata (1.1 kB)
Collecting werkzeug>=1.0.1 (from tensorboard==2.19.0->videoprism==1.0.0)
  Downloading werkzeug-3.1.3-py3-none-any.whl.metadata (3.7 kB)
Collecting astunparse>=1.6.0 (from tensorflow==2.19.0->videoprism=

In [2]:
!pip install "daft>=0.6.1" av yt-dlp

Collecting daft>=0.6.1
  Downloading daft-0.6.1-cp39-abi3-manylinux_2_24_x86_64.whl.metadata (12 kB)
Collecting av
  Downloading av-15.1.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.6 kB)
Collecting yt-dlp
  Downloading yt_dlp-2025.9.5-py3-none-any.whl.metadata (177 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m177.1/177.1 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
Downloading daft-0.6.1-cp39-abi3-manylinux_2_24_x86_64.whl (47.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.4/47.4 MB[0m [31m54.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading av-15.1.0-cp312-cp312-manylinux_2_28_x86_64.whl (39.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m75.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading yt_dlp-2025.9.5-py3-none-any.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m147.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: yt-

In [3]:
import daft
from daft import col, DataType as dt
import numpy as np
import jax
import jax.numpy as jnp
from jax.extend import backend
import tensorflow as tf
from videoprism import models as vp



- B: batch size (number of videos in a batch).
- T: number of frames per video clip (typically 16).
- N: tokens per frame (for 288×288 with 18×18 patches → 16×16 = 256).
- D: embedding dimension (Base: 768; Large: 1024).

Video-text model returns:
- video_embeddings: [B, D] (global video embeddings).
- text_embeddings: [B, D] (global text embeddings).
- Optional: frame_embeddings [B, T, D]; tokens [B, T×N, D]

Retrieval:
- cosine similarity reduces to dot product since outputs are L2-normalized.
- For a single video vs K texts: [1, D] @ [K, D]^T → [1, K].

In [4]:
PATHS = ["https://www.youtube.com/watch?v=wKOC_w4oKO8"]
B, T, H, W, C = 2, 16, 288, 288, 3
MODEL_NAME = 'videoprism_lvt_public_v1_base' # or 'videoprism_lvt_public_v1_large'

In [5]:
df_frames = daft.read_video_frames(
    PATHS,
    image_height=H,
    image_width=W,
).limit(64).collect()

🗡️ 🐟 Limit 64: 00:00 

🗡️ 🐟 PythonFunction Scan: 00:00 

In [6]:
df_frames.show(3)

path Utf8,frame_index Int64,frame_time Float64,frame_time_base Utf8,frame_pts Int64,frame_dts Int64,frame_duration Int64,is_key_frame Boolean,data Image[RGB; 288 x 288]
https://www.youtube.com/watch?v=wKOC_w4oKO8,43,1.4347681014347682,1/11988,17200,17200,400,False,
https://www.youtube.com/watch?v=wKOC_w4oKO8,44,1.4681348014681348,1/11988,17600,17600,400,False,
https://www.youtube.com/watch?v=wKOC_w4oKO8,45,1.5015015015015014,1/11988,18000,18000,400,False,


### Sampling Strategies

In [7]:
df_grouped = (
    df_frames
    .with_column("group_index", df_frames["frame_index"] // T)
    .groupby("path", "group_index")
    .agg_list("data", "frame_index")
)
df_grouped.show(3)

path Utf8,group_index Int64,data List[Image[RGB; 288 x 288]],frame_index List[Int64]
https://www.youtube.com/watch?v=wKOC_w4oKO8,1,"[<FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>]","[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]"
https://www.youtube.com/watch?v=wKOC_w4oKO8,3,"[<FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>]","[48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]"
https://www.youtube.com/watch?v=wKOC_w4oKO8,0,"[<FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>]","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]"


### Stack, Normalize, and Cast

In [8]:
@daft.func(return_dtype=dt.tensor(dt.float32(), shape=(1, 16,288, 288, 3)))
def stack_clip(frames: list[np.ndarray], indices: list[int], clip_size: int):
    """Stacks a list of frames into a single numpy array

    Args:
        frames: List[T] of (H,W,3) float32
        indices: List[T] of int

    Returns:
        (1,T,H,W,3) float32 in [0,1]

    In a parallel/distributed groupby, a pre-group sort isn’t guaranteed
    to survive aggregation order; partitions can concatenate in
    non-deterministic order. Additionally, the image dtype is natively a
    list[uint8], so we need to cast to float32 before normalizing from
    [0,255] to [0,1].

    Steps:
    1. Aggregate both image_tensor and frame_index.
    2. Sort by frame_index inside the group-level UDF, then stack.
    3. Normalize and cast in one step.
    4. Add a batch dimension and return.

    """

    # Don't assume frames are sorted already:
    order = np.argsort(np.asarray(indices))

    # Convert Daft Image to np.ndarray
    def to_np(x):
        if hasattr(x, "to_numpy"):
            return x.to_numpy()          # Daft Image -> np.ndarray (H,W,C) uint8
        return np.asarray(x)

    # Sort frames by frame_index
    frames_sorted = [to_np(frames[i]) for i in order]

    # Ensure Tails are padded with duplicates
    if len(order) < clip_size:
        frames_sorted.extend([frames_sorted[-1]] * (clip_size - len(order)))

    # Stack, Normalize, and Cast in one step
    x = np.stack(frames_sorted[:clip_size], axis=0).astype(np.float32) / 255.0 # (T,H,W,3) float32 in [0,1]

    return x[None, ...] # [1,T,H,W,C] where T=clip_size

df_clips = df_grouped.with_column("clip", stack_clip(df_grouped["data"], df_grouped["frame_index"], clip_size=T))
df_clips.show(3)


path Utf8,group_index Int64,data List[Image[RGB; 288 x 288]],frame_index List[Int64],"clip FixedShapeTensor[Float32; [1, 16, 288, 288, 3]]"
https://www.youtube.com/watch?v=wKOC_w4oKO8,3,"[<FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>]","[48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]",<FixedShapeTensor>
https://www.youtube.com/watch?v=wKOC_w4oKO8,1,"[<FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>]","[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]",<FixedShapeTensor>
https://www.youtube.com/watch?v=wKOC_w4oKO8,0,"[<FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>]","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]",<FixedShapeTensor>


In [12]:
@daft.udf(
    return_dtype = dt.embedding(dt.float32(), 768),
    batch_size=1, # clips per batch (tune for throughput)
)
class VideoPrismVideoUDF:
    def __init__(self, model_name: str = "videoprism_lvt_public_v1_base"):
        from videoprism import models as vp
        self.model = vp.get_model(model_name)
        self.params = vp.load_pretrained_weights(model_name)
        self.text_tokenizer = vp.load_text_tokenizer('c4_en')

        @jax.jit
        def vf_b(x):  # [B,T,288,288,3] -> [B,D]
            v, _, _ = self.model.apply(self.params, x, None, None, train=False)
            return v
        @jax.jit
        def vf_1(x):  # [1,T,288,288,3] -> [1,D]
            v, _, _ = self.model.apply(self.params, x, None, None, train=False)
            return v

        self.vf_b = vf_b
        self.vf_1 = vf_1

        # Warmup both
        _ = self.vf_b(jnp.zeros((B, T, H, W, C), jnp.float32)).block_until_ready()
        _ = self.vf_1(jnp.zeros((1, T, H, W, C), jnp.float32)).block_until_ready()

    def __call__(self,
        clips: list[np.ndarray],
    ):
        n = len(clips)
        if n == B:
            # Batch Inference
            xb = jnp.stack(clips, axis=0) # [B,T,H,W,C]
            v = self.vf_b(jnp.asarray(xb)) # [B,768]
            return [v[i] for i in range(B)]
        else:
            return [self.vf_1(clip) for clip in clips]


In [13]:
df_video_embs = df_clips.with_column("video_embeddings", VideoPrismVideoUDF(df_clips["clip"])).collect()




🗡️ 🐟 InMemorySource: 00:00 

🗡️ 🐟 Project: 00:00 

🗡️ 🐟 GroupedAggregate: 00:00 

🗡️ 🐟 UDF stack_clip: 00:00 

🗡️ 🐟 UDF VideoPrismVideoUDF: 00:00 

ERROR:daft_local_execution:Error when running pipeline node UDF VideoPrismVideoUDF


DaftCoreException: DaftError::ValueError We can only convert numeric python types to List, got Embedding[Float32; 768]

In [11]:
import jax
import jax.numpy as jnp

# Configure JAX to use the TPU backend
jax.config.update("jax_platform_name", "tpu")

# Verify that JAX can see the TPU device
try:
    tpu_devices = jax.devices("tpu")
    print(f"JAX sees {len(tpu_devices)} TPU device(s): {tpu_devices}")
except RuntimeError as e:
    print(f"Error detecting TPU devices: {e}")
    print("Please ensure you have selected a TPU runtime in Colab.")

JAX sees 1 TPU device(s): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]
