# Demo

A simple demo on videos from TAP-Vid DAVIS

First, enter paths to model weights and DAVIS pickle file

In [1]:
DYN_MAST3R_WEIGHTS = "checkpoints/dynamic_50e.pth"
REFINER_WEIGHTS = "checkpoints/refiner.pth"
DAVIS_PKL = "benchmarks/tapvid/davis/tapvid_davis.pkl"

Load models and init the tracker

In [None]:
import torch

from tracker.utils import path_to_mast3r  # noqa: F401
from tracker.refinement import Refinement
from tracker.tracker import Tracker
from mast3r.model import AsymmetricMASt3R

dynamic_mast3r = AsymmetricMASt3R.from_pretrained(DYN_MAST3R_WEIGHTS).eval().cuda()

refiner = Refinement().eval().cuda()
refiner.load_state_dict(torch.load(REFINER_WEIGHTS))


tracker = Tracker(
    dynamic_mast3r,
    refiner,
    use_refined_coords=True,
    display_progress=True,
)

Load the desired video

In [3]:
from benchmarks.tapvid.evaluation_datasets import create_davis_dataset

davis_videos = {
    k: v
    for data in create_davis_dataset(DAVIS_PKL, query_mode="first")
    for k, v in data.items()
}

In [4]:
import numpy as np

video_name = "car-roundabout"
video_data = davis_videos[video_name]

# The video should be in [0, 255] uint8 format
video = ((video_data["video"] + 1) * 127.5).astype(np.uint8)[0]
query_points = video_data["query_points"][0]
target_points = video_data["target_points"][0]
gt_occluded = video_data["occluded"][0]

Run the tracker on the video

In [None]:
trajs, occluded = tracker.track(video, query_points)

Calculate TAP-Vid metrics and visualize results

In [6]:
from benchmarks.tapvid.evaluation_datasets import compute_tapvid_metrics

metrics = compute_tapvid_metrics(
    query_points=video_data["query_points"],
    gt_occluded=video_data["occluded"],
    gt_tracks=video_data["target_points"],
    pred_occluded=occluded[None],
    pred_tracks=trajs[None],
    query_mode="first",
)

print(
    f"OA: {metrics['occlusion_accuracy'][0]:.2f}, AJ: {metrics['average_jaccard'][0]:.2f}, AVG PTS: {metrics['average_pts_within_thresh'][0]:.2f}"
)

OA: 0.80, AJ: 0.45, AVG PTS: 0.73


In [8]:
import mediapy
from tracker.utils.video import plot_tracks

mediapy.show_video(
    plot_tracks(
        video,
        trajs,
        target_points,
        visible=~occluded,
        gt_visible=~gt_occluded,
    ),
    fps=5,
)

0
This browser does not support the video tag.
