<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/videoprism_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://research.google/blog/videoprism-a-foundational-visual-encoder-for-video-understanding/

In [7]:
!nvidia-smi

Fri Jun 27 03:59:55 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   75C    P0             34W /   72W |   18577MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
# Clone the VideoPrism repository
!git clone https://github.com/google-deepmind/videoprism.git

# Navigate into the directory
%cd videoprism

# Install the package and its dependencies
!pip install . -q

In [1]:
import jax
import numpy as np
from videoprism import models as vp

# 1. Choose your model variant
# You can use 'videoprism_public_v1_base' (smaller) or 'videoprism_public_v1_large' (larger)
model_name = 'videoprism_public_v1_large'

# 2. Load the model configuration
print(f"Loading VideoPrism model: {model_name}...")
flax_model = vp.MODELS[model_name]()
print("Model configuration loaded.")

# 3. Load the pre-trained weights
print("Loading pre-trained weights (this may take a moment)...")
loaded_state = vp.load_pretrained_weights(model_name)
print("Pre-trained weights loaded.")

# 4. Define the forward pass function for inference
# It's crucial to wrap this in jax.jit for correct results and performance
@jax.jit
def forward_fn(inputs):
  """Applies the VideoPrism model to input video frames."""
  return flax_model.apply(loaded_state, inputs, train=False)

# 5. Prepare your input video data
# VideoPrism expects input videos with shape (batch_size, num_frames, height, width, 3)
# The RGB values should be normalized to [0.0, 1.0].
# The recommended input resolution is 288x288.
# num_frames can be arbitrary, as the model interpolates temporal positional embeddings.
batch_size = 1
num_frames = 16  # Example: 16 frames
height = 288
width = 288
channels = 3

print(f"Generating dummy video input of shape: ({batch_size}, {num_frames}, {height}, {width}, {channels})")
# Create a dummy video tensor with random float data normalized to [0.0, 1.0]
dummy_video_data = np.random.rand(batch_size, num_frames, height, width, channels).astype(np.float32)

model_inputs = dummy_video_data

# 6. Run inference
print(f"Running inference with input shape: {model_inputs.shape}")
outputs_tuple = forward_fn(model_inputs)

# Access the primary output (embeddings) which is typically the first element of the tuple
outputs = outputs_tuple[0]
print("Inference complete.")

# 7. Process the outputs
# The output shape is [batch_size, num_tokens, feature_channels].
# The `num_tokens` is `num_frames * 16 * 16` for spatiotemporal representations.
# You can reshape it to `(batch_size, num_frames, 16, 16, feature_channels)`
# for more intuitive spatiotemporal features.
print(f"Raw output embeddings shape: {outputs.shape}")

# Example: Reshaping for spatiotemporal features
feature_channels = outputs.shape[-1]
reshaped_outputs = outputs.reshape(
    batch_size, num_frames, 16, 16, feature_channels
)
print(f"Reshaped spatiotemporal embeddings shape: {reshaped_outputs.shape}")

print("\nVideoPrism inference demonstration complete.")
print(f"Example of the first few values of the generated embeddings:\n{reshaped_outputs[0, 0, 0, 0, :5]}")

Loading VideoPrism model: videoprism_public_v1_large...
Model configuration loaded.
Loading pre-trained weights (this may take a moment)...
Pre-trained weights loaded.
Generating dummy video input of shape: (1, 16, 288, 288, 3)
Running inference with input shape: (1, 16, 288, 288, 3)
Inference complete.
Raw output embeddings shape: (1, 4096, 1024)
Reshaped spatiotemporal embeddings shape: (1, 16, 16, 16, 1024)

VideoPrism inference demonstration complete.
Example of the first few values of the generated embeddings:
[ 0.88851106 -0.01135523 -0.01655489 -0.36440778  0.15668297]


In [2]:
!pip install opencv-python numpy jax jaxlib videoprism imageio imageio-ffmpeg -q

In [4]:
import jax
import numpy as np
import cv2
import os
import imageio.v2 as imageio # Use imageio.v2 for the newer API


import jax
import numpy as np
from videoprism import models as vp

# --- Configuration ---
model_name = 'videoprism_public_v1_large'
# Path for the video file we will generate
generated_video_path = '/tmp/generated_test_video.mp4' # Use /tmp for temporary file
target_height = 288
target_width = 288
num_frames_to_generate = 30 # Number of frames for our generated video
max_frames_to_process = 16 # Number of frames VideoPrism will actually process from the generated video

# --- Function to Generate a Simple MP4 Video ---
def generate_simple_mp4_video(output_path, num_frames, height, width, fps=10):
    """
    Generates a simple MP4 video with changing colors.
    """
    writer = imageio.get_writer(output_path, fps=fps)
    print(f"Generating a simple {width}x{height} MP4 video with {num_frames} frames to '{output_path}'...")
    for i in range(num_frames):
        # Create a frame with changing color
        # Red component varies with frame index
        r = int(255 * (i / num_frames))
        # Green component varies inversely
        g = int(255 * (1 - (i / num_frames)))
        b = 100 # Keep blue constant

        frame = np.zeros((height, width, 3), dtype=np.uint8)
        frame[:, :] = [r, g, b] # Set all pixels to this color

        writer.append_data(frame)
    writer.close()
    print("Video generation complete.")

# --- Video Loading and Preprocessing Function ---
def load_and_preprocess_video(file_path, target_h, target_w, max_frames=None):
    """
    Loads video frames, resizes them, converts to RGB, normalizes,
    and returns them as a NumPy array suitable for VideoPrism.
    """
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Video file not found at: {file_path}")

    cap = cv2.VideoCapture(file_path)
    if not cap.isOpened():
        raise IOError(f"Could not open video file: {file_path}")

    frames = []
    frame_count = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break  # No more frames

        # Convert BGR (OpenCV default) to RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # Resize frame to target dimensions
        frame = cv2.resize(frame, (target_w, target_h))

        # Normalize pixel values to [0.0, 1.0]
        frames.append(frame / 255.0)

        frame_count += 1
        if max_frames is not None and frame_count >= max_frames:
            break # Stop after desired number of frames

    cap.release() # Release the video capture object

    if not frames:
        raise ValueError(f"No frames were loaded from the video: {file_path}. "
                         "Check if the video is valid or if max_frames is too low.")

    # Convert list of frames to a NumPy array and add batch dimension
    video_tensor = np.expand_dims(np.array(frames, dtype=np.float32), axis=0)
    return video_tensor

# --- Main Script ---
try:
    # 1. Generate a simple video file first
    generate_simple_mp4_video(
        generated_video_path,
        num_frames=num_frames_to_generate,
        height=target_height,
        width=target_width
    )

    # 2. Load and preprocess the newly generated video
    print(f"\nAttempting to load and preprocess video from: {generated_video_path}")
    model_inputs = load_and_preprocess_video(
        generated_video_path, target_height, target_width, max_frames=max_frames_to_process
    )
    print(f"Video loaded. Input shape for VideoPrism: {model_inputs.shape}")

    # 3. Load the VideoPrism model
    print(f"\nLoading VideoPrism model: {model_name}...")
    flax_model = vp.MODELS[model_name]()
    print("Model configuration loaded.")

    # 4. Load the pre-trained weights
    print("Loading pre-trained weights (this may take a moment)...")
    # !!! IMPORTANT: ALLOW THIS STEP TO COMPLETE WITHOUT INTERRUPTION !!!
    loaded_state = vp.load_pretrained_weights(model_name)
    print("Pre-trained weights loaded.")

    # 5. Define the forward pass function for inference
    @jax.jit
    def forward_fn(inputs):
        """Applies the VideoPrism model to input video frames."""
        return flax_model.apply(loaded_state, inputs, train=False)

    # 6. Run inference
    print(f"Running inference with input shape: {model_inputs.shape}")
    outputs_tuple = forward_fn(model_inputs)

    # Access the primary output (embeddings) which is typically the first element of the tuple
    outputs = outputs_tuple[0]
    print("Inference complete.")

    # 7. Process the outputs
    num_frames_processed = model_inputs.shape[1] # Actual number of frames from the video
    batch_size_actual = model_inputs.shape[0]

    print(f"Raw output embeddings shape: {outputs.shape}")

    feature_channels = outputs.shape[-1]
    expected_num_tokens = num_frames_processed * 16 * 16

    if outputs.shape[1] != expected_num_tokens:
        print(f"Warning: Unexpected number of tokens in raw output. Expected {expected_num_tokens}, got {outputs.shape[1]}.")

    reshaped_outputs = outputs.reshape(
        batch_size_actual, num_frames_processed, 16, 16, feature_channels
    )
    print(f"Reshaped spatiotemporal embeddings shape: {reshaped_outputs.shape}")

    print("\nVideoPrism inference with programmatically generated video complete.")
    print(f"Example of the first few values of the generated embeddings for the first frame's top-left patch:\n{reshaped_outputs[0, 0, 0, 0, :5]}")

    !cp -pr /tmp/generated_test_video.mp4 /content/

except FileNotFoundError as e:
    print(f"Error: {e}. Something went wrong with file generation or loading.")
except IOError as e:
    print(f"Error: {e}. Could not open or read the generated video file.")
except ValueError as e:
    print(f"Error during video processing: {e}")
except Exception as e:
    print(f"An unexpected error occurred: {e}")
finally:
    # Clean up the generated video file
    if os.path.exists(generated_video_path):
        os.remove(generated_video_path)
        print(f"\nCleaned up generated video file: {generated_video_path}")

Generating a simple 288x288 MP4 video with 30 frames to '/tmp/generated_test_video.mp4'...
Video generation complete.

Attempting to load and preprocess video from: /tmp/generated_test_video.mp4
Video loaded. Input shape for VideoPrism: (1, 16, 288, 288, 3)

Loading VideoPrism model: videoprism_public_v1_large...
Model configuration loaded.
Loading pre-trained weights (this may take a moment)...
Pre-trained weights loaded.
Running inference with input shape: (1, 16, 288, 288, 3)
Inference complete.
Raw output embeddings shape: (1, 4096, 1024)
Reshaped spatiotemporal embeddings shape: (1, 16, 16, 16, 1024)

VideoPrism inference with programmatically generated video complete.
Example of the first few values of the generated embeddings for the first frame's top-left patch:
[-0.02193046 -0.07946897  0.41855606  0.01070144 -0.08500562]

Cleaned up generated video file: /tmp/generated_test_video.mp4


In [6]:
!ls /content/*.mp4

/content/generated_test_video.mp4
