In [1]:
import daft
from daft import col, DataType as dt
import numpy as np
import jax
import jax.numpy as jnp
from jax.extend import backend
import tensorflow as tf
from videoprism import models as vp

tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

- B: batch size (number of videos in a batch).
- T: number of frames per video clip (typically 16).
- N: tokens per frame (for 288×288 with 18×18 patches → 16×16 = 256).
- D: embedding dimension (Base: 768; Large: 1024).

Video-text model returns:
- video_embeddings: [B, D] (global video embeddings).
- text_embeddings: [B, D] (global text embeddings).
- Optional: frame_embeddings [B, T, D]; tokens [B, T×N, D]

Retrieval: 
- cosine similarity reduces to dot product since outputs are L2-normalized.
- For a single video vs K texts: [1, D] @ [K, D]^T → [1, K].

In [None]:
PATHS = ["/Users/everett-founder/Movies/digitlism.mp4"]
NUM_FRAMES = 16
FRAME_SIZE = 288
MODEL_NAME = 'videoprism_lvt_public_v1_base' # or 'videoprism_lvt_public_v1_large'

In [25]:
df_frames = daft.read_video_frames(
    PATHS,
    image_height=288, 
    image_width=288,
)
df_frames.show(3)

path Utf8,frame_index Int64,frame_time Float64,frame_time_base Utf8,frame_pts Int64,frame_dts Int64,frame_duration Int64,is_key_frame Boolean,data Image[RGB; 288 x 288]
file:///Users/everett-founder/Movies/digitlism.mp4,0,0.0,1/57600,0,0,960,True,
file:///Users/everett-founder/Movies/digitlism.mp4,1,0.0166666666666666,1/57600,960,960,960,False,
file:///Users/everett-founder/Movies/digitlism.mp4,2,0.0333333333333333,1/57600,1920,1920,960,False,


### Sampling Strategies

In [34]:
@daft.func()
def normalize(image: np.ndarray) -> dt.tensor(dt.float32()):
    return np.asarray(image).astype(np.float32) / 255.0

df_norm = df_frames.with_column("image_tensor", normalize(col("data")))
df_norm.show(3)

path Utf8,frame_index Int64,frame_time Float64,frame_time_base Utf8,frame_pts Int64,frame_dts Int64,frame_duration Int64,is_key_frame Boolean,data Image[RGB; 288 x 288],image_tensor Tensor(Float32)
file:///Users/everett-founder/Movies/digitlism.mp4,0,0.0,1/57600,0,0,960,True,,"<Tensor shape=(288, 288, 3)>"
file:///Users/everett-founder/Movies/digitlism.mp4,1,0.0166666666666666,1/57600,960,960,960,False,,"<Tensor shape=(288, 288, 3)>"
file:///Users/everett-founder/Movies/digitlism.mp4,2,0.0333333333333333,1/57600,1920,1920,960,False,,"<Tensor shape=(288, 288, 3)>"


In [27]:
df_grouped = (
    df_frames
    .with_column("group_index", df_frames["frame_index"] // NUM_FRAMES)
    .groupby("path", "group_index")
    .agg_list("data", "frame_index")
)
df_grouped.show(3)

path Utf8,group_index Int64,data List[Image[RGB; 288 x 288]],frame_index List[Int64]
file:///Users/everett-founder/Movies/digitlism.mp4,20,"[<FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>]","[320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335]"
file:///Users/everett-founder/Movies/digitlism.mp4,18,"[<FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>]","[288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303]"
file:///Users/everett-founder/Movies/digitlism.mp4,25,"[<FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>]","[400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415]"


### Stack, Normalize, and Cast

In [39]:
@daft.func(return_dtype=dt.tensor(dt.float32(), shape=(1, 16,288, 288, 3)))
def stack_clip(frames: list[np.ndarray], indices: list[int], clip_size: int):
    """Stacks a list of frames into a single numpy array
    
    Args:
        frames: List[T] of (H,W,3) float32
        indices: List[T] of int

    Returns:
        (1,T,H,W,3) float32 in [0,1]

    In a parallel/distributed groupby, a pre-group sort isn’t guaranteed 
    to survive aggregation order; partitions can concatenate in 
    non-deterministic order. Additionally, the image dtype is natively a
    list[uint8], so we need to cast to float32 before normalizing from 
    [0,255] to [0,1].

    Steps:
    1. Aggregate both image_tensor and frame_index.
    2. Sort by frame_index inside the group-level UDF, then stack.
    3. Normalize and cast in one step.
    4. Add a batch dimension and return.

    """

    # Don't assume frames are sorted already: 
    order = np.argsort(np.asarray(indices))

    # Convert Daft Image to np.ndarray
    def to_np(x):
        if hasattr(x, "to_numpy"):
            return x.to_numpy()          # Daft Image -> np.ndarray (H,W,C) uint8
        return np.asarray(x)

    # Sort frames by frame_index
    frames_sorted = [to_np(frames[i]) for i in order]

    # Ensure Tails are padded with duplicates
    if len(order) < clip_size:
        frames_sorted.extend([frames_sorted[-1]] * (clip_size - len(order)))

    # Stack, Normalize, and Cast in one step
    x = np.stack(frames_sorted[:clip_size], axis=0).astype(np.float32) / 255.0 # (T,H,W,3) float32 in [0,1]
    
    return x[None, ...] # [1,T,H,W,C] where T=clip_size

df_clips = df_grouped.with_column("clip", stack_clip(df_grouped["data"], df_grouped["frame_index"], clip_size=NUM_FRAMES))
df_clips.show(3)


path Utf8,group_index Int64,data List[Image[RGB; 288 x 288]],frame_index List[Int64],"clip FixedShapeTensor[Float32; [1, 16, 288, 288, 3]]"
file:///Users/everett-founder/Movies/digitlism.mp4,20,"[<FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>]","[320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335]",<FixedShapeTensor>
file:///Users/everett-founder/Movies/digitlism.mp4,18,"[<FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>]","[288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303]",<FixedShapeTensor>
file:///Users/everett-founder/Movies/digitlism.mp4,16,"[<FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>, <FixedShapeImage>]","[256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271]",<FixedShapeTensor>


In [None]:
@daft.udf(return_dtype=dt.struct({
    "video_embed": dt.embedding(dt.float32(), 768),
    "text_embed": dt.embedding(dt.float32(), 768),
    "spatial_features": dt.embedding(dt.float32(), 768),
    "spatiotemporal_features": dt.embedding(dt.float32(), 768),
    "frame_embeddings": dt.embedding(dt.float32(), 768),
    "tokens": dt.embedding(dt.float32(), 768),
    }),
    batch_size=2, # clips per batch (tune for throughput)
)
class VideoPrismUDF:
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.model = vp.get_model(model_name)
        self.loaded_state = vp.load_pretrained_weights(model_name)
        
        # Precompile forward pass per Worker
        self.forward_fn = jax.jit(self._vp_forward)

    @jax.jit
    def _vp_forward(self, video_inputs: jax.Array, text_ids: jax.Array | None, text_paddings: jax.Array | None):
        return self.model.apply(self.loaded_state, video_inputs, text_ids, text_paddings, train=False)

    def __call__(self,
        clips: list[np.ndarray],
        text_ids: np.ndarray | None,
        text_paddings: np.ndarray | None,
    ): 
        x = np.stack(clips, axis=0)                    # [B, 16, 288, 288, 3]
        v_embs, _, _ = self.forward_fn(jnp.asarray(x), None, None, train=False)
        v_embs = np.asarray(v_embs)                    # [B, D]
        return [v_embs[i] for i in range(v_embs.shape[0])]