Skip to content

Commit

Permalink
Add BootsTAPIR to tapir_demo.ipynb
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605373628
Change-Id: Iba08d6c85be21e6bfe68e2b97e2bcc3d50c1e73f
  • Loading branch information
yangyi02 committed Feb 8, 2024
1 parent 6c4e62e commit bde1631
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ clone this repo and run TAPIR on your own hardware, including a real-time demo.
You can run colab demos to see how TAPIR works. You can also upload your own video and try point tracking with TAPIR.
We provide two colab demos:

1. <a target="_blank" href="https://colab.research.google.com/github/deepmind/tapnet/blob/master/colabs/tapir_demo.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Offline TAPIR"/></a> **Standard TAPIR**: This is the most powerful TAPIR model that runs on a whole video at once. We mainly report the results of this model in the paper.
1. <a target="_blank" href="https://colab.research.google.com/github/deepmind/tapnet/blob/master/colabs/tapir_demo.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Offline TAPIR"/></a> **Standard TAPIR**: This is the most powerful TAPIR / BootsTAPIR model that runs on a whole video at once. We mainly report the results of this model in the paper.
2. <a target="_blank" href="https://colab.research.google.com/github/deepmind/tapnet/blob/master/colabs/causal_tapir_demo.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Online TAPIR"/></a> **Online TAPIR**: This is the sequential causal TAPIR model that allows for online tracking on points, which can be run in real-time on a GPU platform.
3. <a target="_blank" href="https://colab.research.google.com/github/deepmind/tapnet/blob/master/colabs/tapir_rainbow_demo.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="TAPIR Rainbow Visualization"/></a> **Rainbow Visualization**: This visualization is used in many of our teaser videos: it does automatic foreground/background segmentation and corrects the tracks for the camera motion, so you can visualize the paths objects take through real space.
4. <a target="_blank" href="https://colab.research.google.com/github/deepmind/tapnet/blob/master/colabs/torch_tapir_demo.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="PyTorch BootsTAP"/></a> **Pytorch BootsTAPIR**: Check this BootsTAPIR model re-implemented in PyTorch, which follows the exact architecture as original BootsTAPIR model implemented in Jax.
4. <a target="_blank" href="https://colab.research.google.com/github/deepmind/tapnet/blob/master/colabs/torch_tapir_demo.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="PyTorch TAPIR"/></a> **PyTorch TAPIR**: This is the TAPIR / BootsTAPIR model re-implemented in PyTorch, which contains the exact architecture & weights as the Jax model.

### Live Demo

Expand Down Expand Up @@ -146,15 +146,15 @@ To validate that this is a better approach than a simple linear interpolation be

<a target="_blank" href="https://colab.research.google.com/github/deepmind/tapnet/blob/master/colabs/optical_flow_track_assist.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Point Track Annotation"/></a> **Flow assist point annotation**: You can run this colab demo to see how point tracks are annotated with optical flow assistance.

## RoboTAP Benchmark and Point Track based Clustering
## RoboTAP Benchmark and Point Track based Video Segmentation

[RoboTAP](https://robotap.github.io/) is a following work of TAP-Vid and TAPIR that demonstrates point tracking models are important for robotics.

The [RoboTAP dataset](https://storage.googleapis.com/dm-tapnet/robotap/robotap.zip) follows the same annotation format as TAP-Vid, but is released as an addition to TAP-Vid. In terms of domain, RoboTAP dataset is mostly similar to TAP-Vid-RGB-Stacking, with a key difference that all robotics videos are real and manually annotated. Video sources and object categories are also more diversified. The benchmark dataset includes 265 videos, serving for evaluation purpose only.

For more details of downloading and visualization of the dataset, please see the [data section](https://github.com/deepmind/tapnet/tree/main/data).

<a target="_blank" href="https://colab.research.google.com/github/deepmind/tapnet/blob/master/colabs/tapir_clustering.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Point Clustering"/></a> **Point track based clustering**: You can run this colab demo to see how point track based clustering works. Given an input video, the point tracks are extracted from TAPIR and further separated into different clusters according to different motion patterns. This is purely based on the low level motion and does not depend on any semantics or segmentation labels. You can also upload your own video and try point track based clustering.
<a target="_blank" href="https://colab.research.google.com/github/deepmind/tapnet/blob/master/colabs/tapir_clustering.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Point Clustering"/></a> **Point track based video segmentation**: You can run this colab demo to see how point track based video segmentation works. Given an input video, the point tracks are extracted from TAPIR and further separated into different clusters according to different motion patterns. This is purely based on the low level motion and does not depend on any other cues (i.e. semantics). You can also upload your own video and try it.

## Download Checkpoints

Expand Down
37 changes: 32 additions & 5 deletions colabs/tapir_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@
"!pip install -r tapnet/requirements_inference.txt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zyEo9-Kv78S7"
},
"outputs": [],
"source": [
"MODEL_TYPE = 'bootstapir' # 'tapir' or 'bootstapir'"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -95,7 +106,10 @@
"\n",
"%mkdir tapnet/checkpoints\n",
"\n",
"!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/tapir_checkpoint_panning.npy\n",
"if MODEL_TYPE == 'tapir':\n",
" !wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/tapir_checkpoint_panning.npy\n",
"else:\n",
" !wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/bootstapir_checkpoint.npy\n",
"\n",
"%ls tapnet/checkpoints"
]
Expand All @@ -110,6 +124,7 @@
"source": [
"# @title Imports {form-width: \"25%\"}\n",
"%matplotlib widget\n",
"import functools\n",
"import haiku as hk\n",
"import jax\n",
"import matplotlib.pyplot as plt\n",
Expand All @@ -136,7 +151,10 @@
"source": [
"# @title Load Checkpoint {form-width: \"25%\"}\n",
"\n",
"checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint_panning.npy'\n",
"if MODEL_TYPE == 'tapir':\n",
" checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint_panning.npy'\n",
"else:\n",
" checkpoint_path = 'tapnet/checkpoints/bootstapir_checkpoint.npy'\n",
"ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()\n",
"params, state = ckpt_state['params'], ckpt_state['state']"
]
Expand All @@ -151,9 +169,17 @@
"source": [
"# @title Build Model {form-width: \"25%\"}\n",
"\n",
"def build_model(frames, query_points):\n",
"def build_model(frames, query_points, model_type='tapir'):\n",
" \"\"\"Compute point tracks and occlusions given frames and query points.\"\"\"\n",
" model = tapir_model.TAPIR(bilinear_interp_with_depthwise_conv=False, pyramid_level=0)\n",
" if model_type == 'tapir':\n",
" model = tapir_model.TAPIR(bilinear_interp_with_depthwise_conv=False, pyramid_level=0)\n",
" elif model_type == 'bootstapir':\n",
" model = tapir_model.TAPIR(\n",
" bilinear_interp_with_depthwise_conv=False,\n",
" pyramid_level=1,\n",
" extra_convs=True,\n",
" softmax_temperature=10.0,\n",
" )\n",
" outputs = model(\n",
" video=frames,\n",
" is_training=False,\n",
Expand All @@ -162,7 +188,8 @@
" )\n",
" return outputs\n",
"\n",
"model = hk.transform_with_state(build_model)\n",
"build_model_fn = functools.partial(build_model, model_type=MODEL_TYPE)\n",
"model = hk.transform_with_state(build_model_fn)\n",
"model_apply = jax.jit(model.apply)"
]
},
Expand Down

0 comments on commit bde1631

Please sign in to comment.