<a href="https://colab.research.google.com/github/deepmind/perception_test/blob/main/baselines/single_point_tracking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates how to load the point tracking annotations in the validation split of the Perception Test and run the evaluation for a simple baseline model which predicts static points for all tracks.

Copyright 2023 DeepMind Technologies Limited

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at [https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0).
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Single Point Tracking Static Baseline

Github: https://github.com/deepmind/perception_test

## The Perception Test
[Perception Test: A Diagnostic Benchmark for Multimodal Video Models](https://arxiv.org/abs/2305.13786) is a multimodal benchmark designed to comprehensively evaluate the perception and reasoning skills of multimodal video models. The Perception Test dataset introduces real-world videos designed to show perceptually interesting situations and defines multiple computational tasks (object and point tracking, action and sound localisation, multiple-choice and grounded video question-answering). Here, we provide details and a simple baseline for the single  point traking task.

[![Perception Test Overview Presentation](https://img.youtube.com/vi/8BiajMOBWdk/maxresdefault.jpg)](https://youtu.be/8BiajMOBWdk?t=10)

## Single Point Tracking
 In this task, given the 2D coordinates of a point at the beginning of a video, the model should track the point throughout the video. Performance is evaluated using the recently proposed average [Jaccard metric](https://arxiv.org/abs/2211.03726) for evaluating both long-term point tracking position and occlusion accuracy.

The below image shows examples of point tracking annotations. Note that the annotations match the frame rate of the original videos, around 30fps.

![image collage](https://storage.googleapis.com/dm-perception-test/img/ptpt_collage.png)

## Static baseline
This notebook demonstrates how to load the point tracking annotations in the validation split of the Perception Test, and run the evaluation for a dummy baseline model. This model assumes static points for all point tracks in a video.

In [None]:
# @title Prerequisites
import colorsys
import json
import os
import random
from typing import Tuple, List, Dict, Any, Mapping
import zipfile

import cv2
import imageio
import moviepy.editor as mvp
import numpy as np
import requests

In [None]:
# @title Utility functions
def download_and_unzip(url: str, destination: str):
  """Downloads and unzips a .zip file to a destination.

  Downloads a file from the specified URL, saves it to the destination
  directory, and then extracts its contents.

  If the file is larger than 1GB, it will be downloaded in chunks,
  and the download progress will be displayed.

  Args:
    url (str): The URL of the file to download.
    destination (str): The destination directory to save the file and
      extract its contents.
  """
  if not os.path.exists(destination):
    os.makedirs(destination)

  filename = url.split('/')[-1]
  file_path = os.path.join(destination, filename)

  if os.path.exists(file_path):
    print(f'{filename} already exists. Skipping download.')
    return

  response = requests.get(url, stream=True)
  total_size = int(response.headers.get('content-length', 0))
  gb = 1024*1024*1024

  if total_size / gb > 1:
    print(f'{filename} is larger than 1GB, downloading in chunks')
    chunk_flag = True
    chunk_size = int(total_size/100)
  else:
    chunk_flag = False
    chunk_size = total_size

  with open(file_path, 'wb') as file:
    for chunk_idx, chunk in enumerate(
        response.iter_content(chunk_size=chunk_size)):
      if chunk:
        if chunk_flag:
          print(f"""{chunk_idx}% downloading
          {round((chunk_idx*chunk_size)/gb, 1)}GB
          / {round(total_size/gb, 1)}GB""")
        file.write(chunk)
  print(f"'{filename}' downloaded successfully.")

  with zipfile.ZipFile(file_path, 'r') as zip_ref:
    zip_ref.extractall(destination)
  print(f"'{filename}' extracted successfully.")

  os.remove(file_path)


def load_db_json(db_file: str) -> Dict[str, Any]:
  """Loads a JSON file as a dictionary.

  Args:
    db_file (str): Path to the JSON file.

  Returns:
    Dict: Loaded JSON data as a dictionary.

  Raises:
    FileNotFoundError: If the specified file doesn't exist.
    TypeError: If the JSON file is not formatted as a dictionary.
  """
  if not os.path.isfile(db_file):
    raise FileNotFoundError(f'No such file: {db_file}')

  with open(db_file, 'r') as f:
    db_file_dict = json.load(f)
    if not isinstance(db_file_dict, dict):
      raise TypeError('JSON file is not formatted as a dictionary.')
    return db_file_dict


def load_mp4_to_frames(filename: str) -> np.array:
  """Loads an MP4 video file and returns its frames as a NumPy array.

  Args:
    filename (str): Path to the MP4 video file.

  Returns:
    np.array: Frames of the video as a NumPy array.
  """
  assert os.path.exists(filename), f'File {filename} does not exist.'
  cap = cv2.VideoCapture(filename)

  vid_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))

  vid_frames = np.empty((vid_frames, height, width, 3), dtype=np.uint8)

  idx = 0
  while True:
    ret, vid_frame = cap.read()
    if not ret:
      break

    vid_frames[idx] = vid_frame
    idx += 1

  cap.release()
  return vid_frames


def get_video_frames(data_item: Dict[str, Any], vid_path: str) -> np.array:
  """Loads frames of a video specified by an item dictionary.

  Assumes format of annotations used in the Perception Test Dataset.

  Args:
  	data_item (Dict): Item from dataset containing metadata.
    vid_path (str): Path to the directory containing videos.

  Returns:
    np.array: Frames of the video as a NumPy array.
  """
  video_file_path = os.path.join(vid_path,
                                 data_item['metadata']['video_id']) + '.mp4'
  vid_frames = load_mp4_to_frames(video_file_path)
  assert data_item['metadata']['num_frames'] == vid_frames.shape[0]
  return vid_frames

In [None]:
# @title Download data
data_path = './data/'
db_json_path = './data/sample.json'
video_path = './data/videos/'

# sample annotations and videos to visualise the annotations later
sample_annot_url = 'https://storage.googleapis.com/dm-perception-test/zip_data/sample_annotations.zip'
download_and_unzip(sample_annot_url, data_path)

sample_videos_url = 'https://storage.googleapis.com/dm-perception-test/zip_data/sample_videos.zip'
download_and_unzip(sample_videos_url, data_path)

# This is the validation set of annotations for point tracking
valid_annot_url = 'https://storage.googleapis.com/dm-perception-test/zip_data/point_tracking_valid_annotations.zip'
download_and_unzip(valid_annot_url, data_path)

'sample_annotations.zip' downloaded successfully.
'sample_annotations.zip' extracted successfully.
'sample_videos.zip' downloaded successfully.
'sample_videos.zip' extracted successfully.
'point_tracking_valid_annotations.zip' downloaded successfully.
'point_tracking_valid_annotations.zip' extracted successfully.


In [None]:
# @title Dataset class
class PerceptionDataset():
  """Dataset class to store video items from dataset.

  Attributes:
    video_folder_path: Path to the folder containing the videos.
    task: Task type for annotations.
    split: Dataset split to load.
  	task_db: List containing annotations for dataset according to
  		split and task availability.
  """

  def __init__(self, db_path: Dict[str, Any], video_folder_path: str,
               task: str, split: str) -> None:
    """Initializes the PerceptionDataset class.

    Args:
      db_path (str): Path to the annotation file.
      video_folder_path (str): Path to the folder containing the videos.
      task (str): Task type for annotations.
      split (str): Dataset split to load.
    """
    self.video_folder_path = video_folder_path
    self.task = task
    self.split = split
    self.task_db = self.load_dataset(db_path)

  def load_dataset(self, db_path: str) -> List:
    """Loads the dataset from the annotation file to dictionary.

    Dict is processed according to split and task.

    Args:
      db_path (str): Path to the annotation file.

    Returns:
      List: List of database items containing annotations.
    """
    db_dict = load_db_json(db_path)
    db_list = []
    for _, val in db_dict.items():
      if val['metadata']['split'] == self.split:
        if val[self.task]:  # If video has annotations for this task
          db_list.append(val)

    return db_list

  def __len__(self) -> int:
    """Returns the total number of videos in the dataset.

    Returns:
      int: Total number of videos.
    """
    return len(self.pt_db_list)

  def __getitem__(self, idx: int) -> Dict[str, Any]:
    """Returns the video and annotations for a given index.

    Args:
      idx (int): Index of the video.

    Returns:
      Dict: Dictionary containing the video frames, metadata, annotations.
    """
    data_item = self.task_db[idx]
    annot = data_item[self.task]

    metadata = data_item['metadata']
    # here we are loading a placeholder as the frames
    # the commented out function below will actually load frames
    vid_frames = np.zeros((metadata['num_frames'], 1, 1, 1))
    # frames = get_video_frames(video_item, self.video_folder_path)

    return {'metadata': metadata,
            self.task: annot,
            'frames': vid_frames}

In [None]:
# @title Point tracking model (static baseline)
class PointTracker():
  """Point tracker class that tracks a given point in a video.

  This model assumes static point: given the starting point in a video
  which should be tracked in a sequence, for every frame in the
  remaining sequence it will return the same coordinates.

  """

  def __init__(self):
    """Initializes the PointTracker class."""
    pass

  def track(self, frames: np.array, start_point: List[float],
            start_frame_id: int)-> Dict[str, Any]:
    """Tracks a point in a video.

    Tracks a point given a sequence of frames and initial information about
    the coordinates and frame ID of the point.

    Args:
      frames (np.array): Array of frames representing the video.
      start_point (List): List containing the start point in [y, x] format.
      start_frame_id (int): Integer value for starting frame ID for point track.

    Returns:
      Dict: Dictionary containing the point track sequence and
        corresponding frame IDs.
    """
    # initially take starting point for tracking
    prev_point = start_point
    output_points = []
    output_frame_ids = []

    for frame_id in range(start_frame_id, frames.shape[0]):
      # here is where the per frame tracking is done by the model
      # we just return the starting coords in this dummy baseline
      frame = frames[frame_id]
      point = self.track_frame(frame, prev_point)
      output_points.append(point)
      output_frame_ids.append(frame_id)
    output_points = np.stack(output_points, axis=0)
    output_frame_ids = np.array(output_frame_ids)
    return output_points, output_frame_ids

  # model inference would be inserted here!!
  def track_frame(self, frame: np.array,
                  prev_point: List[float]) -> List[float]:
    """Tracks a point in a single frame.

    Tracks a point in a single frame based on the previous point
    coordinates. Placeholder function that just returns the coords it is given,
    assumes a static point which is never occluded.

    Args:
      frame (np.array): The current frame.
      prev_point(List): Previous point coordinates. [y, x] format.

    Returns:
      List: The point coordinates in the current frame.
    """
    del frame  # unused
    return prev_point

In [None]:
# @title Evaluation functions

# from https://github.com/deepmind/tapnet/blob/main/evaluation_datasets.py
def compute_tapvid_metrics(
    query_points: np.ndarray,
    gt_occluded: np.ndarray,
    gt_tracks: np.ndarray,
    pred_occluded: np.ndarray,
    pred_tracks: np.ndarray,
    query_mode: str,
) -> Mapping[str, np.ndarray]:
  """Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)

  See the TAP-Vid paper for details on the metric computation.  All inputs are
  given in raster coordinates.  The first three arguments should be the direct
  outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
  The paper metrics assume these are scaled relative to 256x256 images.
  pred_occluded and pred_tracks are your algorithm's predictions.

  This function takes a batch of inputs, and computes metrics separately for
  each video.  The metrics for the full benchmark are a simple mean of the
  metrics across the full set of videos.  These numbers are between 0 and 1,
  but the paper multiplies them by 100 to ease reading.

  Args:
     query_points: The query points, an in the format [t, y, x].  Its size is
       [b, n, 3], where b is the batch size and n is the number of queries
     gt_occluded: A boolean array of shape [b, n, t], where t is the number of
       frames.  True indicates that the point is occluded.
     gt_tracks: The target points, of shape [b, n, t, 2].  Each point is in the
       format [x, y]
     pred_occluded: A boolean array of predicted occlusions, in the same format
       as gt_occluded.
     pred_tracks: An array of track predictions from your algorithm, in the same
       format as gt_tracks.
     query_mode: Either 'first' or 'strided', depending on how queries are
       sampled.  If 'first', we assume the prior knowledge that all points
       before the query point are occluded, and these are removed from the
       evaluation.

  Returns:
      A dict with the following keys:

      occlusion_accuracy: Accuracy at predicting occlusion.
      pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
        predicted to be within the given pixel threshold, ignoring occlusion
        prediction.
      jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
        threshold
      average_pts_within_thresh: average across pts_within_{x}
      average_jaccard: average across jaccard_{x}
  """

  metrics = {}

  # Don't evaluate the query point.  Numpy doesn't have one_hot, so we
  # replicate it by indexing into an identity matrix.
  one_hot_eye = np.eye(gt_tracks.shape[2])
  query_frame = query_points[..., 0]
  query_frame = np.round(query_frame).astype(np.int32)
  evaluation_points = one_hot_eye[query_frame] == 0

  # If we're using the first point on the track as a query, don't evaluate the
  # other points.
  if query_mode == 'first':
    for i in range(gt_occluded.shape[0]):
      index = np.where(gt_occluded[i] == 0)[0][0]
      evaluation_points[i, :index] = False
  elif query_mode != 'strided':
    raise ValueError('Unknown query mode ' + query_mode)

  # Occlusion accuracy is simply how often the predicted occlusion equals the
  # ground truth.
  occ_acc = np.sum(
      np.equal(pred_occluded, gt_occluded) & evaluation_points,
      axis=(1, 2),
  ) / np.sum(evaluation_points)
  metrics['occlusion_accuracy'] = occ_acc

  # Next, convert the predictions and ground truth positions into pixel
  # coordinates.
  visible = np.logical_not(gt_occluded)
  pred_visible = np.logical_not(pred_occluded)
  all_frac_within = []
  all_jaccard = []
  for thresh in [1, 2, 4, 8, 16]:
    # True positives are points that are within the threshold and where both
    # the prediction and the ground truth are listed as visible.
    within_dist = np.sum(
        np.square(pred_tracks - gt_tracks),
        axis=-1,
    ) < np.square(thresh)
    is_correct = np.logical_and(within_dist, visible)

    # Compute the frac_within_threshold, which is the fraction of points
    # within the threshold among points that are visible in the ground truth,
    # ignoring whether they're predicted to be visible.
    count_correct = np.sum(
        is_correct & evaluation_points,
        axis=(1, 2),
    )
    count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
    frac_correct = count_correct / count_visible_points
    metrics['pts_within_' + str(thresh)] = frac_correct
    all_frac_within.append(frac_correct)

    true_positives = np.sum(
        is_correct & pred_visible & evaluation_points, axis=(1, 2)
    )

    # The denominator of the jaccard metric is the true positives plus
    # false positives plus false negatives.  However, note that true positives
    # plus false negatives is simply the number of points in the ground truth
    # which is easier to compute than trying to compute all three quantities.
    # Thus we just add the number of points in the ground truth to the number
    # of false positives.
    #
    # False positives are simply points that are predicted to be visible,
    # but the ground truth is not visible or too far from the prediction.
    gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
    false_positives = (~visible) & pred_visible
    false_positives = false_positives | ((~within_dist) & pred_visible)
    false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
    jaccard = true_positives / (gt_positives + false_positives)
    metrics['jaccard_' + str(thresh)] = jaccard
    all_jaccard.append(jaccard)
  metrics['average_jaccard'] = np.mean(
      np.stack(all_jaccard, axis=1),
      axis=1,
  )
  metrics['average_pts_within_thresh'] = np.mean(
      np.stack(all_frac_within, axis=1),
      axis=1,
  )
  return metrics


def evaluate(results: Dict[str, Any], label_dict: Dict[str, Any],
                scale: float = 256.0) -> float:
  """Calculates the Average Jaccard for each video in the results.

  Args:
    results: A dictionary containing the results for each video.
      The keys are video IDs, and the values are dictionaries with
      'point_tracking' information.
    label_dict: A dictionary containing the ground truth labels for each video.
      The keys are video IDs, and the values are dictionaries with
      'point_tracking' information.
    scale: A float value representing the scaling factor (default is 256.0).

  Returns:
    The average Jaccard across all videos.

  Raises:
    AssertionError: If the lengths of predicted tracks and ground truth tracks
      do not match.

  """
  avg_jacs = []
  static_avg_jacs = []
  moving_avg_jacs = []
  for video_id, video_item in results.items():
    pred_tracks = video_item['point_tracking']
    gt_tracks = label_dict[video_id]['point_tracking']
    num_frames = label_dict[video_id]['metadata']['num_frames']
    assert len(pred_tracks) == len(gt_tracks)
    num_tracks = len(pred_tracks)

    query_points = np.zeros((1, num_tracks, 3))
    gt_occluded = np.ones((1, num_tracks, num_frames))
    gt_points = np.zeros((1, num_tracks, num_frames, 2))
    pred_occluded = np.ones((1, num_tracks, num_frames))
    pred_points = np.zeros((1, num_tracks, num_frames, 2))

    for track_idx, pred_track in enumerate(pred_tracks):
      gt_track = gt_tracks[pred_track['id']]
      gt_track_points = np.array(gt_track['points']).T
      pred_track_points = np.array(pred_track['points']).T

      start_point = gt_track_points[0]
      start_frame_id = gt_track['frame_ids'][0]
      query_points[0, track_idx, 0] = start_frame_id
      query_points[0, track_idx, 1:] = start_point
      gt_occluded[0, track_idx][gt_track['frame_ids']] = 0
      pred_occluded[0, track_idx] = 0

      gt_points[0, track_idx][gt_track['frame_ids']] = gt_track_points
      pred_points[0, track_idx, :, :][pred_track['frame_ids']] = (
          pred_track_points
      )

    gt_points *= scale
    pred_points *= scale

    metrics = compute_tapvid_metrics(query_points, gt_occluded,
                                     gt_points, pred_occluded,
                                     pred_points, 'first')
    avg_jacs.append(metrics['average_jaccard'])

    if label_dict[video_id]['metadata']['is_camera_moving']:
      moving_avg_jacs.append(metrics['average_jaccard'])
    else:
      static_avg_jacs.append(metrics['average_jaccard'])

  average_jaccard = np.mean(avg_jacs)
  static_average_jaccard = np.mean(static_avg_jacs)
  moving_average_jaccard = np.mean(moving_avg_jacs)
  print(f'Average Jaccard across all videos: {average_jaccard}')
  print(f'Average Jaccard across static videos: {static_average_jaccard}')
  print(f'Average Jaccard across moving videos: {moving_average_jaccard}')
  return average_jaccard

In [None]:
# @title Evaluate static baseline
label_path = './data/point_tracking_valid.json'
cfg = {'video_folder_path': './data/videos/',
       'task': 'point_tracking',
       'split': 'valid'}

# init dataset
tracking_dataset = PerceptionDataset(label_path, **cfg)

# init tracking model
point_tracker = PointTracker()

# run model across full dataset
results = {}
for video_item in tracking_dataset:
  video_id = video_item['metadata']['video_id']
  video_pred_tracks = []
  for gt_track in video_item['point_tracking']:
    points = np.array(gt_track['points']).T
    start_point = points[0]
    start_frame_id = gt_track['frame_ids'][0]
    pred_points, pred_frame_ids = (
        point_tracker.track(video_item['frames'], start_point,
                            start_frame_id)
    )

    pred_track = {}
    # .tolist() to serialise without error
    pred_track['points'] = pred_points.T.tolist()
    pred_track['frame_ids'] = pred_frame_ids.tolist()
    pred_track['id'] = gt_track['id']
    video_pred_tracks.append(pred_track)

  results[video_id] = {'point_tracking': video_pred_tracks}

In [None]:
# @title Serialise example results file
# Writing model outputs in the expected competition format. This
# JSON file contains answers to all questions in the validation split in the
# format:

# example_submission = {'video_1009' :{'point_tracking':[
#     {'id': 0, 'points': [[y,x],...], 'frame_ids': [n,...]},
#     {'id': 1, 'points': [[y,x],...], 'frame_ids': [n,...]},...]}}

# This file could be used directly as a submission for the Eval.ai challenge
with open(f'{cfg["task"]}_{cfg["split"]}_results.json', 'w') as my_file:
  json.dump(results, my_file)

In [None]:
# @title Compute average Jaccard
label_dict = load_db_json(label_path)
aj = evaluate(results, label_dict)

Average Jaccard across all videos: 0.3606894073177557
Average Jaccard across static videos: 0.4092416410388438
Average Jaccard across moving videos: 0.08703136270798537


In [None]:
# @title Visualisation functions
def get_colors(num_colors: int) -> Tuple[int, int, int]:
  """Generate random colormaps for visualizing different objects and points.

  Args:
    num_colors (int): The number of colors to generate.

  Returns:
    Tuple[int, int, int]: A tuple of RGB values representing the
      generated colors.
  """
  colors = []
  for j in np.arange(0., 360., 360. / num_colors):
    hue = j / 360.
    lightness = (50 + np.random.rand() * 10) / 100.
    saturation = (90 + np.random.rand() * 10) / 100.
    color = colorsys.hls_to_rgb(hue, lightness, saturation)
    color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
    colors.append(color)
  random.seed(0)
  random.shuffle(colors)
  return colors


def display_video(vid_frames: np.array, fps: int = 30):
  """Create and display temporary video from numpy array frames.

  Args:
    vid_frames: (np.array): The frames of the video as a
    	numpy array. Format of frames should be:
    	(num_frames, height, width, channels)
    fps (int): Frames per second for the video playback. Default is 30.
  """
  kwargs = {'macro_block_size': None}
  imageio.mimwrite('tmp_video_display.mp4',
                   vid_frames[:, :, :, ::-1], fps=fps, **kwargs)
  display(mvp.ipython_display('tmp_video_display.mp4'))


def paint_point(video: np.array, track: dict,
                color: tuple[int, int, int] = (255, 0, 0),
                label: str = None) -> np.array:
  """Paints a single tracked point on each frame of a video.

  Args:
    video (np.array): The video frames as a numpy array.
    track (dict): The track containing frame IDs and corresponding points.
    color (tuple, optional): The color of the painted point.
      Defaults to (255, 0, 0).
    label (str): string to be added to point label annotation.

  Returns:
    np.array: The video frames with painted points.
  """
  _, height, width, _ = video.shape
  for idx, frame_id in enumerate(track['frame_ids']):
    vid_frame = video[frame_id]
    y = int(round(track['points'][0][idx] * height))
    x = int(round(track['points'][1][idx] * width))
    vid_frame = cv2.circle(vid_frame, (x, y), radius=10,
                           color=color, thickness=-1)
    if label is not None:
      vid_frame = cv2.putText(vid_frame, label, (x, y + 20),
                              cv2.FONT_HERSHEY_SIMPLEX, 0.75, color, 2)
    video[frame_id] = vid_frame
  return video


def paint_points(video: np.array, tracks: List[dict]) -> np.array:
  """Paints multiple tracked points on each frame of a video.

  Args:
    video (np.array): The video frames as a numpy array.
    tracks (List[dict]): The list of tracks containing
      frame IDs and corresponding points.

  Returns:
    np.array: The video frames with painted points.
  """
  for idx, track in enumerate(tracks):
    video = paint_point(video, track, COLORS[idx])
  return video

In [None]:
# @title Show Example Annotations
COLORS = get_colors(num_colors=100)

db_dict = load_db_json(db_json_path)

video_id = list(db_dict.keys())[6]
example_data = db_dict[video_id]

if example_data['point_tracking']:
  frames = get_video_frames(example_data, video_path)
  tracks_to_show = 0.5  # @param {type:"slider", min:0, max:1, step:0.05}
  num_tracks = int(len(example_data['point_tracking']) * tracks_to_show)
  frames = paint_points(frames, example_data['point_tracking'][0:num_tracks])
  display_video(frames, example_data['metadata']['frame_rate'])
  del frames

In [None]:
# @title Model outputs visualised (static points)
# here we show how actual inference would work with video frames loaded
if example_data['point_tracking']:
  get_video_frames(example_data, video_path)
  frames = get_video_frames(example_data, video_path)

  pred_tracks = []
  for gt_track in example_data['point_tracking']:
    points = np.array(gt_track['points']).T
    start_point = points[0]
    start_frame_id = gt_track['frame_ids'][0]
    pred_points, pred_frame_ids = (
        point_tracker.track(frames, start_point,
                            start_frame_id)
    )
    pred_track = {}
    pred_track['points'] = pred_points.T.tolist()
    pred_track['frame_ids'] = pred_frame_ids.tolist()
    pred_track['label'] = gt_track['label']
    pred_track['id'] = gt_track['id']
    pred_tracks.append(pred_track)

  tracks_to_show = 0.5  # @param {type:"slider", min:0, max:1, step:0.05}
  num_tracks = int(len(pred_tracks) * tracks_to_show)
  frames = paint_points(frames, pred_tracks[0: num_tracks])
  display_video(frames, video_item['metadata']['frame_rate'])
  del frames

In [None]:
# @title Comparing ground truth labels vs model outputs
if example_data['point_tracking']:

  track_to_compare = 1  # @param {type:"integer"}
  if track_to_compare > len(example_data['point_tracking']):
    raise ValueError(f'Track {track_to_compare} is not in the video')

  frames = get_video_frames(example_data, video_path)
  frames = paint_point(frames, example_data['point_tracking'][track_to_compare],
                    color=COLORS[0], label=': gt')
  frames = paint_point(frames, pred_tracks[track_to_compare], color=COLORS[1],
                    label=': pred')

  display_video(frames, video_item['metadata']['frame_rate'])
  del frames