# VideoPrism Video-Text Encoder Demo

[![Paper](https://img.shields.io/badge/arXiv-2402.13217-red.svg)](https://arxiv.org/abs/2402.13217)
[![Blog](https://img.shields.io/badge/Google_Research-Blog-green.svg)](https://research.google/blog/videoprism-a-foundational-visual-encoder-for-video-understanding/)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

This notebook provides an example of video and text feature extraction with a pre-trained VideoPrism video-text model for zero-shot video classification/retrieval.

Please run this demo on Google Colab with (faster) or without TPU.

## Set up

In [1]:
# @title Prepare environment

import os
import sys

# Fetch VideoPrism repository if Python does not know about it and install
# dependencies needed for this notebook.
if not os.path.exists("videoprism_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-deepmind/videoprism.git videoprism_repo
  os.chdir('./videoprism_repo')
  !pip install .
  os.chdir('..')

# Append VideoPrism code to Python import path.
if "videoprism_repo" not in sys.path:
  sys.path.append("videoprism_repo")

# Install missing dependencies.
!pip install mediapy

import jax
from jax.extend import backend
import tensorflow as tf

# Do not let TF use the GPU or TPUs.
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.get_backend().platform}")
print(f"JAX devices:  {jax.device_count()}")

JAX version:  0.5.2
JAX platform: tpu
JAX devices:  8


In [2]:
# @title Load dependencies and define utilities

import mediapy
import numpy as np
from PIL import Image


def read_and_preprocess_video(
    filename: str, target_num_frames: int, target_frame_size: tuple[int, int]
):
  """Reads and preprocesses a video."""

  frames = mediapy.read_video(filename)

  # Sample to target number of frames.
  frame_indices = np.linspace(
      0, len(frames), num=target_num_frames, endpoint=False, dtype=np.int32
  )
  frames = np.array([frames[i] for i in frame_indices])

  # Resize to target size.
  original_height, original_width = frames.shape[-3:-1]
  target_height, target_width = target_frame_size
  assert (
      original_height * target_width == original_width * target_height
  ), 'Currently does not support aspect ratio mismatch.'
  frames = mediapy.resize_video(frames, shape=target_frame_size)

  # Normalize pixel values to [0.0, 1.0].
  frames = mediapy.to_float01(frames)

  return frames


def compute_similarity_matrix(
    video_embeddings,
    text_embeddings,
    temperature: float,
    apply_softmax: str | None = None,
) -> np.ndarray:
  """Computes cosine similarity matrix."""
  assert apply_softmax in [None, 'over_texts', 'over_videos']
  emb_dim = video_embeddings[0].shape[-1]
  assert emb_dim == text_embeddings[0].shape[-1]

  video_embeddings = np.array(video_embeddings).reshape(-1, emb_dim)
  text_embeddings = np.array(text_embeddings).reshape(-1, emb_dim)
  similarity_matrix = np.dot(video_embeddings, text_embeddings.T)

  if temperature is not None:
    similarity_matrix /= temperature

  if apply_softmax == 'over_videos':
    similarity_matrix = np.exp(similarity_matrix)
    similarity_matrix = similarity_matrix / np.sum(
        similarity_matrix, axis=0, keepdims=True
    )
  elif apply_softmax == 'over_texts':
    similarity_matrix = np.exp(similarity_matrix)
    similarity_matrix = similarity_matrix / np.sum(
        similarity_matrix, axis=1, keepdims=True
    )

  return similarity_matrix

In [3]:
# @title Load model

import jax
import jax.numpy as jnp
from videoprism import models as vp

# MODEL_NAME = 'videoprism_lvt_public_v1_base'  # @param ['videoprism_lvt_public_v1_base', 'videoprism_lvt_public_v1_large'] {allow-input: false}
MODEL_NAME = 'videoprism_lvt_public_v1_large'
NUM_FRAMES = 16
FRAME_SIZE = 288

flax_model = vp.get_model(MODEL_NAME)
loaded_state = vp.load_pretrained_weights(MODEL_NAME)
text_tokenizer = vp.load_text_tokenizer('c4_en')


@jax.jit
def forward_fn(inputs, text_token_ids, text_paddings, train=False):
  return flax_model.apply(
      loaded_state,
      inputs,
      text_token_ids,
      text_paddings,
      train=train,
  )

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


flax_lvt_large_f8r288_repeated.npz:   0%|          | 0.00/2.32G [00:00<?, ?B/s]

# Example: Zero-shot Video Classification/Retrieval

In this example, we extract the embedding of an input video, and the embeddings of five senetence. We measure the cosine similarites between the videos and sentences.

In [4]:
# @title Specify input video
VIDEO_FILE_PATH = 'videoprism_repo/videoprism/assets/water_bottle_drumming.mp4'  # @param {type: "string"}

frames = read_and_preprocess_video(
    VIDEO_FILE_PATH,
    target_num_frames=NUM_FRAMES,
    target_frame_size=[FRAME_SIZE, FRAME_SIZE],
)
frames = jnp.asarray(frames[None, ...])  # Add batch dimension.

In [5]:
# @title Specify input text queries
TEXT_QUERY_CSV = 'playing drums,sitting,playing flute,playing at playground,concert'  # @param {type: "string"}
PROMPT_TEMPLATE = 'a video of {}.'

text_queries = TEXT_QUERY_CSV.split(',')
text_queries = [PROMPT_TEMPLATE.format(t) for t in text_queries]
text_ids, text_paddings = vp.tokenize_texts(text_tokenizer, text_queries)

print('Input text queries:')
for i, text in enumerate(text_queries):
  print(f'({i + 1}) {text}')

Input text queries:
(1) a video of playing drums.
(2) a video of sitting.
(3) a video of playing flute.
(4) a video of playing at playground.
(5) a video of concert.


In [18]:
# text_ids

In [6]:
# @title Compute video-to-text retrieval results
video_embeddings, text_embeddings, _ = forward_fn(
    frames, text_ids, text_paddings)

TEMPERATURE = 0.01  # @param {type: "number"}
similarity_matrix = compute_similarity_matrix(
    video_embeddings,
    text_embeddings,
    temperature=TEMPERATURE,
    apply_softmax='over_texts',
)

In [7]:
v2t_similarity_vector = similarity_matrix[0]
top_indices = np.argsort(v2t_similarity_vector)[::-1]

print(f'Query video: {os.path.basename(VIDEO_FILE_PATH)}')
mediapy.show_video(frames[0], fps=6.0)

for k, j in enumerate(top_indices):
  print(
      'Top-%d retrieved text: %s [Similarity = %0.4f]'
      % (k + 1, text_queries[j], v2t_similarity_vector[j])
  )
print(f'\nThis is {text_queries[top_indices[0]]}')

Query video: water_bottle_drumming.mp4


0
This browser does not support the video tag.


Top-1 retrieved text: a video of sitting. [Similarity = 0.6414]
Top-2 retrieved text: a video of playing drums. [Similarity = 0.2726]
Top-3 retrieved text: a video of playing flute. [Similarity = 0.0442]
Top-4 retrieved text: a video of playing at playground. [Similarity = 0.0409]
Top-5 retrieved text: a video of concert. [Similarity = 0.0009]

This is a video of sitting.


In [12]:
!pip install wget openpyxl

Collecting openpyxl
  Downloading openpyxl-3.1.5-py2.py3-none-any.whl.metadata (2.5 kB)
Collecting et-xmlfile (from openpyxl)
  Downloading et_xmlfile-2.0.0-py3-none-any.whl.metadata (2.7 kB)
Downloading openpyxl-3.1.5-py2.py3-none-any.whl (250 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m250.9/250.9 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading et_xmlfile-2.0.0-py3-none-any.whl (18 kB)
Installing collected packages: et-xmlfile, openpyxl
Successfully installed et-xmlfile-2.0.0 openpyxl-3.1.5


In [9]:
import pandas as pd

In [13]:
data_file = '/content/video_rating_results.xlsx'
df = pd.read_excel(data_file)
df.loc[:, 'md5']
df.info()
df.head()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 53 entries, 0 to 52
Data columns (total 26 columns):
 #   Column                              Non-Null Count  Dtype  
---  ------                              --------------  -----  
 0   caption                             53 non-null     object 
 1   video_url                           53 non-null     object 
 2   sa                                  53 non-null     int64  
 3   pc                                  53 non-null     int64  
 4   joint                               53 non-null     int64  
 5   action                              52 non-null     object 
 6   is_hard                             53 non-null     int64  
 7   physics_rules_followed              53 non-null     object 
 8   physics_rules_unfollowed            53 non-null     object 
 9   physics_rules_cannot_be_determined  53 non-null     object 
 10  human_violated_rules                53 non-null     object 
 11  model_name                          53 non-null

Unnamed: 0,caption,video_url,sa,pc,joint,action,is_hard,physics_rules_followed,physics_rules_unfollowed,physics_rules_cannot_be_determined,...,prompt_sa,prompt_physics,prompt_sa_score,sa_gemini_2.5_flash,prompt_sa_usage,prompt_physics_score,pc_gemini_2.5_flash,prompt_physics_usage,prompt_sa_cost,prompt_physics_cost
0,"A hedge trimmer is used to shape a tall, dense...",https://videophysics2testvideos.s3.us-east-2.a...,4,3,0,trimming shrubs,0,"['The blades of the hedge trimmer move.', 'The...",['The cut pieces of the hedge fall to the grou...,[],...,The following is a conversation between a curi...,The following is a conversation between a curi...,This video is an excellent match for the descr...,5,"{""completion_tokens"":197,""prompt_tokens"":1415,...",This video appears to fully adhere to the phys...,5,"{""completion_tokens"":68,""prompt_tokens"":1388,""...",0.003726,0.003554
1,A leaf blower is pointed at a patch of leaves ...,https://videophysics2testvideos.s3.us-east-2.a...,4,4,1,blowing leaves,0,"['The leaves move away from the leaf blower.',...",[],[],...,The following is a conversation between a curi...,The following is a conversation between a curi...,This video is an excellent match for the descr...,5,"{""completion_tokens"":141,""prompt_tokens"":1151,...",This video appears to adhere very well to the ...,5,"{""completion_tokens"":269,""prompt_tokens"":1125,...",0.003022,0.003001
2,"A seamstress pulls the fabric taut, guiding th...",https://videophysics2testvideos.s3.us-east-2.a...,4,4,1,sewing,0,['The needle moves up and down in a repetitive...,[],[],...,The following is a conversation between a curi...,The following is a conversation between a curi...,This video is an excellent match for the descr...,5,"{""completion_tokens"":177,""prompt_tokens"":885,""...",This video depicts a common and realistic acti...,5,"{""completion_tokens"":282,""prompt_tokens"":862,""...",0.002379,0.00237
3,"A seamstress pulls the fabric taut, guiding th...",https://videophysics2testvideos.s3.us-east-2.a...,4,5,1,sewing,0,"[""The woman's headscarf and shirt remain in a ...",[],[],...,The following is a conversation between a curi...,The following is a conversation between a curi...,"The video shows a woman handling red fabric, w...",2,"{""completion_tokens"":220,""prompt_tokens"":1411,...","Based on the video provided, there are no appa...",5,"{""completion_tokens"":75,""prompt_tokens"":1388,""...",0.003857,0.003541
4,"Hands fold a map, showing the creases forming ...",https://videophysics2testvideos.s3.us-east-2.a...,1,4,0,folding paper,0,['The hands maintain a relatively static posit...,[],[],...,The following is a conversation between a curi...,The following is a conversation between a curi...,"Based on the video provided, the description ""...",1,"{""completion_tokens"":102,""prompt_tokens"":879,""...","Based on the video provided, which shows hands...",5,"{""completion_tokens"":78,""prompt_tokens"":862,""t...",0.002333,0.002239


In [14]:
# @title Download videos

# Create a directory to save the videos
if not os.path.exists('videos'):
    os.makedirs('videos')

# Iterate through the DataFrame and download each video
for index, row in df.iterrows():
    video_url = row['video_url']
    md5 = row['md5']
    output_filename = f'videos/{md5}.mp4'  # Assuming the videos are mp4 format

    # Check if the file already exists to avoid re-downloading
    if not os.path.exists(output_filename):
        print(f"Downloading {video_url} to {output_filename}...")
        !wget -O {output_filename} {video_url}
    else:
        print(f"File {output_filename} already exists, skipping download.")

Downloading https://videophysics2testvideos.s3.us-east-2.amazonaws.com/cosmos_videophy2_test_challenging/A_hedge_trimmer_is_used_to_shape_a_tall,_dense_privet_hedge,_the_blades_visibly_cutting_through_the_leaves_and_stems.mp4 to videos/d667e27fee2f6a7a67aa6047b7708d03.mp4...
--2025-07-31 07:16:14--  https://videophysics2testvideos.s3.us-east-2.amazonaws.com/cosmos_videophy2_test_challenging/A_hedge_trimmer_is_used_to_shape_a_tall,_dense_privet_hedge,_the_blades_visibly_cutting_through_the_leaves_and_stems.mp4
Resolving videophysics2testvideos.s3.us-east-2.amazonaws.com (videophysics2testvideos.s3.us-east-2.amazonaws.com)... 52.219.98.98, 52.219.228.122, 3.5.130.145, ...
Connecting to videophysics2testvideos.s3.us-east-2.amazonaws.com (videophysics2testvideos.s3.us-east-2.amazonaws.com)|52.219.98.98|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1279915 (1.2M) [video/mp4]
Saving to: ‘videos/d667e27fee2f6a7a67aa6047b7708d03.mp4’


2025-07-31 07:16:14 (10.9 MB/s

In [19]:
# 物理原则二分类查询
text_queries = [
    'a video showing physically impossible scenarios with violations of fundamental physics principles',
    'a video showing realistic physics and natural motion following conservation laws'
]

text_ids, text_paddings = vp.tokenize_texts(text_tokenizer, text_queries)


# frames

# @title Specify input video
idx = 0
VIDEO_FILE_PATH = './videos/' + df.loc[idx, 'md5'] + '.mp4'

frames = read_and_preprocess_video(
    VIDEO_FILE_PATH,
    target_num_frames=NUM_FRAMES,
    target_frame_size=[FRAME_SIZE, FRAME_SIZE],
)
frames = jnp.asarray(frames[None, ...])  # Add batch dimension.

In [None]:
# @title Compute video-to-text retrieval results
video_embeddings, text_embeddings, _ = forward_fn(
    frames, text_ids, text_paddings)

TEMPERATURE = 0.01  # @param {type: "number"}
similarity_matrix = compute_similarity_matrix(
    video_embeddings,
    text_embeddings,
    temperature=TEMPERATURE,
    apply_softmax='over_texts',
)

