Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

<p align="center">
  <h1 align="center">Scaling 4D Representations</h1>
  <p align="center">
    João Carreira, Dilara Gokay, Michael King, Chuhan Zhang, Ignacio Rocco, Aravindh Mahendran, Thomas Albert Keck, Joseph Heyward, Skanda Koppula, Etienne Pot, Goker Erdogan, Yana Hasson, Yi Yang, Klaus Greff, Guillaume Le Moing, Sjoerd van Steenkiste, Daniel Zoran, Drew A. Hudson, Pedro Vélez, Luisa Polanía, Luke Friedman, Chris Duvarney, Ross Goroshin, Kelsey Allen, Jacob Walker, Rishabh Kabra, Eric Aboussouan, Jennifer Sun, Thomas Kipf, Carl Doersch, Viorica Pătrăucean, Dima Damen, Pauline Luc, Mehdi S. M. Sajjadi, Andrew Zisserman
  </p>
  <h3 align="center"><a href="https://arxiv.org/abs/2412.15212">Paper</a> | <a href="https://github.com/google-deepmind/representations4d">GitHub</a>  </h3>
  <div align="center"></div>
</p>

<p align="center">
  <a href="">
    <img src="https://storage.googleapis.com/representations4d/assets/architecture.png" alt="Logo" width="50%">
  </a>
</p>



In [None]:
# @title Installation

!git clone https://github.com/google-deepmind/representations4d.git
%cd representations4d
!pip install .

In [None]:
# @title Download example input and checkpoint
!wget https://storage.googleapis.com/representations4d/checkpoints/scaling4d_dist_b_depth.npz
!wget https://storage.googleapis.com/representations4d/assets/horsejump-high.mp4

In [None]:
# @title Imports

from flax import linen as nn
import jax
import jax.numpy as jnp
from kauldron.modules import pos_embeddings
from kauldron.modules import vit as kd_vit
import mediapy
from representations4d.models import model as model_lib
from representations4d.models import readout
import numpy as np
from representations4d.utils import checkpoint_utils

In [None]:
# @title Hyperparameters
model_patch_size = (2, 16, 16)
im_size = (224, 224)
model_size = "B"
dtype = jnp.float32
model_output_patch_size = (2, 8, 8)
n_pixels_patch = (
    model_output_patch_size[0]
    * model_output_patch_size[1]
    * model_output_patch_size[2]
)
num_input_frames = 16
n_pixels_video = num_input_frames * im_size[0] * im_size[1]

In [None]:
# @title Define model
encoder = model_lib.Model(
    encoder=model_lib.Tokenizer(
        patch_embedding=model_lib.PatchEmbedding(
            patch_size=model_patch_size,
            num_features=kd_vit.VIT_SIZES[model_size][0],
        ),
        posenc=pos_embeddings.LearnedEmbedding(dtype=dtype),
        posenc_axes=(-4, -3, -2),
    ),
    processor=model_lib.GeneralizedTransformer.from_variant_str(
        variant_str=model_size,
        dtype=dtype,
    ),
)

encoder2readout = model_lib.EncoderToReadout(
    embedding_shape=(
        num_input_frames // model_patch_size[0],
        im_size[0] // model_patch_size[1],
        im_size[1] // model_patch_size[2],
    ),
    readout_depth=0.95,
    num_input_frames=num_input_frames,
)

readout_head = readout.AttentionReadout(
    num_classes=n_pixels_patch,
    num_params=1024,
    num_heads=16,
    num_queries=n_pixels_video // n_pixels_patch,
    output_shape=(
        num_input_frames,
        im_size[0],
        im_size[1],
        1,
    ),
    decoding_patch_size=model_output_patch_size,
)

model = nn.Sequential([encoder, encoder2readout, readout_head])


def forward(params, vid):
  return model.apply(params, vid, is_training_property=False)

In [None]:
# @title Initialize model
key = jax.random.key(0)
x = jnp.zeros((1, 16, 224, 224, 3)).astype(jnp.float32)

model_params = model.init(key, x, is_training_property=True)

In [None]:
# @title Restore parameters

restored_params = checkpoint_utils.recover_tree(
    checkpoint_utils.npload("scaling4d_dist_b_depth.npz")
)

In [None]:
# @title Load example video from DAVIS

video = mediapy.read_video("horsejump-high.mp4")

video = mediapy.resize_video(video, im_size) / 255.0
video = video[jnp.newaxis, :num_input_frames].astype(jnp.float32)

In [None]:
# @title Run forward pass
outputs = forward(restored_params, video)

In [None]:
# @title Visualize depth maps
out = np.array(outputs[0])
out = jnp.tile(out, [1, 1, 1, 3])
out = out / np.max(out)
vis = np.concatenate([video[0], out], axis=2)
mediapy.show_video(vis, fps=20)