Copyright 2020 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

<p align="center">
  <h1 align="center">TRAJAN: Direct Motion Models for Assessing Generated Videos</h1>
  <p align="center">
    <a href="https://k-r-allen.github.io/">Kelsey Allen*</a>
    ·
    <a href="http://www.carldoersch.com/">Carl Doersch</a>
    ·
    <a href="https://stanniszhou.github.io/">Guangyao Zhou</a>
    ·
    <a href="https://mohammedsuhail.net/">Mohammed Suhail</a>
    ·
    <a href="https://dannydriess.github.io/">Danny Driess</a>
    ·
    <a href="https://www.irocco.info/">Ignacio Rocco</a>
    ·
    <a href="https://yuliarubanova.github.io/">Yulia Rubanova</a>
    ·
    <a href="https://tkipf.github.io/">Thomas Kipf</a>
    ·
    <a href="https://msajjadi.com/">Mehdi S. M. Sajjadi</a>
    ·
    <a href="https://scholar.google.com/citations?user=MxxZkEcAAAAJ&hl=en">Kevin Murphy</a>
    ·
    <a href="https://scholar.google.co.uk/citations?user=IUZ-7_cAAAAJ">Joao Carreira</a>
    ·
    <a href="https://www.sjoerdvansteenkiste.com/">Sjoerd van Steenkiste*</a>
  </p>
  <h3 align="center"><a href="">Paper</a> | <a href="https://trajan-paper.github.io">Project Page</a> | <a href="https://github.com/deepmind/tapnet">GitHub</a> </h3>
  <div align="center"></div>
</p>

<p align="center">
  <a href="">
    <img src="https://storage.googleapis.com/dm-tapnet/swaying_gif.gif" alt="Logo" width="50%">
  </a>
</p>



In [None]:
# @title Install Code and Dependencies {form-width: "25%"}
!pip install git+https://github.com/google-deepmind/tapnet.git

In [None]:
MODEL_TYPE = 'bootstapir'  # 'tapir' or 'bootstapir'

In [None]:
# @title Download Model {form-width: "25%"}

%mkdir tapnet/checkpoints

if MODEL_TYPE == "tapir":
  !wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/tapir_checkpoint_panning.npy
else:
  !wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/bootstap/bootstapir_checkpoint_v2.npy

!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/trajan/track_autoencoder_ckpt.npz

%ls tapnet/checkpoints

In [None]:
# @title Imports {form-width: "25%"}
from google.colab import output
import jax
import matplotlib
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
from tapnet.models import tapir_model
from tapnet.utils import model_utils
from tapnet.utils import transforms
from tapnet.utils import viz_utils
from tapnet.tapvid import evaluation_datasets
from tapnet.trajan import track_autoencoder

output.enable_custom_widget_manager()

In [None]:
# @title Load Checkpoint {form-width: "25%"}

if MODEL_TYPE == 'tapir':
  checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint_panning.npy'
else:
  checkpoint_path = 'tapnet/checkpoints/bootstapir_checkpoint_v2.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']

kwargs = dict(bilinear_interp_with_depthwise_conv=False, pyramid_level=0)
if MODEL_TYPE == 'bootstapir':
  kwargs.update(
      dict(pyramid_level=1, extra_convs=True, softmax_temperature=10.0)
  )
tapir = tapir_model.ParameterizedTAPIR(params, state, tapir_kwargs=kwargs)

trajan_checkpoint_path = 'tapnet/checkpoints/track_autoencoder_ckpt.npz'

In [None]:
# @title Load an Exemplar Video {form-width: "25%"}

%mkdir tapnet/examplar_videos

!wget -P tapnet/examplar_videos http://storage.googleapis.com/dm-tapnet/horsejump-high.mp4

video = media.read_video("tapnet/examplar_videos/horsejump-high.mp4")
media.show_video(video, fps=10)

In [None]:
# @title Utility Functions {form-width: "25%"}


def inference(frames, query_points):
  """Inference on one video.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8
    query_points: [num_points, 3], [0, num_frames/height/width], [t, y, x]

  Returns:
    tracks: [num_points, 3], [-1, 1], [t, y, x]
    visibles: [num_points, num_frames], bool
  """
  # Preprocess video to match model inputs format
  frames = model_utils.preprocess_frames(frames)
  query_points = query_points.astype(np.float32)
  frames, query_points = frames[None], query_points[None]  # Add batch dimension

  outputs = tapir(
      video=frames,
      query_points=query_points,
      is_training=False,
      query_chunk_size=32,
  )
  tracks, occlusions, expected_dist = (
      outputs['tracks'],
      outputs['occlusion'],
      outputs['expected_dist'],
  )

  # Binarize occlusions
  visibles = model_utils.postprocess_occlusions(occlusions, expected_dist)
  return tracks[0], visibles[0]


inference = jax.jit(inference)


def sample_random_points(frame_max_idx, height, width, num_points):
  """Sample random points with (time, height, width) order."""
  y = np.random.randint(0, height, (num_points, 1))
  x = np.random.randint(0, width, (num_points, 1))
  t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))
  points = np.concatenate((t, y, x), axis=-1).astype(
      np.int32
  )  # [num_points, 3]
  return points

In [None]:
# @title Efficient Chunked Point Track Prediction {form-width: "25%"}

resize_height = 256  # @param {type: "integer"}
resize_width = 256  # @param {type: "integer"}
num_points = 4096  # @param {type: "integer"}

frames = media.resize_video(video, (resize_height, resize_width))
frames = model_utils.preprocess_frames(frames[None])
feature_grids = tapir.get_feature_grids(frames, is_training=False)
query_points = sample_random_points(
    frames.shape[1], frames.shape[2], frames.shape[3], num_points
)
chunk_size = 32


def chunk_inference(query_points):
  query_points = query_points.astype(np.float32)[None]

  outputs = tapir(
      video=frames,
      query_points=query_points,
      is_training=False,
      query_chunk_size=chunk_size,
      feature_grids=feature_grids,
  )
  tracks, occlusions, expected_dist = (
      outputs["tracks"],
      outputs["occlusion"],
      outputs["expected_dist"],
  )

  # Binarize occlusions
  visibles = model_utils.postprocess_occlusions(occlusions, expected_dist)
  return tracks[0], visibles[0]


chunk_inference = jax.jit(chunk_inference)

all_tracks = []
all_visibles = []
for chunk in range(0, query_points.shape[0], chunk_size):
  tracks, visibles = chunk_inference(query_points[chunk : chunk + chunk_size])
  all_tracks.append(np.array(tracks))
  all_visibles.append(np.array(visibles))

tracks = np.concatenate(all_tracks, axis=0)
visibles = np.concatenate(all_visibles, axis=0)

# Visualize sparse point tracks
height, width = video.shape[1:3]
tracks = transforms.convert_grid_coordinates(
    tracks, (resize_width, resize_height), (width, height)
)
video_viz = viz_utils.paint_point_track(video, tracks, visibles)
media.show_video(video_viz, fps=10)

## Apply TRAJAN

In [None]:
# @title Imports

from __future__ import annotations

import dataclasses
import einops
import os
import io

In [None]:
# @title Load parameters {form-width: "25%"}


def npload(fname):
  loaded = np.load(fname, allow_pickle=False)
  if isinstance(loaded, np.ndarray):
    return loaded
  else:
    return dict(loaded)


def recover_tree(flat_dict):
  tree = (
      {}
  )  # Initialize an empty dictionary to store the resulting tree structure
  for (
      k,
      v,
  ) in (
      flat_dict.items()
  ):  # Iterate over each key-value pair in the flat dictionary
    parts = k.split(
        '/'
    )  # Split the key into parts using "/" as a delimiter to build the tree structure
    node = tree  # Start at the root of the tree
    for part in parts[
        :-1
    ]:  # Loop through each part of the key, except the last one
      if (
          part not in node
      ):  # If the current part doesn't exist as a key in the node, create an empty dictionary for it
        node[part] = {}
      node = node[part]  # Move down the tree to the next level
    node[parts[-1]] = v  # Set the value at the final part of the key
  return tree  # Return the reconstructed tree

params = recover_tree(npload(trajan_checkpoint_path))

In [None]:
# @title Preprocessor for Tracks

@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class ProcessTracksForTrackAutoencoder:
  """Samples tracks and fills out support_tracks, query_points etc.

  TrackAutoencoder format which will be output from this transform:
   video: float["*B T H W 3"]
   support_tracks: float["*B QS T 2"]
   support_tracks_visible: float["*B QS T 1"]
   query_points: float["*B Q 3"]
  """

  # note that we do not use query points in the encoding, so it is expected
  # that num_support_tracks >> num_target_tracks

  num_support_tracks: int
  num_target_tracks: int

  # If true, assume that everything after the boundary_frame is padding,
  # so don't sample any query points after the boundary_frame, and only sample
  # target tracks that have at least one visible frame before the boundary.
  before_boundary: bool = True
  episode_length: int = 150

  # Keys.
  video_key: str = "video"
  tracks: str = "tracks"  # [time, num_points, 2]
  visible_key: str = "visible"  # [time, num_points, 1]

  def random_map(self, features):

    # set tracks xy and compute visibility
    tracks_xy = features[self.tracks][..., :2]
    tracks_xy = np.asarray(tracks_xy, np.float32)
    boundary_frame = features["video"].shape[0]

    # visibles already post-processed by compute_point_tracks.py
    visibles = np.asarray(features[self.visible_key], np.float32)

    # pad to 'episode_length' frames
    if self.before_boundary:
      # if input video is longer than episode_length, crop to episode_length
      if self.episode_length - visibles.shape[0] < 0:
        visibles = visibles[: self.episode_length]
        tracks_xy = tracks_xy[: self.episode_length]

      visibles = np.pad(
          visibles,
          [[0, self.episode_length - visibles.shape[0]], [0, 0]],
          constant_values=0.0,
      )
      tracks_xy = np.pad(
          tracks_xy,
          [[0, self.episode_length - tracks_xy.shape[0]], [0, 0], [0, 0]],
          constant_values=0.0,
      )

    # Samples indices for support tracks and target tracks.
    num_input_tracks = tracks_xy.shape[1]
    idx = np.arange(num_input_tracks)
    np.random.shuffle(idx)

    assert (
        num_input_tracks >= self.num_support_tracks + self.num_target_tracks
    ), (
        (
            f"num_input_tracks {num_input_tracks} must be greater than"
            f" num_support_tracks {self.num_support_tracks} + num_target_tracks"
            f" {self.num_target_tracks}"
        ),
    )

    idx_support = idx[-self.num_support_tracks :]
    idx_target = idx[: self.num_target_tracks]

    # Gathers support tracks from `features`.  Features are of shape
    # [time, num_points, 2]
    support_tracks = tracks_xy[..., idx_support, :]
    support_tracks_visible = visibles[..., idx_support]

    # Gathers target tracks from `features`.
    target_tracks = tracks_xy[..., idx_target, :]
    target_tracks_visible = visibles[..., idx_target]

    # transpose to [num_points, time, ...]
    support_tracks = np.transpose(support_tracks, [1, 0, 2])
    support_tracks_visible = np.expand_dims(
        np.transpose(support_tracks_visible, [1, 0]), axis=-1
    )

    # [time, point_id, x/y] -> [point_id, time, x/y]
    target_tracks = np.transpose(target_tracks, [1, 0, 2])
    target_tracks_visible = np.transpose(target_tracks_visible, [1, 0])

    # Sample query points as random visible points
    num_target_tracks = target_tracks_visible.shape[0]
    target_queries = self.sample_query_from_targets(
       num_target_tracks, target_tracks, target_tracks_visible)

    # Add channel dimension to target_tracks_visible
    target_tracks_visible = np.expand_dims(target_tracks_visible, axis=-1)

    # Updates `features` to contain these *new* features and add batch dim.
    features_new = {
        "support_tracks": support_tracks[None, :],
        "support_tracks_visible": support_tracks_visible[None, :],
        "query_points": target_queries[None, :],
        "target_points": target_tracks[None, :],
        "boundary_frame": np.array([boundary_frame]),
        "target_tracks_visible": target_tracks_visible[None, :],
    }
    features.update(features_new)
    return features
  
  def sample_query_from_targets(
        self,
        num_query_tracks: int,
        target_tracks: np.ndarray,
        target_tracks_visible: np.ndarray,
  ) -> np.ndarray:
    """Samples query points from target tracks."""
    random_frame = np.zeros(num_query_tracks, dtype=np.int64)
    num_frames = target_tracks_visible.shape[1]
    for i in range(num_query_tracks):
      visible_indices = np.where(target_tracks_visible[i] > 0)[0]
      if len(visible_indices) > 0:
          # Choose a random frame index from the visible ones
          random_frame[i] = np.random.choice(visible_indices)
      else:
          # If no frame is visible for a track, default to frame 0
          # (or handle as appropriate for your use case)
          random_frame[i] = 0
  
      # Create one-hot encoding based on the randomly selected frame for each track
      idx = np.eye(num_frames, dtype=np.float32)[
          random_frame
      ]  # [num_query_tracks, num_frames]

    # Use the one-hot index to select the coordinates at the chosen frame
    target_queries_xy = np.sum(
        target_tracks * idx[..., np.newaxis], axis=1
    )  # [num_query_tracks, 2]

    # Stack frame index and coordinates: [t, x, y]
    target_queries = np.stack(
        [
            random_frame.astype(np.float32),
            target_queries_xy[..., 0],
            target_queries_xy[..., 1],
        ],
        axis=-1,
    )  # [num_query_tracks, 3]
    return target_queries

In [None]:
# @title Run Model

# Create model and define forward pass.
model = track_autoencoder.TrackAutoEncoder(
  decoder_scan_chunk_size=32,  # If passing large queries
#   decoder_scan_chunk_size=None,  # If passing arbitrary small queries
)

@jax.jit
def forward(params, inputs):
  outputs = model.apply({'params': params}, inputs)
  return outputs

# Create preprocessor
preprocessor = ProcessTracksForTrackAutoencoder(
    num_support_tracks=2048,
    num_target_tracks=2048,
    video_key="video",
    before_boundary=True,
)

# Preprocess Batch
batch = {
    "video": video,
    "tracks": einops.rearrange(
        transforms.convert_grid_coordinates(
            tracks + 0.5, (width, height), (1, 1)
        ),
        "q t c -> t q c",
    ),
    "visible": einops.rearrange(visibles, "q t -> t q"),
}

batch = preprocessor.random_map(batch)
batch.pop("tracks", None)

# Run forward pass
outputs = forward(params, batch)

In [None]:
# @title Run Model on Custom Query Points

# Create model and define forward pass.
model = track_autoencoder.TrackAutoEncoder(
#   decoder_scan_chunk_size=32,  # If passing large queries
  decoder_scan_chunk_size=None,  # If passing arbitrary small queries
)

@jax.jit
def forward(params, inputs):
  outputs = model.apply({'params': params}, inputs)
  return outputs

# [Optional] Define custom query points [t, x, y], where t is the frame
# index, and x and y are normalized coordinates e.g. [12., 0.01, 0.9]
# Comment this out to use the query points sampled from the target tracks.
query_pts = np.array(
    [
        [10., 0.5, 0.5],  # center of the image
        [10., 0.05, 0.05],  # top left corner
        [10., 0.95, 0.95],  # bottom right corner
        [10., 0.05, 0.95],  # bottom left corner
    ]
)  # [num_query_points, 3]
batch['query_points'] = query_pts[None, :]

# Run forward pass
outputs = forward(params, batch)

In [None]:
# @title Visualize reconstructed point tracks
height, width = video.shape[1:3]

reconstructed_tracks = transforms.convert_grid_coordinates(
    outputs.tracks[0], (1, 1), (width, height)
)

support_tracks_vis = transforms.convert_grid_coordinates(
    batch['support_tracks'][0], (1, 1), (width, height)
)

target_tracks_vis = transforms.convert_grid_coordinates(
    batch['target_points'][0], (1, 1), (width, height)
)

reconstructed_visibles = model_utils.postprocess_occlusions(
    outputs.visible_logits, outputs.certain_logits
)

# NOTE: uncomment the lines below to also visualize the support & target tracks.
video_length = video.shape[0]

# video_viz = viz_utils.paint_point_track(
#     video,
#     support_tracks_vis[:, : video.shape[0]],
#     batch['support_tracks_visible'][0, :, :video_length],
# )
# media.show_video(video_viz, fps=10)

# video_viz = viz_utils.paint_point_track(
#     video,
#     target_tracks_vis[:, :video_length],
#     batch['target_tracks_visible'][0, :, :video_length],
# )
# media.show_video(video_viz, fps=10)

video_viz = viz_utils.paint_point_track(
    video,
    reconstructed_tracks[:, :video_length],
    reconstructed_visibles[0, :, :video_length],
    # np.ones_like(reconstructed_visibles[0, :, :video_length]),
)
media.show_video(video_viz, fps=10)

In [None]:
#@title Compute Metrics

# Query from the first frame onward.
query_points = np.zeros((
    reconstructed_visibles.shape[0],
    batch['target_tracks_visible'].shape[1],
    1,
))

# Compute TapVid metrics
metrics = evaluation_datasets.compute_tapvid_metrics(
    query_points=query_points,
    gt_occluded=1 - batch['target_tracks_visible'][..., :video_length, 0],
    gt_tracks=target_tracks_vis[None, ..., :video_length, :],
    pred_occluded=reconstructed_visibles[..., :video_length, 0],
    pred_tracks=reconstructed_tracks[..., :video_length, :],
    query_mode='strided',
    get_trackwise_metrics=False,
)

jaccard = np.mean([metrics[f'jaccard_{d}'] for d in [1, 2, 4, 8, 16]])
print('Average Jaccard:', jaccard)
print('Occlusion Accuracy:', metrics['occlusion_accuracy'].mean())