# VideoPrism Video Encoder Demo

[![Paper](https://img.shields.io/badge/arXiv-2402.13217-red.svg)](https://arxiv.org/abs/2402.13217)
[![Blog](https://img.shields.io/badge/Google_Research-Blog-green.svg)](https://research.google/blog/videoprism-a-foundational-visual-encoder-for-video-understanding/)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

This notebook provides an example of video feature extraction with a pre-trained VideoPrism video encoder.

Please run this demo on Google Colab with (faster) or without TPU.

## Set up

In [None]:
# @title Prepare environment

import os
import sys

# Fetch VideoPrism repository if Python does not know about it and install
# dependencies needed for this notebook.
if not os.path.exists("videoprism_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-deepmind/videoprism.git videoprism_repo
  os.chdir('./videoprism_repo')
  !pip install .
  os.chdir('..')

# Append VideoPrism code to Python import path.
if "videoprism_repo" not in sys.path:
  sys.path.append("videoprism_repo")

# Install missing dependencies.
!pip install mediapy

import jax
from jax.extend import backend
import tensorflow as tf

# Do not let TF use the GPU or TPUs.
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.get_backend().platform}")
print(f"JAX devices:  {jax.device_count()}")


In [None]:
# @title Load dependencies and define utilities

import mediapy
import numpy as np
from PIL import Image


def read_and_preprocess_video(
    filename: str, target_num_frames: int, target_frame_size: tuple[int, int]
):
  """Reads and preprocesses a video."""

  frames = mediapy.read_video(filename)

  # Sample to target number of frames.
  frame_indices = np.linspace(
      0, len(frames), num=target_num_frames, endpoint=False, dtype=np.int32
  )
  frames = np.array([frames[i] for i in frame_indices])

  # Resize to target size.
  original_height, original_width = frames.shape[-3:-1]
  target_height, target_width = target_frame_size
  assert (
      original_height * target_width == original_width * target_height
  ), 'Currently does not support aspect ratio mismatch.'
  frames = mediapy.resize_video(frames, shape=target_frame_size)

  # Normalize pixel values to [0.0, 1.0].
  frames = mediapy.to_float01(frames)

  return frames


In [None]:
# @title Load model

import jax
import jax.numpy as jnp
from videoprism import models as vp

# Models available: ['videoprism_public_v1_base', 'videoprism_public_v1_large']
MODEL_NAME = 'videoprism_public_v1_base'
NUM_FRAMES = 16
FRAME_SIZE = 288

flax_model = vp.MODELS[MODEL_NAME]()
loaded_state = vp.load_pretrained_weights(MODEL_NAME)


@jax.jit
def forward_fn(inputs, train=False):
  return flax_model.apply(loaded_state, inputs, train=train)


# Example: Video feature extraction

In this example, we extract the spatiotemporal embeddings of an example video.

In [None]:
VIDEO_FILE_PATH = (
    './videoprism_repo/videoprism/assets/water_bottle_drumming.mp4'
)
frames = read_and_preprocess_video(
    VIDEO_FILE_PATH,
    target_num_frames=NUM_FRAMES,
    target_frame_size=[FRAME_SIZE, FRAME_SIZE],
)
mediapy.show_video(frames, fps=6.0)

frames = jnp.asarray(frames[None, ...])  # Add batch dimension.
print(f'Input shape: {frames.shape}')

embeddings, _ = forward_fn(frames)
print(f'Encoded embedding shape: {embeddings.shape}')
