<p align="center">
  <h1 align="center">TAPIR: Tracking Any Point with per-frame Initialization and temporal Refinement</h1>
  <p align="center">
    <a href="http://www.carldoersch.com/">Carl Doersch</a>
    ·
    <a href="https://yangyi02.github.io/">Yi Yang</a>
    ·
    <a href="https://scholar.google.com/citations?user=Jvi_XPAAAAAJ">Mel Vecerik</a>
    ·
    <a href="https://scholar.google.com/citations?user=cnbENAEAAAAJ">Dilara Gokay</a>
    ·
    <a href="https://www.robots.ox.ac.uk/~ankush/">Ankush Gupta</a>
    ·
    <a href="http://people.csail.mit.edu/yusuf/">Yusuf Aytar</a>
    ·
    <a href="https://scholar.google.co.uk/citations?user=IUZ-7_cAAAAJ">Joao Carreira</a>
    ·
    <a href="https://www.robots.ox.ac.uk/~az/">Andrew Zisserman</a>
  </p>
  <h3 align="center"><a href="https://arxiv.org/abs/2306.08637">Paper</a> | <a href="https://deepmind-tapir.github.io">Project Page</a> | <a href="https://github.com/deepmind/tapnet">GitHub</a> | <a href="https://github.com/deepmind/tapnet/tree/main#running-tapir-locally">Live Demo</a> </h3>
  <div align="center"></div>
</p>

<p align="center">
  <img src="https://storage.googleapis.com/dm-tapnet/horsejump_rainbow.gif" width="70%"/>
</p>


In [None]:
# @title Download Code {form-width: "25%"}
!git clone https://github.com/deepmind/tapnet.git

In [None]:
# @title Install Dependencies {form-width: "25%"}
!pip install -r tapnet/requirements_inference.txt

In [None]:
# @title Download Model {form-width: "25%"}

%mkdir tapnet/checkpoints

!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/tapir_checkpoint_panning.npy

%ls tapnet/checkpoints


In [None]:
# @title Imports {form-width: "25%"}

import jax
import jax.numpy as jnp
import haiku as hk
import mediapy as media
import numpy as np
import tree


In [None]:
from tapnet import tapir_model
from tapnet.utils import transforms
from tapnet.utils import viz_utils
from tapnet.utils import model_utils

# @title Load Checkpoint {form-width: "25%"}

checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint_panning.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']

# @title Build Model {form-width: "25%"}

def build_model(frames, query_points):
  """Compute point tracks and occlusions given frames and query points."""
  model = tapir_model.TAPIR(bilinear_interp_with_depthwise_conv=False, pyramid_level=0)
  outputs = model(
      video=frames,
      is_training=False,
      query_points=query_points,
      query_chunk_size=64,
  )
  return outputs

model = hk.transform_with_state(build_model)
model_apply = jax.jit(model.apply)

## Load and Build Model

In [None]:
# @title Inference function {form-width: "25%"}

def inference(frames, query_points):
  """Inference on one video.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8
    query_points: [num_points, 3], [0, num_frames/height/width], [t, y, x]

  Returns:
    tracks: [num_points, 3], [-1, 1], [t, y, x]
    visibles: [num_points, num_frames], bool
  """
  # Preprocess video to match model inputs format
  frames = model_utils.preprocess_frames(frames)
  num_frames, height, width = frames.shape[0:3]
  query_points = query_points.astype(np.float32)
  frames, query_points = frames[None], query_points[None]  # Add batch dimension

  # Model inference
  rng = jax.random.PRNGKey(42)
  outputs, _ = model_apply(params, state, rng, frames, query_points)
  outputs = tree.map_structure(lambda x: np.array(x[0]), outputs)
  tracks, occlusions, expected_dist = outputs['tracks'], outputs['occlusion'], outputs['expected_dist']

  # Binarize occlusions
  visibles = model_utils.postprocess_occlusions(occlusions, expected_dist)
  return tracks, visibles

In [None]:
# @title Utilities for model inference {form-width: "25%"}

def sample_grid_points(frame_idx, height, width, stride=1):
  """Sample grid points with (time height, width) order."""
  points = np.mgrid[stride//2:height:stride, stride//2:width:stride]
  points = points.transpose(1, 2, 0)
  out_height, out_width = points.shape[0:2]
  frame_idx = np.ones((out_height, out_width, 1)) * frame_idx
  points = np.concatenate((frame_idx, points), axis=-1).astype(np.int32)
  points = points.reshape(-1, 3)  # [out_height*out_width, 3]
  return points

## Inference on DAVIS

In [None]:
# @title Load an Exemplar Video {form-width: "25%"}

%mkdir tapnet/examplar_videos

!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4

orig_frames = media.read_video('tapnet/examplar_videos/horsejump-high.mp4')
height, width = orig_frames.shape[1:3]
media.show_video(orig_frames, fps=10)

In [None]:
# @title Predict semi-dense point tracks {form-width: "25%"}
%%time

resize_height = 512  # @param {type: "integer"}
resize_width = 512  # @param {type: "integer"}
stride = 16  # @param {type: "integer"}

height, width = orig_frames.shape[1:3]
frames = media.resize_video(orig_frames, (resize_height, resize_width))
query_points = sample_grid_points(0, resize_height, resize_width, stride)
batch_size = 64
tracks = []
visibles = []
for i in range(0,query_points.shape[0],batch_size):
  query_points_chunk = query_points[i:i+batch_size]
  num_extra = batch_size - query_points_chunk.shape[0]
  if num_extra > 0:
    query_points_chunk = np.concatenate([query_points_chunk, np.zeros([num_extra, 3])], axis=0)
  tracks2, visibles2 = inference(frames, query_points_chunk)
  if num_extra > 0:
    tracks2 = tracks2[:-num_extra]
    visibles2 = visibles2[:-num_extra]
  tracks.append(tracks2)
  visibles.append(visibles2)
tracks=jnp.concatenate(tracks, axis=0)
visibles=jnp.concatenate(visibles, axis=0)

tracks = transforms.convert_grid_coordinates(tracks, (resize_width, resize_height), (width, height))

# We show the point tracks without rainbows so you can see the input.
video = viz_utils.plot_tracks_v2(orig_frames, tracks, 1.0 - visibles)
media.show_video(video, fps=10)


In [None]:
occluded = 1.0 - visibles
homogs, err, canonical = viz_utils.get_homographies_wrt_frame(
    tracks,
    occluded,
    [width, height]
)

# sort by position in canonical frame.  In this demo they're already essentially
# sorted, but if you query points from multiple frames or are chosen randomly,
# they won't be.
ordr = np.argsort(canonical[:,1])
sorted_tracks = tracks[ordr]
sorted_occ = occluded[ordr]
sorted_err = err[ordr]
inlier_ct = np.sum((sorted_err < np.square(0.07)) * (1 - sorted_occ), axis=-1)
ratio = inlier_ct / np.maximum(1.0, np.sum(1 - sorted_occ, axis=1))
is_fg = ratio <= 0.60
video = viz_utils.plot_tracks_tails(
    orig_frames,
    sorted_tracks[is_fg],
    sorted_occ[is_fg],
    homogs
)
media.show_video(video, fps=16)