Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

In [None]:
# @title Installation {form-width: "10%"}

!pip install kauldron mediapy
!pip install git+https://github.com/google-deepmind/tapnet.git

In [None]:
# @title Imports {form-width: "10%"}

import io
import os
import random

import einops
from flax import linen as nn
import jax
import jax.numpy as jnp
import mediapy as media
import numpy as np
from typing import Callable, Optional

from kauldron import kd
from kauldron.typing import Float
from tapnet.utils import viz_utils

In [None]:
# @title Download checkpoints {form-width: "10%"}

!wget https://storage.googleapis.com/representations4d/checkpoints/moog_ego4d_backbone_ckpt_164335139.npz
!wget https://storage.googleapis.com/representations4d/checkpoints/moog_ego4d_point_track_head_ckpt_164335139.npz
!wget https://storage.googleapis.com/representations4d/checkpoints/moog_ego4d_box_track_head_ckpt_164335139.npz

In [None]:
# @title Attention module {form-width: "10%"}

class ImprovedMHDPAttention(nn.Module):
  num_heads: int
  qkv_size: int

  @nn.compact
  def __call__(self, inputs_q, inputs_kv):
    query = nn.DenseGeneral(features=(self.num_heads, self.qkv_size // self.num_heads), use_bias=False, name='dense_query')(inputs_q)
    key = nn.DenseGeneral(features=(self.num_heads, self.qkv_size // self.num_heads), use_bias=False, name='dense_key')(inputs_kv)
    value = nn.DenseGeneral(features=(self.num_heads, self.qkv_size // self.num_heads), use_bias=False, name='dense_value')(inputs_kv)
    query = nn.RMSNorm(name='norm_query')(query)
    key = nn.RMSNorm(name='norm_key')(key)
    query = query / jnp.sqrt(query.shape[-1])
    attn = jax.nn.softmax(jnp.einsum('...qhd,...khd->...hqk', query, key), axis=-1)
    x = jnp.einsum('...hqk,...khd->...qhd', attn, value)
    out = nn.DenseGeneral(features=inputs_q.shape[-1], axis=(-2, -1), use_bias=True, name='dense_out')(x)
    return out


class ImprovedTransformerBlock(nn.Module):
  mlp_size: int
  num_heads: int
  qkv_size: int

  @nn.compact
  def __call__(self, queries, inputs_kv=None):
    width = queries.shape[-1]
    x = queries
    query = nn.LayerNorm(use_bias=False, use_scale=True, name='norm_q')(queries)
    x += ImprovedMHDPAttention(self.num_heads, self.qkv_size, name='self_att')(query, query)
    if inputs_kv is not None:
      x += ImprovedMHDPAttention(self.num_heads, self.qkv_size, name='cross_att')(query, inputs_kv)
    normed_x = nn.LayerNorm(use_bias=False, use_scale=True, name='norm_attn')(x)
    h = nn.gelu(nn.Dense(self.mlp_size, name='MLP_in')(normed_x))
    mlp_out = nn.Dense(width, name='MLP_out')(h)
    return x + mlp_out


class ImprovedTransformer(nn.Module):
  qkv_size: int
  num_heads: int
  mlp_size: int
  num_layers: int
  hidden_size: Optional[int] = None

  @nn.compact
  def __call__(self, queries, inputs_kv=None, num_token_axes=1):
    first_token_axis = -1 - num_token_axes
    batch_shape = queries.shape[:first_token_axis]
    token_shape = queries.shape[first_token_axis:-1]
    if num_token_axes > 1:
      queries = jnp.reshape(queries, batch_shape + (np.prod(token_shape), queries.shape[-1]))

    query_size = queries.shape[-1]
    if self.hidden_size is not None and self.hidden_size != query_size:
      output_size = query_size
    else:
      output_size = None
    if self.hidden_size is not None:
      queries = nn.Dense(features=self.hidden_size)(queries)

    for i in range(self.num_layers):
      queries = ImprovedTransformerBlock(
          qkv_size=self.qkv_size,
          num_heads=self.num_heads,
          mlp_size=self.mlp_size,
          name=f'layer_{i}',
      )(queries, inputs_kv)

    queries = nn.LayerNorm(use_bias=False, use_scale=True, name='norm_encoder')(queries)
    if output_size is not None:
      queries = nn.Dense(features=output_size)(queries)
    if len(token_shape) > 1:
      queries = jnp.reshape(queries, batch_shape + token_shape + (queries.shape[-1],))
    return queries

In [None]:
# @title Position embedding module {form-width: "10%"}

def _convert_to_fourier_features(inputs, basis_degree):
  n_dims = inputs.shape[-1]
  freq_basis = jnp.concatenate([2**i * jnp.eye(n_dims) for i in range(basis_degree)], 1)
  x = inputs @ freq_basis  # Project inputs onto frequency basis.
  return jnp.sin(jnp.concatenate([x, x + 0.5 * jnp.pi], axis=-1))


def _create_gradient_grid(samples_per_dim, value_range=(-1.0, 1.0)):
  s = [jnp.linspace(value_range[0], value_range[1], n) for n in samples_per_dim]
  return jnp.stack(jnp.meshgrid(*s, sparse=False, indexing='ij'), axis=-1)


class FourierEmbedding(nn.Module):
  num_fourier_bases: int
  update_type: str
  axes: tuple = (-2,)

  @nn.compact
  def __call__(self, inputs):
    emb_shape = tuple(inputs.shape[axis] for axis in self.axes)
    coords = _create_gradient_grid(emb_shape, value_range=(-1.0, 1.0))
    pos_embedding = _convert_to_fourier_features(coords * jnp.pi, self.num_fourier_bases)
    all_axes = list(range(min(self.axes), -1))
    axes_to_add = tuple(axis - min(self.axes) for axis in all_axes if axis not in self.axes)
    pos_embedding = jnp.expand_dims(pos_embedding, axis=axes_to_add)

    if self.update_type == 'project_add':
      n_features = inputs.shape[-1]
      x = inputs + nn.Dense(n_features, name='dense_pe')(pos_embedding)
    elif self.update_type == 'concat':
      pos_embedding = jnp.broadcast_to(pos_embedding, inputs.shape[:-1] + pos_embedding.shape[-1:])
      x = jnp.concatenate((inputs, pos_embedding), axis=-1)
    else:
      raise ValueError('Invalid update type provided.')
    return x


class SampleFourierEmbedding(nn.Module):
  num_fourier_bases: int
  update_type: str

  @nn.compact
  def __call__(self, inputs, coords=None):
    if coords is None:
      coords = inputs

    pos_embedding = _convert_to_fourier_features(coords * jnp.pi, self.num_fourier_bases)

    if self.update_type == 'project_add':
      n_features = inputs.shape[-1]
      x = inputs + nn.Dense(n_features, name='dense_pe')(pos_embedding)
    elif self.update_type == 'concat':
      pos_embedding = jnp.broadcast_to(pos_embedding, inputs.shape[:-1] + pos_embedding.shape[-1:])
      x = jnp.concatenate((inputs, pos_embedding), axis=-1)
    elif self.update_type == 'replace':
      x = jnp.broadcast_to(pos_embedding, inputs.shape[:-1] + pos_embedding.shape[-1:])
    else:
      raise ValueError('Invalid update type provided.')
    return x

In [None]:
# @title Base module {form-width: "10%"}

class RandomStateInit(nn.Module):
  shape: tuple[int, ...]
  random_init_scale: float

  @nn.compact
  def __call__(self, inputs, batch_shape):
    key = self.make_rng("default")
    state = self.random_init_scale * jax.random.normal(key=key, shape=batch_shape + self.shape)
    return state


class ConvNet(nn.Module):
  features: tuple[int, ...]
  kernel_sizes: tuple[tuple[int, ...], ...]
  strides: tuple[tuple[int, ...], ...]

  @nn.compact
  def __call__(self, x):
    for n, (f, k, s) in enumerate(zip(self.features, self.kernel_sizes, self.strides)):
      x = nn.Conv(f, k, s)(x)
      if n < len(self.features) - 1:
        x = nn.gelu(x)
    x = nn.LayerNorm(use_bias=False, use_scale=True)(x)
    return x


class MLP(nn.Module):
  hidden_size: int
  output_size: Optional[int] = None

  @nn.compact
  def __call__(self, x):
    output_size = self.output_size or x.shape[-1]
    x = nn.Dense(self.hidden_size)(x)
    x = nn.gelu(x)
    x = nn.Dense(output_size)(x)
    return x

In [None]:
# @title Model module {form-width: "10%"}

class SRTEncoder(nn.Module):
  backbone: nn.Module
  transformer: nn.Module | None
  pos_embedding: nn.Module

  @nn.compact
  def __call__(self, image):
    x = self.backbone(image)
    x = self.pos_embedding(x)
    x = einops.rearrange(x, '... h w d -> ... (h w) d')
    if self.transformer is not None:
      x = self.transformer(x)
    return x


class SRTDecoder(nn.Module):
  transformer: nn.Module
  pos_embedding: nn.Module

  @nn.compact
  def __call__(self, inputs, targets):
    h, w = targets.shape[-3], targets.shape[-2]
    batch_shape, feat_dim = inputs.shape[:-2], inputs.shape[-1]
    coords = _create_gradient_grid((h, w), (-1.0, 1.0))
    coords = jnp.broadcast_to(coords, batch_shape + coords.shape)
    zero_feats = jnp.zeros(batch_shape + (height, width, feat_dim))
    x = self.pos_embedding(inputs=zero_feats, coords=coords)
    x = self.transformer(x, inputs_kv=inputs, num_token_axes=2)
    recon = nn.Dense(features=targets.shape[-1], name='recon')(x)
    return recon, targets


class MooG(nn.Module):
  encoder: SRTEncoder
  initializer: nn.Module
  predictor_sa: nn.Module
  predictor_ca: nn.Module
  decoder: SRTDecoder
  state_layer_norm: nn.Module = nn.LayerNorm(epsilon=1e-4, use_scale=True, use_bias=False)

  def scan_over_time(self, encoded_inputs, state):
    states = [state]
    for t in range(encoded_inputs.shape[-3]):
      state = self.state_layer_norm(state)
      state = state + self.predictor_sa(state)
      state = state + self.predictor_ca(queries=state, inputs_kv=encoded_inputs[..., t, :, :])
      states.append(state)
    return jnp.stack(states, axis=-3)

  @nn.compact
  def __call__(self, video):
    x = self.encoder(video)
    state = self.initializer(x[..., 0, :, :], batch_shape=x.shape[:-3])
    states = self.scan_over_time(x, state)
    out, targets = self.decoder(inputs=states[..., :-1, :, :], targets=video)
    return dict(video_predicted=out, subsampled_targets_predicted=targets)


class EvalWrapper(nn.Module):
  pretrained_model: MooG

  def setup(self):
    self.encoder = self.pretrained_model.encoder
    self.initializer = self.pretrained_model.initializer
    self.predictor_ca = self.pretrained_model.predictor_ca
    self.predictor_sa = self.pretrained_model.predictor_sa
    self.decoder = self.pretrained_model.decoder
    self.state_layer_norm = self.pretrained_model.state_layer_norm
    self.scan_over_time = self.pretrained_model.scan_over_time

  @nn.compact
  def __call__(self, video, state=None):
    x = self.encoder(video)
    if state is None:
      state = self.initializer(x[..., 0, :, :], batch_shape=x.shape[:-3])
    states = self.scan_over_time(x, state)
    return dict(features=states[..., 1:, :, :], state=states[..., -1, :, :])

In [None]:
# @title Readout module {form-width: "10%"}

class AttentionReadout(nn.Module):
  num_classes: int
  num_params: int
  num_heads: int

  @nn.compact
  def __call__(self, inputs, queries):
    feats = nn.LayerNorm()(inputs) # (1, 16, 1024, 512)
    feats += kd.nn.LearnedEmbedding(name='temporal_posenc')(feats.shape, axis=-3)
    feats = einops.rearrange(feats, '... T N C -> ... (T N) C') # (1, 16384, 512)
    query = nn.Dense(self.num_params)(queries) # (1, 408, 1024)
    query = einops.rearrange(query, '... Q (h n) -> ... Q h n', h=self.num_heads) # (1, 408, 8, 128)
    key = nn.DenseGeneral(features=(self.num_heads, self.num_params // self.num_heads), axis=-1, use_bias=True, name='key_embedding')(feats) # (1, 16384, 8, 128)
    val = nn.DenseGeneral(features=(self.num_heads, self.num_params // self.num_heads), axis=-1, use_bias=True, name='value_embedding')(feats) # (1, 16384, 8, 128)
    token = nn.dot_product_attention(query=query, key=key, value=val) # (1, 408, 8, 128)
    token = einops.rearrange(token, '... Q N c -> ... Q (N c)') # (1, 408, 1024)
    query = einops.rearrange(query, '... Q N c -> ... Q (N c)') # (1, 408, 1024)
    token = query + nn.Dense(self.num_params)(token) # (1, 408, 1024)
    token = token + MLP(self.num_params * 4)(nn.LayerNorm()(token)) # (1, 408, 1024)
    out = nn.Dense(self.num_classes)(token) # (1, 408, 8)
    return out


class TrackingReadoutWrapper(nn.Module):
  attention_readout: AttentionReadout
  query_initializer: nn.Module | None = None
  output_activation: Optional[Callable[[Float['*b']], Float['*b']]] = None
  predict_visibility: bool = False
  use_certainty: bool = True
  certainty_threshold: float = 0.5
  num_frames_per_query: int = 1
  temporal_tile_size: int = 1

  @nn.compact
  def __call__(self, inputs, queries):
    if self.query_initializer is not None:
      queries = self.query_initializer(queries)

    if self.temporal_tile_size > 1:
      queries = einops.repeat(queries, 'B Q D -> B Q k D', k=self.temporal_tile_size)
      pos = kd.nn.LearnedEmbedding(name='temporal_tile_posenc')(queries.shape, axis=-2).astype(queries.dtype)
      queries = einops.rearrange(queries + pos, 'B Q k D -> B (Q k) D')

    out = self.attention_readout(inputs, queries)
    out = einops.rearrange(out, 'B (Q k) C -> B Q (k C)', k=self.temporal_tile_size)
    out = einops.rearrange(
      out, '... Q (k t c)->...(k t) Q c',
      k=self.temporal_tile_size,
      t=self.num_frames_per_query,
      c=self.attention_readout.num_classes // self.num_frames_per_query)

    values, logits_visible, logits_certainty, visible = out, None, None, None
    if self.predict_visibility:
      split = 2 if self.use_certainty else 1
      values, logits = out[..., :-split], out[..., -split:]
      logits_visible = logits[..., :1]; visible = jax.nn.sigmoid(logits_visible)
      if self.use_certainty:
        logits_certainty = logits[..., 1:]
        visible = (visible * jax.nn.sigmoid(logits_certainty) >
                   self.certainty_threshold).astype(jnp.float32)

    if self.output_activation is not None:
      values = self.output_activation(values)

    return dict(values=values,
                logits_visible=logits_visible,
                logits_certainty=logits_certainty,
                visible=visible)

In [None]:
# @title Define model {form-width: "10%"}

model = EvalWrapper(
    pretrained_model=MooG(
        encoder=SRTEncoder(
            backbone=ConvNet(
                features=(64, 128, 128, 256, 256, 512),
                kernel_sizes=((3, 3),) * 6,
                strides=((2, 2), (1, 1), (1, 1), (2, 2), (1, 1), (1, 1)),
            ),
            transformer=ImprovedTransformer(
                qkv_size=64 * 8,
                num_heads=8,
                mlp_size=2048,
                hidden_size=512,
                num_layers=0,  # NOTE: this just adds a linear + norm
            ),
            pos_embedding=FourierEmbedding(
                num_fourier_bases=20,
                axes=(-3, -2),
                update_type="concat",
            ),
        ),
        initializer=RandomStateInit(
            shape=(1024, 512),
            random_init_scale=1e-4,
        ),
        state_layer_norm=nn.LayerNorm(
            epsilon=1e-4, use_scale=True, use_bias=False
        ),
        predictor_ca=ImprovedTransformer(
            qkv_size=64 * 8,
            num_heads=8,
            mlp_size=2048,
            num_layers=2,
        ),
        predictor_sa=ImprovedTransformer(
            qkv_size=64 * 4,
            num_heads=4,
            mlp_size=2048,
            num_layers=3,
        ),
        decoder=SRTDecoder(
            transformer=ImprovedTransformer(
                qkv_size=64 * 2,
                num_heads=2,
                mlp_size=2048,
                num_layers=6,
            ),
            pos_embedding=SampleFourierEmbedding(
                num_fourier_bases=16,
                update_type="concat",
            ),
        ),
    ),
)

point_readout_head = TrackingReadoutWrapper(
    attention_readout=AttentionReadout(
        num_classes=8,
        num_params=1024,
        num_heads=8,
    ),
    query_initializer=kd.nn.Sequential(
        layers=[
            SampleFourierEmbedding(
                num_fourier_bases=16,
                update_type='replace',
            ),
            MLP(
                hidden_size=512,
                output_size=512,
            ),
        ],
    ),
    output_activation=nn.sigmoid,
    num_frames_per_query=2,
    temporal_tile_size=8,
    predict_visibility=True,
    use_certainty=True,
)

box_readout_head = TrackingReadoutWrapper(
    attention_readout=AttentionReadout(
        num_classes=64,
        num_params=1024,
        num_heads=4,
    ),
    query_initializer=kd.nn.Sequential(
        layers=[
            SampleFourierEmbedding(
                num_fourier_bases=16,
                update_type='replace',
            ),
            MLP(
                hidden_size=512,
                output_size=512,
            ),
        ],
    ),
    output_activation=None,
    num_frames_per_query=16,
    temporal_tile_size=1,
    predict_visibility=False,
    use_certainty=False,
)

In [None]:
# @title Load parameters {form-width: "10%"}

def npload(fname):
  if os.path.exists(fname):
    return np.load(fname, allow_pickle=False)
  with open(fname, "rb") as f:
    data = f.read()
  loaded = np.load(io.BytesIO(data), allow_pickle=False)
  return loaded if isinstance(loaded, np.ndarray) else dict(loaded)

def recover_tree(flat_dict):
  tree = {}
  for k, v in flat_dict.items():
    parts = k.split("/")
    node = tree
    for part in parts[:-1]:
      if part not in node:
        node[part] = {}
      node = node[part]
    node[parts[-1]] = v
  return tree


ckpt_path = 'moog_ego4d_backbone_ckpt_164335139.npz'
backbone_params = recover_tree(npload(ckpt_path))

ckpt_path = 'moog_ego4d_point_track_head_ckpt_164335139.npz'
point_readout_params = recover_tree(npload(ckpt_path))

ckpt_path = 'moog_ego4d_box_track_head_ckpt_164335139.npz'
box_readout_params = recover_tree(npload(ckpt_path))

## Point Tracking

In [None]:
# @title Load data {form-width: "10%"}

!wget https://storage.googleapis.com/representations4d/assets/kubric_batch.npy

batch = np.load('kubric_batch.npy', allow_pickle=True).item()

In [None]:
# @title Compute model backbone {form-width: "10%"}

key = jax.random.PRNGKey(0)
model_preds, state = model.apply({'params': backbone_params}, batch['video'], is_training_property=False, rngs={'default': key}, capture_intermediates=True)

In [None]:
# @title Compute point readout {form-width: "10%"}

key = jax.random.PRNGKey(0)
point_readout = point_readout_head.apply({'params': point_readout_params}, model_preds['features'], batch['query_points_video'], is_training_property=False, rngs={'default': key})

In [None]:
# @title Visualize predicted keypoints {form-width: "10%"}

video = (np.array(batch['video'][0]) * 255).astype(np.uint8)
height, width = video.shape[1:3]
pred_tracks = np.array(point_readout['values'][0].transpose(1, 0, 2)) * np.array((width, height))
pred_visibles = np.array(point_readout['visible'][0, ..., 0].transpose(1, 0))

gt_tracks = np.array(batch['target_points_video'][0].transpose(1, 0, 2)) * np.array((width, height))
gt_visibles = np.array(batch['target_points_visible_video'][0, ..., 0].transpose(1, 0))

num_points = gt_tracks.shape[0]
colormap = viz_utils.get_colors(num_points)
video_viz = viz_utils.paint_point_track(video, pred_tracks, pred_visibles, colormap)
gt_video_viz = viz_utils.paint_point_track(video, gt_tracks, gt_visibles, colormap)
print('GT Points (left), Points Predictions (right)')
media.show_video(np.concatenate([gt_video_viz, video_viz], axis=2), height=256, fps=8)

## Box Tracking

In [None]:
# @title Load data {form-width: "10%"}

!wget https://storage.googleapis.com/representations4d/assets/waymo_batch.npy

batch = np.load('waymo_batch.npy', allow_pickle=True).item()

In [None]:
# @title Compute model backbone {form-width: "10%"}

key = jax.random.PRNGKey(0)
model_preds, state = model.apply({'params': backbone_params}, batch['video'], is_training_property=False, rngs={'default': key}, capture_intermediates=True)

In [None]:
# @title Compute box readout {form-width: "10%"}

key = jax.random.PRNGKey(0)
box_readout = box_readout_head.apply({'params': box_readout_params}, model_preds['features'], batch['query_boxes_video'], is_training_property=False, rngs={'default': key})

In [None]:
# @title Visualize predicted boxes {form-width: "10%"}

import cv2

def draw_boxes_on_video(video, boxes):
  colors = []
  for _ in range(boxes.shape[1]):
    colors.append([random.randint(0, 255) for _ in range(3)])

  video = video.copy()
  for t in range(video.shape[0]):
    for i in range(boxes.shape[1]):
      y1, x1, y2, x2 = boxes[t, i, :].astype(np.int32)
      cv2.rectangle(video[t], (x1, y1), (x2, y2), colors[i], 2)
  return video

video = (np.array(batch['video'][0]) * 255).astype(np.uint8)
height, width = video.shape[1:3]
pred_boxes = np.array(box_readout['values'][0]) * np.array((width, height, width, height))

gt_boxes = np.array(batch['boxes_video'][0]) * np.array((width, height, width, height))

video_viz = draw_boxes_on_video(video, pred_boxes)
video_gt_viz = draw_boxes_on_video(video, gt_boxes)
media.show_videos([video_viz, video_gt_viz], fps=8)