<a href="https://colab.research.google.com/github/deepmind/perception_test/blob/main/baselines/grounded_vqa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates how to load the grounded video question answering annotations in the validation split of the Perception Test and run the evaluation for a simple baseline model, which first runs MDETR on the middle frame of a video conditioned on the question, and then assumes static boxes throughout the video.

Copyright 2023 DeepMind Technologies Limited

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at [https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0).
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Grounded Video Question Answering Baseline for the Perception Test

Github: https://github.com/deepmind/perception_test

## The Perception Test
[Perception Test: A Diagnostic Benchmark for Multimodal Video Models](https://arxiv.org/abs/2305.13786) is a multimodal benchmark designed to comprehensively evaluate the perception and reasoning skills of multimodal video models. The Perception Test dataset introduces real-world videos designed to show perceptually interesting situations and defines multiple computational tasks (object and point tracking, action and sound localisation, multiple-choice and grounded video question-answering). Here, we provide details and a simple baseline for the grounded video question-answering task.

[![Perception Test Overview Presentation](https://img.youtube.com/vi/8BiajMOBWdk/maxresdefault.jpg)](https://youtu.be/8BiajMOBWdk?t=10)

##Grounded Video Question Answering
This task is similar to conditional multiple-object tracking,
with the conditioning given as a language task or question as opposed to a class label. The answers are object tracks defined throughout the video.

The below image shows an example of video-question-answer tuple for the grounded video QA task.

![image collage](https://storage.googleapis.com/dm-perception-test/img/gqa.png)

## MDETR-Static baseline
This notebook demonstrates how to load the annotations in the validation split of the Perception Test, and run the evaluation for a simple model where the objects to track are detected by MDETR on the middle frame of the video, conditioned by the question. The detections are then kept static throughout the video.

In [None]:
# @title Setup and install

! pip install timm transformers

!git clone --recursive https://github.com/facebookresearch/multimodal.git multimodal
%cd multimodal
!pip install -e .
%cd ..
!cp -r /content/multimodal/examples/mdetr/data .
!cp -r /content/multimodal/examples/mdetr/utils .

# VOT toolkit required for evaluation
!pip install git+https://github.com/votchallenge/vot-toolkit-python

# TrackEval required for HOTA
!git clone https://github.com/JonathonLuiten/TrackEval.git
%cd TrackEval
!pip install -e .
%cd ..

# RESTART RUNTIME AFTER THESE COMMANDS

In [1]:
# @title Prerequisites
import collections
import colorsys
import json
import os
import random
from typing import Tuple, List, Dict, Any
import zipfile

import cv2

from google.colab.patches import cv2_imshow
import imageio
import matplotlib.pyplot as plt
import moviepy.editor as mvp
import numpy as np
from PIL import Image
import requests
import torch
from torch import nn
from torch.utils.data import Dataset
from torchmultimodal.models.mdetr.model import mdetr_for_vqa
from torchvision.ops.boxes import box_convert
import torchvision.transforms as T
from trackeval.metrics import HOTA
from transformers import RobertaTokenizerFast

In [2]:
# @title Utility functions
def download_and_unzip(url: str, destination: str):
  """Downloads and unzips a .zip file to a destination.

  Downloads a file from the specified URL, saves it to the destination
  directory, and then extracts its contents.

  If the file is larger than 1GB, it will be downloaded in chunks,
  and the download progress will be displayed.

  Args:
    url (str): The URL of the file to download.
    destination (str): The destination directory to save the file and
      extract its contents.
  """
  if not os.path.exists(destination):
    os.makedirs(destination)

  filename = url.split('/')[-1]
  file_path = os.path.join(destination, filename)

  if os.path.exists(file_path):
    print(f'{filename} already exists. Skipping download.')
    return

  response = requests.get(url, stream=True)
  total_size = int(response.headers.get('content-length', 0))
  gb = 1024*1024*1024

  if total_size / gb > 1:
    print(f'{filename} is larger than 1GB, downloading in chunks')
    chunk_flag = True
    chunk_size = int(total_size/100)
  else:
    chunk_flag = False
    chunk_size = total_size

  with open(file_path, 'wb') as file:
    for chunk_idx, chunk in enumerate(
        response.iter_content(chunk_size=chunk_size)):
      if chunk:
        if chunk_flag:
          print(f"""{chunk_idx}% downloading
          {round((chunk_idx*chunk_size)/gb, 1)}GB
          / {round(total_size/gb, 1)}GB""")
        file.write(chunk)
  print(f"'{filename}' downloaded successfully.")

  with zipfile.ZipFile(file_path, 'r') as zip_ref:
    zip_ref.extractall(destination)
  print(f"'{filename}' extracted successfully.")

  os.remove(file_path)


def load_db_json(db_file: str) -> Dict[str, Any]:
  """Loads a JSON file as a dictionary.

  Args:
    db_file (str): Path to the JSON file.

  Returns:
    Dict: Loaded JSON data as a dictionary.

  Raises:
    FileNotFoundError: If the specified file doesn't exist.
    TypeError: If the JSON file is not formatted as a dictionary.
  """
  if not os.path.isfile(db_file):
    raise FileNotFoundError(f'No such file: {db_file}')

  with open(db_file, 'r') as f:
    db_file_dict = json.load(f)
    if not isinstance(db_file_dict, dict):
      raise TypeError('JSON file is not formatted as a dictionary.')
    return db_file_dict


def load_mp4_to_frames(filename: str) -> np.array:
  """Loads an MP4 video file and returns its frames as a NumPy array.

  Args:
    filename (str): Path to the MP4 video file.

  Returns:
    np.array: Frames of the video as a NumPy array.
  """
  assert os.path.exists(filename), f'File {filename} does not exist.'
  cap = cv2.VideoCapture(filename)

  num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))

  vid_frames = np.empty((num_frames, height, width, 3), dtype=np.uint8)

  idx = 0
  while True:
    ret, frame = cap.read()
    if not ret:
      break

    vid_frames[idx] = frame
    idx += 1

  cap.release()
  return vid_frames


def get_video_frames(data_item: Dict[str, Any], video_path: str) -> np.array:
  """Loads frames of a video specified by an item dictionary.

  Assumes format of annotations used in the Perception Test Dataset.

  Args:
  	data_item (Dict): Item from dataset containing metadata.
    video_path (str): Path to the directory containing videos.

  Returns:
    np.array: Frames of the video as a NumPy array.
  """
  video_file_path = os.path.join(video_path,
                                 data_item['metadata']['video_id']) + '.mp4'
  vid_frames = load_mp4_to_frames(video_file_path)
  assert data_item['metadata']['num_frames'] == vid_frames.shape[0]
  return vid_frames


def load_single_frame(filename: str, frame_idx) -> np.array:
  """Loads an MP4 video file and returns a single frame as a NumPy array.

  This function loads a specific frame from the given MP4 video file and returns
  it as a NumPy array.

  Args:
    filename (str): Path to the MP4 video file.
    frame_idx (int): The index of the frame to be loaded.

  Returns:
    np.array: A single frame of the video as a NumPy array.

  Raises:
    AssertionError: If the given file does not exist.
  """
  assert os.path.exists(filename), f'File {filename} does not exist.'
  cap = cv2.VideoCapture(filename)
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
  _, frame = cap.read()
  cap.release()
  return frame

In [None]:
# # @title Download data
data_path = './data/'

split = 'valid'  # @param ['sample', 'valid']

if split == 'valid':
  # full validaton set - takes approx 30 mins to download
  valid_annot_url = 'https://storage.googleapis.com/dm-perception-test/zip_data/grounded_question_valid_annotations.zip'
  download_and_unzip(valid_annot_url, data_path)
  valid_videos_url = 'https://storage.googleapis.com/dm-perception-test/zip_data/grounded_question_valid_videos.zip'
  download_and_unzip(valid_videos_url, data_path)

elif split == 'sample':
  # sample annotations and videos (small subset for demo)
  sample_annot_url = 'https://storage.googleapis.com/dm-perception-test/zip_data/sample_annotations.zip'
  download_and_unzip(sample_annot_url, data_path)
  sample_videos_url = 'https://storage.googleapis.com/dm-perception-test/zip_data/sample_videos.zip'
  download_and_unzip(sample_videos_url, data_path)

In [4]:
# @title Dataset class
class PerceptionGQADataset(Dataset):
  """Dataset class to store video items from dataset.

  Attributes:
    video_folder_path: Path to the folder containing the videos.
    task: Task type for annotations.
    split: Dataset split to load.
  	task_db: List containing annotations for dataset according to
  		split and task availability.
  """

  def __init__(self, db_path: Dict[str, Any], video_folder_path: str,
               task: str, split: str) -> None:
    """Initializes the PerceptionDataset class.

    Args:
      db_path (str): Path to the annotation file.
      video_folder_path (str): Path to the folder containing the videos.
      task (str): Task type for annotations.
      split (str): Dataset split to load.
    """
    self.video_folder_path = video_folder_path
    self.task = task
    self.split = split
    self.task_db = self.load_dataset(db_path)

  def load_dataset(self, db_path: str) -> List:
    """Loads the dataset from the annotation file and processes.

    Dict is processed according to split and task.

    Args:
      db_path (str): Path to the annotation file.

    Returns:
      List: List of database items containing annotations.
    """
    db_dict = load_db_json(db_path)
    db_list = []
    for _, val in db_dict.items():
      if val['metadata']['split'] == self.split:
        if val[self.task]:  # If video has annotations for this task
          db_list.append(val)

    return db_list

  def __len__(self) -> int:
    """Returns the total number of videos in the dataset.

    Returns:
      int: Total number of videos.
    """
    return len(self.pt_db_list)

  def __getitem__(self, idx: int) -> Dict[str, Any]:
    """Returns the video and annotations for a given index.

    Args:
      idx (int): Index of the video.

    Returns:
      Dict: Dictionary containing the video frames, metadata, annotations.
    """
    data_item = self.task_db[idx]
    annot = data_item[self.task]

    metadata = data_item['metadata']
    frame_idx = round(data_item['metadata']['num_frames']/2)
    video_file_path = os.path.join(self.video_folder_path,
                                   data_item['metadata']['video_id']) + '.mp4'
    frame = load_single_frame(video_file_path, frame_idx)
    frame = torch.tensor(frame[:,:,::-1].copy())
    frame = frame.permute(2, 0, 1)[None, ...].float()/255.

    vid_frames = np.zeros((metadata['num_frames'], 1, 1, 1))

    return {'metadata': metadata,
            'grounded_question': annot,
            'object_tracking': data_item['object_tracking'],
            'vqa_frame': frame,
            'frames': vid_frames}

In [5]:
# @title Object tracking model (static)
class ObjectTracker():
  """Object tracker class that tracks a given object in a video.

  This model assumes static boxes, given the first bounding box
  which should be tracked in a sequence, for every frame in the
  remaining sequence it will return the same coordinates.

  """

  def __init__(self):
    """Initializes the ObjectTracker class."""
    pass

  def track_object_in_video(self, frames: np.array, start_info: Dict[str, Any]
                            )-> Dict[str, Any]:
    """Tracks an object in a video.

    Tracks an object given a sequence of frames and initial information about
    the coordinates and frame ID of the object.

    Args:
      frames (np.array): Array of frames representing the video.
      start_info (Dict): Dictionary containing the start bounding box and
        frame ID.

    Returns:
      Dict[str: List]: Dictionary containing the tracked bounding boxes and
        corresponding frame IDs.
    """
    # initially take starting bounding box for tracking
    prev_bb = start_info['start_bounding_box']
    output_bounding_boxes = []
    output_frame_ids = []

    for frame_id in range(start_info['start_id'], frames.shape[0]):
      frame = frames[frame_id]
      # here is where the per frame tracking is done by the model
      # we just return the starting coords in this dummy baseline
      bb = self.track_object_in_frame(frame, prev_bb)
      output_bounding_boxes.append(bb)
      output_frame_ids.append(frame_id)

    output_bounding_boxes = np.stack(output_bounding_boxes, axis=0)
    output_frame_ids = np.array(output_frame_ids)
    return output_bounding_boxes, output_frame_ids

  # model inference would be inserted here!!
  def track_object_in_frame(self, frame: np.array,
                            prev_bb: List[float]) -> List[float]:
    """Tracks an object in a single frame.

    Tracks an object in a single frame based on the previous bounding box
    coordinates. Placeholder function that just returns the coords it is given,
    assumes a static object.

    Args:
      frame (np.array): The current frame.
      prev_bb(List): Previous bounding box coordinates. (y2,x2,y1,x1)

    Returns:
      List: The tracked bounding box coordinates in the current frame.
    """
    del frame  # unused
    return prev_bb

In [None]:
# @title Evaluation functions
def get_start_frame(track_arr: List[List[float]]) -> int:
  """Returns index of the first non-zero element in a track array.

  Args:
    track_arr (list): one hot vector correspoinding to annotations,
      showing which index to start tracking .

  Returns:
    int: Index of the first non-zero element in the track array.

  Raises:
    ValueError: Raises error if the length of the array is 0
      or if there is no one-hot value.
  """
  if not track_arr or np.count_nonzero(track_arr) == 0:
    raise ValueError('Track is empty or has no non-zero elements')
  return np.nonzero(track_arr)[0][0]


def get_start_info(track: Dict[str, Any]) -> Dict[str, Any]:
  """Retrieve information about the start frame of a track.

  Args:
    track (Dict): A dictionary containing information about the track.

  Returns:
    Dict[str: Any]: A dictionary with the following keys:
      'start_id': The frame ID of the start frame.
      'start_bounding_box': The bounding box coordinates of the start
        frame.
      'start_idx': The index of the start frame in the
      'bounding_boxes' list.
  """
  track_start_idx = get_start_frame(track['initial_tracking_box'])
  track_start_id = track['frame_ids'][track_start_idx]
  track_start_bb = track['bounding_boxes'][track_start_idx]

  return {'start_id': track_start_id,
          'start_bounding_box': track_start_bb,
          'start_idx': track_start_idx}


def filter_pred_boxes(pred_bb: np.ndarray, pred_fid: np.ndarray,
                      gt_fid: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
  """Filter bounding boxes and frame IDs based on ground truth frame IDs.

  Args:
    pred_bb (np.ndarray): Array of predicted bounding boxes.
    pred_fid (np.ndarray): Array of frame IDs for predicted bounding boxes.
    gt_fid (np.ndarray): Array of frame IDs for ground truth bounding boxes.

  Returns:
    Tuple[np.ndarray, np.ndarray]: Filtered predicted bounding boxes and
    	their corresponding frame IDs.
  """
  pred_idx = np.isin(pred_fid, gt_fid).nonzero()[0]
  filter_pred_bb = pred_bb[pred_idx]
  filter_pred_fid = pred_fid[pred_idx]
  return filter_pred_bb, filter_pred_fid


def rescale_tensor_boxes(boxes: torch.Tensor,
                         size: Tuple[int, int]) -> torch.Tensor:
  """Rescales predicted boxes to match the image size.

  Args:
    boxes (torch.Tensor): Tensor of bounding boxes in the format
      "cxcywh" (center x, center y, width, height).
    size (Tuple): A tuple representing the height and width of the
      image.

  Returns:
    torch.Tensor: Tensor of rescaled bounding boxes in the format "xyxy"
      (xmin, ymin, xmax, ymax).
  """
  img_h, img_w = size
  b = box_convert(boxes, 'cxcywh', 'xyxy')
  b = b * torch.tensor([img_w, img_h, img_w, img_h],
                       dtype=torch.float32).to(DEVICE)
  return b


def rescale_list_boxes(box: List[float], size: Tuple[int, int]) -> List[float]:
  """Rescales a list bounding box to match the image size.

  Args:
    box (List): List of bounding box coordinates
      [xmin, ymin, xmax, ymax].
    size (Tuple): A tuple representing the height and width of the
      image.

  Returns:
    List: List of rescaled bounding box coordinates
      [xmin_rescaled, ymin_rescaled, xmax_rescaled, ymax_rescaled].
  """
  [img_h,img_w] = size
  return [box[0] / img_w, box[1] / img_h, box[2] / img_w, box[3] / img_h]


def postprocess(outputs: Any, text: str, tokenized: Any,
                         img_size: Tuple[int, int], keep_prob: float = 0.7
                         ) -> Tuple[List[float], List[torch.Tensor], List[str]]:
  """Postprocessing to rescale boxes and extract predicted spans of tokens.

  Args:
    outputs (Any): The outputs from the model.
    text (str): The original text of the question.
    tokenized (Any): The tokenized version of the text.
    img_size (Tuple): A tuple representing the height and width of
      the image.
    keep_prob (float, optional): Probability threshold for keeping predicted
      boxes. Defaults to 0.7.

  Returns:
    Tuple[List[float], List[torch.Tensor], List[str]]: A tuple containing:
        - List of probabilities for each predicted box.
        - List of rescaled bounding boxes for each predicted box.
        - List of labels for each predicted box.

  Note:
    The type of 'outputs' and 'tokenized' parameters should have specific
      attributes and methods, as used in the function.
  """
  model_outputs = outputs.model_output
  probs = 1 - model_outputs.pred_logits.softmax(-1)[0, :, -1]
  keep = (probs > keep_prob)

  # convert boxes from [0; 1] to image scales
  boxes_scaled = rescale_tensor_boxes(model_outputs.pred_boxes[0, keep],
                                      img_size)

  # Extract the text spans predicted by each box
  positive_tokens = model_outputs.pred_logits[0, keep].softmax(-1) > 0.1
  positive_tokens = positive_tokens.nonzero().tolist()
  pred_spans = collections.defaultdict(str)
  for tok in positive_tokens:
    item, pos = tok
    if pos < 255:
      span = tokenized.token_to_chars(0, pos)
      if span is not None:
        pred_spans[item] += ' ' + text[span.start:span.end]

  probs = probs[keep]
  probs = [probs[int(k)] for k in sorted(list(pred_spans.keys()))]
  boxes_scaled = [boxes_scaled[int(k)] for k in sorted(list(pred_spans.keys()))]
  labels = [pred_spans[k] for k in sorted(list(pred_spans.keys()))]

  return probs, boxes_scaled, labels


def search_frame_ids(frame_id_dict: Dict[int, List[int]]
                    ) -> Dict[int, Any]:
  """Search for frame IDs in tracks to get per Frame ID track ID dict.

  Args:
    frame_id_dict (Dict): A dictionary containing track IDs as
      keys and lists of corresponding frame IDs as values.

  Returns:
    Dict[int, List]: A combined dictionary with frame IDs as keys and
      lists of track IDs containing each frame ID as values.
  """
  combined_id_dict = {}

  for k, v in frame_id_dict.items():
    for i in v:
      if i not in combined_id_dict:
        combined_id_dict[i] = [k]
      else:
        combined_id_dict[i].append(k)

  return combined_id_dict


def build_gt_ids(tracks: List[Dict[str, Any]], gt: bool,
                 track_ids: List[int] = None) -> Dict[int, List[int]]:
  """Build ground truth track IDs dict based on input tracks and track IDs.

  Args:
    tracks (List): A list of track
      dictionaries, each containing 'id' and 'frame_ids' as keys.
    gt (bool): A boolean indicating if ground truth track IDs are to be built.
    track_ids (List, optional): A list of track IDs to be considered.
      If None, all tracks will be considered. Defaults to None.

  Returns:
    Dict: A dictionary containing track IDs as keys and lists
      of corresponding frame IDs as values.
  """
  track_dict = {}
  if track_ids is None:
    nid = 0
    for track in tracks:
      if gt:
        start_info = get_start_info(track)
        track_dict[nid] = track['frame_ids'][start_info['start_idx']:]
      else:
        track_dict[nid] = track['frame_ids']
      nid += 1

  else:
    nid = 0
    for t in track_ids:
      track = tracks[t]
      assert track['id'] == t
      if gt:
        start_info = get_start_info(track)
        track_dict[nid] = track['frame_ids'][start_info['start_idx']:]
      else:
        track_dict[nid] = track['frame_ids']
      nid += 1

  frame_id_dict = search_frame_ids(track_dict)
  return frame_id_dict


def calculate_iou(boxes1: np.array, boxes2: np.array) -> float:
  """Calculate Intersection over Union (IoU) for two sets of bounding boxes.

  Args:
    boxes1 (np.array): Bounding boxes in the format [y2, x2, y1, x1],
      shape (n, 4)
    boxes2 (np.array): Bounding boxes in the format [y2, x2, y1, x1],
      shape (n, 4)

  Returns:
    iou (float): Intersection over Union (IoU) float value
  """
  x1_1, y1_1, x2_1, y2_1 = np.split(boxes1, 4, axis=1)
  x1_2, y1_2, x2_2, y2_2 = np.split(boxes2, 4, axis=1)

  # Find intersection coordinates
  y1_inter = np.maximum(y1_1, y1_2)
  x1_inter = np.maximum(x1_1, x1_2)
  y2_inter = np.minimum(y2_1, y2_2)
  x2_inter = np.minimum(x2_1, x2_2)

  # Calculate area of intersection
  h_inter = np.maximum(0, y2_inter - y1_inter)
  w_inter = np.maximum(0, x2_inter - x1_inter)
  area_inter = h_inter * w_inter

  # Calculate area of union
  area_boxes1 = (y2_1 - y1_1) * (x2_1 - x1_1)
  area_boxes2 = (y2_2 - y1_2) * (x2_2 - x1_2)
  union = area_boxes1 + area_boxes2 - area_inter

  return area_inter / union


def top_k_tracks(tracks: List[Dict[str, Any]], max_num: int
                 ) -> List[Dict[str, Any]]:
  """Select the top-k tracks based on their score.

  Args:
    tracks (List): A list of dictionaries representing tracks with 'score' as
      one of the keys.
    max_num (int): The maximum number of tracks to select.

  Returns:
    List: A list of the top-k tracks with the highest scores.
  """
  sorted_tracks = sorted(tracks, key=lambda d: d['score'], reverse=True)
  return sorted_tracks[0:max_num]


def run_iou(results: Dict[Any, Any], label_dict: Dict[Any, Any],
            max_num_tracks: int = 10) -> Dict[Any, Any]:
  """Run IoU calculations for ground truth and predicted tracks.

  Args:
    results (Dict]): A dictionary containing results for different videos,
      with video IDs as keys and video results as values.
    label_dict (Dict): A dictionary containing label information for different
      videos, with video IDs as keys and label data as values.
    max_num_tracks (int, optional): The maximum number of tracks to consider.
      Defaults to 10.

  Returns:
    Dict: A dictionary containing video data with IoU scores.
  """
  data = {}
  for video_id, video_results in results.items():
    video_data = {}
    gt_tracks = label_dict[video_id]['object_tracking']
    gt_qs = label_dict[video_id]['grounded_question']
    pred_answers = video_results['grounded_question']

    for q in gt_qs:
      q_data = {}
      answer_gt_tracks = [gt_tracks[t] for t in q['answers']]
      answer_pred_tracks = pred_answers[q['id']]
      if len(answer_pred_tracks) > max_num_tracks:
        answer_pred_tracks = top_k_tracks(answer_pred_tracks, max_num_tracks)

      gt_frame_id_dict = build_gt_ids(gt_tracks, True, q['answers'])
      pred_frame_id_dict = build_gt_ids(answer_pred_tracks, False)

      q_data['num_gt_ids'] = len(q['answers'])
      q_data['num_gt_dets'] = (
          sum([len(gt_tracks[t]['bounding_boxes']) for t in q['answers']])
      )
      q_data['gt_ids'] = [np.array(x) for x in list(gt_frame_id_dict.values())]

      q_data['num_tracker_ids'] = len(answer_pred_tracks)
      q_data['num_tracker_dets'] = (
          sum([len(t['bounding_boxes']) for t in answer_pred_tracks])
      )
      q_data['tracker_ids'] = (
          [np.array(x) for x in list(pred_frame_id_dict.values())]
      )

      sim_score_dict = {k: [] for k in list(gt_frame_id_dict.keys())}
      sim_scores = []
      for gt_track in answer_gt_tracks:
        track_sim_scores = []
        for pred_track in answer_pred_tracks:
          start_info = get_start_info(gt_track)
          start_idx = start_info['start_idx']
          gt_bb = np.array(gt_track['bounding_boxes'])[start_idx:]
          gt_fid = gt_track['frame_ids'][start_idx:]

          # case where only one box is labelled
          if not gt_fid:
            continue

          pred_bb = np.array(pred_track['bounding_boxes'])
          pred_fid = np.array(pred_track['frame_ids'])
          # filter predicted trajectory for frame IDs where we have annotations
          pred_bb, pred_fid = filter_pred_boxes(pred_bb, pred_fid, gt_fid)
          # clip 0 -> 1
          pred_bb = np.minimum(np.maximum(pred_bb, 0), 1)
          iou = calculate_iou(gt_bb, pred_bb)
          track_sim_scores.append(iou)
          for frame_iou, frame_id in zip(iou, pred_fid):
            sim_score_dict[frame_id].append(frame_iou.item())

        sim_scores.append(np.array(track_sim_scores))

      if q_data['num_tracker_dets'] == 0 or q_data['num_gt_dets'] == 0:
        sim_scores = []
      else:
        sim_scores = []
        for idx, frame_scores in enumerate(sim_score_dict.values()):
          sim_scores.append(np.array(frame_scores).reshape(
              (len(q_data['gt_ids'][idx]), len(q_data['tracker_ids'][idx])))
                           )

      q_data['similarity_scores'] = sim_scores
      video_data[q['id']] = q_data

    data[video_id] = video_data
  return data


def eval_sequences(evaluator: Any, data: Dict[Any, Any]) -> Dict[str, float]:
  """Evaluate sequences using the evaluator.

  Args:
    evaluator (Any): The evaluator object to be used for evaluation.
    data (Dict): A dictionary containing the data to be evaluated, with video
      IDs as keys and video results as values.

  Returns:
    Dict: A dictionary containing the average evaluation results for the
      sequences, with keys 'HOTA', 'DetA', 'AssA', and 'LocA'.
  """
  hota = []
  deta = []
  assa = []
  loca = []
  for video_results in data.values():
    for question_results in video_results.values():
      res = evaluator.eval_sequence(question_results)
      hota.append(np.mean(res['HOTA']))
      deta.append(np.mean(res['DetA']))
      assa.append(np.mean(res['AssA']))
      loca.append(np.mean(res['LocA']))

  ave_hota = np.mean(hota)
  print('HOTA: ', ave_hota)
  ave_deta = np.mean(deta)
  print('DetA: ', ave_deta)
  ave_assa = np.mean(assa)
  print('AssA: ', ave_assa)
  ave_loca = np.mean(loca)
  print('LocA: ', ave_loca)

  return {'HOTA': ave_hota, 'DetA': ave_deta,
          'AssA': ave_assa, 'LocA': ave_loca}

In [7]:
if split == 'sample':
  label_path = './data/sample.json'

elif split == 'valid':
  label_path = './data/grounded_question_valid.json'

# @title Initialise dataset
cfg = {'video_folder_path': './data/videos/',
       'task': 'grounded_question',
       'split': 'valid'}

# init dataset
gqa_dataset = PerceptionGQADataset(label_path, **cfg)

In [None]:
# @title Initialise detector and tracker models
DEVICE = 'cuda'
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

# The default values correspond to the GQA dataset.
mdetr_vqa = mdetr_for_vqa()
# But to perform VQA on another dataset, one can simply pass a
# different set of heads, e.g.
other_vqa_heads = nn.ModuleDict({'head1': nn.Linear(256, 3),
                                 'head2': nn.Linear(256, 12)})
mdetr_other_vqa_dataset = mdetr_for_vqa(vqa_heads=other_vqa_heads)

checkpoint_url = 'https://pytorch.s3.amazonaws.com/models/multimodal/mdetr/gqa_resnet101_checkpoint.pth'
checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url,
                                                map_location='cpu',
                                                check_hash=True)

mdetr_vqa.load_state_dict(checkpoint['model_ema'], strict=False)
mdetr_vqa.eval().to(DEVICE)

img_transform = T.Compose([
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')

# init tracking model
object_tracker = ObjectTracker()

In [None]:
# @title Get predictions for MDETR + static baseline
filter_outputs = True
results = {}
count_check = 25

for cidx, video_item in enumerate(gqa_dataset):
  video_id = video_item['metadata']['video_id']

  if (cidx%count_check) == 0:
    print(cidx, video_id)

  img_size = video_item['vqa_frame'].shape[2:]
  t_img = img_transform(video_item['vqa_frame'])
  gt_tracks = video_item['object_tracking']

  video_pred_tracks = {}
  for q in video_item['grounded_question']:

    combined_frame_id_dict = build_gt_ids(gt_tracks, q['answers'])
    tokenized = tokenizer.batch_encode_plus([q['question']], padding='longest',
                                            return_tensors='pt')
    outputs = mdetr_vqa(t_img.to(DEVICE), tokenized['input_ids'].to(DEVICE))
    probs, boxes, labels = postprocess(outputs, q['question'],
                                                tokenized, img_size)

    qid = 0
    question_results = []
    for prob, box, label in zip(probs, boxes, labels):
      box = rescale_list_boxes(box.detach().cpu().numpy(), img_size)
      start_info = {'start_id': 0,
                    'start_bounding_box': box,
                    'start_idx': 0}

      pred_bounding_boxes, pred_frame_ids = (
          object_tracker.track_object_in_video(video_item['frames'], start_info)
      )
      if filter_outputs:
        pred_bounding_boxes, pred_frame_ids = (
            filter_pred_boxes(pred_bounding_boxes, pred_frame_ids,
                              list(combined_frame_id_dict.keys()))
        )

      pred_track = {}
      pred_track['id'] = qid
      qid += 1
      pred_track['score'] = prob.item()
      pred_track['bounding_boxes'] = pred_bounding_boxes.tolist()
      pred_track['frame_ids'] = pred_frame_ids.tolist()
      question_results.append(pred_track)

    video_pred_tracks[q['id']] = question_results

  results[video_id] = {'grounded_question': video_pred_tracks}

In [None]:
# @title Evaluate predictions
label_path = './data/grounded_question_valid.json'
label_dict = load_db_json(label_path)
data = run_iou(results, label_dict)

hota_evaluator = HOTA()
eval_results = eval_sequences(hota_evaluator, data)

In [None]:
# @title Serialise example results file
# Writing model outputs in the expected competition format. This
# JSON file contains answers to all questions in the validation split in the
# format:


# {'video_8913': {'grounded_question': {5: [{'id': 0, 'score': 0.875480055809021,
#     'bounding_boxes': [[0.37, 0.00, 0.83, 0.44],...],
#     'frame_ids': [0, 30, 32, 37, 60, 90, 120, 150,...]},...]}}}

# This file could be used directly as a submission for the Eval.ai challenge
with open(f'{cfg["task"]}_{split}_results.json', 'w') as my_file:
  json.dump(results, my_file)

In [12]:
# @title Visualisation functions
def make_plot(pil_img: Any, scores: List[float], boxes: List[torch.Tensor],
              labels: List[str]) -> None:
  """Plotting utility function.

  Args:
    pil_img (Any): The PIL image to be plotted (could be of any type,
      but should be compatible with 'np.array').
    scores (List): List of probabilities for each bounding box.
    boxes (List): List of bounding boxes in torch.Tensor format.
    labels (List): List of labels for each bounding box.

  Returns:
    None: The function plots the image with bounding boxes and labels using
      matplotlib.
  """
  plt.figure(figsize=(16, 10))
  np_image = np.array(pil_img)
  ax = plt.gca()
  colors = COLORS * 100
  assert len(scores) == len(boxes) == len(labels)
  for s, box, l, c in zip(scores, boxes, labels, colors):
    (xmin, ymin, xmax, ymax) = box.detach().cpu().numpy()
    ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                               fill=False, color=c, linewidth=3))
    text = f'{l}: {s:0.2f}'
    ax.text(xmin, ymin, text, fontsize=15,
            bbox=dict(facecolor='white', alpha=0.8))

  plt.imshow(np_image)
  plt.axis('off')
  plt.show()


def get_colors(num_colors: int) -> Tuple[int, int, int]:
  """Generate random colormaps for visualizing different objects.

  Args:
    num_colors (int): The number of colors to generate.

  Returns:
    Tuple[int, int, int]: A tuple of RGB values representing the
      generated colors.
  """
  colors = []
  for i in np.arange(0., 360., 360. / num_colors):
    hue = i / 360.
    lightness = (50 + np.random.rand() * 10) / 100.
    saturation = (90 + np.random.rand() * 10) / 100.
    color = colorsys.hls_to_rgb(hue, lightness, saturation)
    color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
    colors.append(color)
  random.seed(0)
  random.shuffle(colors)
  return colors


def display_video(vid_frames: np.array, fps: int = 30):
  """Create and display temporary video from numpy array frames.

  Args:
    vid_frames: (np.array): The frames of the video as a
    	numpy array. Format of frames should be:
    	(num_frames, height, width, channels)
    fps (int): Frames per second for the video playback. Default is 30.
  """
  kwargs = {'macro_block_size': None}
  imageio.mimwrite('tmp_video_display.mp4',
                   vid_frames[:, :, :, ::-1], fps=fps, **kwargs)
  display(mvp.ipython_display('tmp_video_display.mp4'))


def display_frame(frame: np.array):
  """Display a frame, converting from RGB to BGR for cv2.

  Args:
    frame (np.array): The frame to be displayed.
  """
  cv2_imshow(frame)


def paint_box(video: np.array, track: Dict[str, Any],
              color: Tuple[int, int, int] = (255, 0, 0),
              addn_label: str = '') -> np.array:
  """Paint bounding box and label on video for a given track.

  Args:
    video (np.array): The video frames as a numpy array.
    track (Dict): The track information containing bounding box
    and frame information, assumes Perception Test Dataset format.
    color (Tuple[int, int, int]): The RGB color values for the bounding box.
      Default is red (255, 0, 0).
    addn_label (str): Additional label to be added to the track label.
      Default is an empty string.

  Returns:
    np.array: The modified video frames with painted bounding box and
      label.
  """
  _, height, width, _ = video.shape
  name = str(track['id']) + ' : ' + track['label'] + addn_label
  bounding_boxes = np.array(track['bounding_boxes'])

  for box, frame_id in zip(bounding_boxes, track['frame_ids']):
    frame = np.array(video[frame_id])
    x1 = int(round(box[0] * width))
    y1 = int(round(box[1] * height))
    x2 = int(round(box[2] * width))
    y2 = int(round(box[3] * height))
    frame = cv2.rectangle(frame, (x1, y1), (x2, y2),
                          color=color, thickness=2)
    frame = cv2.putText(frame, name, (x1, y1 + 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.75, color, 2)
    video[frame_id] = frame

  return video


def paint_boxes(video: np.array, tracks: List[Dict],
                colors: Tuple[int, int, int]) -> np.array:
  """Paint bounding boxes and labels on a video for multiple tracks.

  Args:
    video (np.array): The video frames as a numpy array.
    tracks (List): A list of track information,
      where each track contains bounding box and frame information.
    colors (Tuple): Tuple containing randomly generated RGB color values.

  Returns:
    np.array: The modified video frames with painted bounding boxes
      and labels.
  """
  for i, track in enumerate(tracks):
    video = paint_box(video, track, colors[i])
  return video


def get_answer_tracks(ex_data: dict, goq_ids: List) -> List[dict]:
  """Filters and retrieves object tracks based on the given object ids.

  Args:
    ex_data (dict): The data containing object tracking information.
    goq_ids (List): The list of IDs to filter tracks.

  Returns:
    List[dict]: The filtered tracks matching the goq_ids.
  """
  goq_tracks = []
  for track in ex_data['object_tracking']:
    if track['id'] in goq_ids:
      goq_tracks.append(track)
  return goq_tracks

In [None]:
# @title Ground truth annotations visualised
colors = get_colors(num_colors=100)

# sample annotations and videos to visualise the annotations later
sample_annot_url = 'https://storage.googleapis.com/dm-perception-test/zip_data/sample_annotations.zip'
download_and_unzip(sample_annot_url, data_path)

sample_videos_url = 'https://storage.googleapis.com/dm-perception-test/zip_data/sample_videos.zip'
download_and_unzip(sample_videos_url, data_path)

# get sample video info to showcase the annotations
sample_db_path = './data/sample.json'
sample_db_dict = load_db_json(sample_db_path)
video_id = list(sample_db_dict.keys())[7]
example_data = sample_db_dict[video_id]

if example_data['grounded_question']:
  question = example_data['grounded_question'][0]
  print('---------------------------------')
  print('Question: ', question['question'])
  print('Answer IDs: ', question['answers'])
  print('Question info: ')
  print('Reasoning: ', question['reasoning'])
  print('area: ', question['area'])
  print('---------------------------------')

  frames = get_video_frames(example_data, cfg['video_folder_path'])
  answer_tracks = get_answer_tracks(example_data, question['answers'])
  frames = paint_boxes(frames, answer_tracks, colors)

  annotated_frames = []
  for frame_idx in answer_tracks[0]['frame_ids']:
    annotated_frames.append(frames[frame_idx])

  annotated_frames = np.array(annotated_frames)
  display_video(annotated_frames, 1)
  del frames