<p align="center">
  <h1 align="center">TAPIR: Tracking Any Point with per-frame Initialization and temporal Refinement</h1>
  <p align="center">
    <a href="http://www.carldoersch.com/">Carl Doersch</a>
    ·
    <a href="https://yangyi02.github.io/">Yi Yang</a>
    ·
    <a href="https://scholar.google.com/citations?user=Jvi_XPAAAAAJ">Mel Vecerik</a>
    ·
    <a href="https://scholar.google.com/citations?user=cnbENAEAAAAJ">Dilara Gokay</a>
    ·
    <a href="https://www.robots.ox.ac.uk/~ankush/">Ankush Gupta</a>
    ·
    <a href="http://people.csail.mit.edu/yusuf/">Yusuf Aytar</a>
    ·
    <a href="https://scholar.google.co.uk/citations?user=IUZ-7_cAAAAJ">Joao Carreira</a>
    ·
    <a href="https://www.robots.ox.ac.uk/~az/">Andrew Zisserman</a>
  </p>
  <h3 align="center"><a href="https://arxiv.org/abs/">Paper</a> | <a href="https://deepmind-tapir.github.io">Project Page</a> | <a href="https://github.com/deepmind/tapnet">GitHub</a> | <a href="https://github.com/deepmind/tapnet">Demo</a> </h3>
  <div align="center"></div>
</p>

<p align="center">
  <a href="">
    <img src="https://storage.googleapis.com/dm-tapnet/swaying_gif.gif" alt="Logo" width="50%">
  </a>
</p>

In [None]:
# @title Download Code {form-width: "25%"}
!git clone https://github.com/deepmind/tapnet.git

Cloning into 'tapnet'...
remote: Enumerating objects: 419, done.[K
remote: Counting objects: 100% (419/419), done.[K
remote: Compressing objects: 100% (240/240), done.[K
remote: Total 419 (delta 239), reused 324 (delta 154), pack-reused 0[K
Receiving objects: 100% (419/419), 249.96 KiB | 6.75 MiB/s, done.
Resolving deltas: 100% (239/239), done.


In [None]:
# @title Install Dependencies {form-width: "25%"}
!pip install -r tapnet/requirements_inference.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jupyter_http_over_ws (from -r tapnet/requirements_inference.txt (line 2))
  Downloading jupyter_http_over_ws-0.0.8-py2.py3-none-any.whl (18 kB)
Collecting jaxline (from -r tapnet/requirements_inference.txt (line 5))
  Downloading jaxline-0.0.5.tar.gz (32 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dm-haiku (from -r tapnet/requirements_inference.txt (line 7))
  Downloading dm_haiku-0.0.9-py3-none-any.whl (352 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m352.1/352.1 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
Collecting mediapy (from -r tapnet/requirements_inference.txt (line 11))
  Downloading mediapy-1.1.6-py3-none-any.whl (24 kB)
Collecting einshape (from -r tapnet/requirements_inference.txt (line 13))
  Downloading einshape-1.0-py3-none-any.whl (21 kB)
Collecting ml_collections>=0.1 (from jaxline->-r tapnet/requirements_inferenc

In [None]:
# @title Download Model {form-width: "25%"}

%mkdir tapnet/checkpoints

!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/tapir_checkpoint.npy

%ls tapnet/checkpoints

--2023-06-15 21:20:17--  https://storage.googleapis.com/dm-tapnet/tapir_checkpoint.npy
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.215.128, 173.194.216.128, 173.194.212.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.215.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 124408122 (119M) [application/octet-stream]
Saving to: ‘tapnet/checkpoints/tapir_checkpoint.npy’


2023-06-15 21:20:18 (141 MB/s) - ‘tapnet/checkpoints/tapir_checkpoint.npy’ saved [124408122/124408122]

tapir_checkpoint.npy


In [None]:
# @title Imports {form-width: "25%"}

import haiku as hk
import jax
import mediapy as media
import numpy as np
import tensorflow_datasets as tfds
import tree

from tapnet import tapir_model
from tapnet.configs import tapir_config
from tapnet.utils import transforms
from tapnet.utils import viz_utils

In [None]:
# @title Load Checkpoint {form-width: "25%"}

checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']

In [None]:
# @title Build Model {form-width: "25%"}

def build_model(frames, query_points):
  """Compute point tracks and occlusions given frames and query points."""
  model = tapir_model.TAPIR()
  outputs = model(
      video=frames,
      is_training=False,
      query_points=query_points,
      query_chunk_size=64,
  )
  return outputs

model = hk.transform_with_state(build_model)
model_apply = jax.jit(model.apply)

In [None]:
# @title Utility Functions {form-width: "25%"}

def preprocess_frames(frames):
  """Preprocess frames to model inputs.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8

  Returns:
    frames: [num_frames, height, width, 3], [-1, 1], np.float32
  """
  frames = frames.astype(np.float32)
  frames = frames / 255 * 2 - 1
  return frames


def postprocess_occlusions(occlusions, expected_dist):
  """Postprocess occlusions to boolean visible flag.

  Args:
    occlusions: [num_points, num_frames], [-inf, inf], np.float32
    expected_dist: [num_points, num_frames], [-inf, inf], np.float32

  Returns:
    visibles: [num_points, num_frames], bool
  """
  # visibles = occlusions < 0
  visibles = (1 - jax.nn.sigmoid(occlusions)) * (1 - jax.nn.sigmoid(expected_dist)) > 0.5
  return visibles

def inference(frames, query_points):
  """Inference on one video.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8
    query_points: [num_points, 3], [0, num_frames/height/width], [t, y, x]

  Returns:
    tracks: [num_points, 3], [-1, 1], [t, y, x]
    visibles: [num_points, num_frames], bool
  """
  # Preprocess video to match model inputs format
  frames = preprocess_frames(frames)
  num_frames, height, width = frames.shape[0:3]
  query_points = query_points.astype(np.float32)
  frames, query_points = frames[None], query_points[None]  # Add batch dimension

  # Model inference
  rng = jax.random.PRNGKey(42)
  outputs, _ = model_apply(params, state, rng, frames, query_points)
  outputs = tree.map_structure(lambda x: np.array(x[0]), outputs)
  tracks, occlusions, expected_dist = outputs['tracks'], outputs['occlusion'], outputs['expected_dist']

  # Binarize occlusions
  visibles = postprocess_occlusions(occlusions, expected_dist)
  return tracks, visibles


def sample_random_points(frame_max_idx, height, width, num_points):
  """Sample random points with (time, height, width) order."""
  y = np.random.randint(0, height, (num_points, 1))
  x = np.random.randint(0, width, (num_points, 1))
  t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))
  points = np.concatenate((t, y, x), axis=-1).astype(np.int32)  # [num_points, 3]
  return points

In [None]:
# @title Load an Examplar Video {form-width: "25%"}

video_id = 'horsejump-high'  # @param

ds, ds_info = tfds.load('davis', split='validation', with_info=True)
davis_dataset = tfds.as_numpy(ds)

for sample in davis_dataset:
  video_name = sample['metadata']['video_name'].decode()
  if video_name == video_id:
    break  # stop at particular video id

Downloading and preparing dataset 794.19 MiB (download: 794.19 MiB, generated: 792.26 MiB, total: 1.55 GiB) to /root/tensorflow_datasets/davis/480p/2.1.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Shuffling /root/tensorflow_datasets/davis/480p/2.1.0.incomplete37OQA3/davis-train.tfrecord*...:   0%|         …

Generating validation examples...:   0%|          | 0/30 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/davis/480p/2.1.0.incomplete37OQA3/davis-validation.tfrecord*...:   0%|    …



Dataset davis downloaded and prepared to /root/tensorflow_datasets/davis/480p/2.1.0. Subsequent calls will reuse this data.


In [None]:
# @title Predict Sparse Point Tracks {form-width: "25%"}

resize_height = 256  # @param {type: "integer"}
resize_width = 256  # @param {type: "integer"}
num_points = 20  # @param {type: "integer"}

orig_frames = sample['video']['frames']
height, width = orig_frames.shape[1:3]
frames = media.resize_video(orig_frames, (resize_height, resize_width))
query_points = sample_random_points(0, frames.shape[1], frames.shape[2], num_points)
tracks, visibles = inference(frames, query_points)

# Visualize sparse point tracks
tracks = transforms.convert_grid_coordinates(tracks, (resize_width, resize_height), (width, height))
video = viz_utils.paint_point_track(orig_frames, tracks, visibles)
media.show_video(video, fps=10)

0
This browser does not support the video tag.


That's it!