In [None]:
# @title Download Code {form-width: "25%"}
!git clone https://github.com/deepmind/tapnet.git

In [None]:
# @title Install Dependencies {form-width: "25%"}
!pip install -r tapnet/requirements_inference.txt

In [None]:
# @title Download Model {form-width: "25%"}

%mkdir tapnet/checkpoints

!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/causal_tapir_checkpoint.npy

%ls tapnet/checkpoints

checkpoint_path = 'tapnet/checkpoints/causal_tapir_checkpoint.npy'

In [None]:
# @title Imports {form-width: "25%"}
%matplotlib widget
import functools

import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
from tqdm import tqdm
import tree

from tapnet import tapir_clustering
from tapnet.utils import transforms
from tapnet.utils import viz_utils

In [None]:
# @title Load an Exemplar Video {form-width: "25%"}

%mkdir tapnet/examplar_videos

!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/robotap/for_clustering.mp4

video = media.read_video('tapnet/examplar_videos/for_clustering.mp4')
height, width = video.shape[1:3]
media.show_video(video[::5], fps=10)

In [None]:
# @title Run TAPIR to extract point tracks {form-width: "25%"}

demo_videos = {"dummy_id":video}
demo_episode_ids = list(demo_videos.keys())
track_dict = tapir_clustering.track_many_points(
    demo_videos,
    demo_episode_ids,
    checkpoint_path,
    point_batch_size=1024,
    points_per_frame=10,
)

In [None]:
# @title Run the clustering {form-width: "25%"}

clustered = tapir_clustering.compute_clusters(
    track_dict['separation_tracks'],
    track_dict['separation_visibility'],
    track_dict['demo_episode_ids'],
    track_dict['video_shape'],
    track_dict['query_features'],
    max_num_cats=12,
    final_num_cats=7,
)

In [None]:
# @title Display the inferred clusters {form-width: "25%"}

separation_visibility_trim = clustered['separation_visibility']
separation_tracks_trim = clustered['separation_tracks']

pointtrack_video = viz_utils.plot_tracks_v2(
    (demo_videos[demo_episode_ids[0]]).astype(np.uint8),
    separation_tracks_trim[demo_episode_ids[0]],
    1.0-separation_visibility_trim[demo_episode_ids[0]],
    trackgroup=clustered['classes']
)
media.show_video(pointtrack_video, fps=20)