-
Notifications
You must be signed in to change notification settings - Fork 7.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #3868 Add a vanilla Hungarian tracker Reviewed By: zhanghang1989 Differential Revision: D32690383 fbshipit-source-id: dc4b8579fdc6541fd48d2ad53a3d3f8d3f601422
- Loading branch information
1 parent
8ba4dd8
commit e767c5a
Showing
6 changed files
with
415 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| #!/usr/bin/env python3 | ||
| from detectron2.structures import Instances | ||
| import numpy as np | ||
| from typing import List | ||
|
|
||
|
|
||
| def create_prediction_pairs( | ||
| instances: Instances, | ||
| prev_instances: Instances, | ||
| iou_all: np.ndarray, | ||
| threshold: float = 0.5, | ||
| ) -> List: | ||
| """ | ||
| Args: | ||
| instances: predictions from current frame | ||
| prev_instances: predictions from previous frame | ||
| iou_all: 2D numpy array containing iou for each bbox pair | ||
| threshold: below the threshold, doesn't consider the pair of bbox is valid | ||
| Return: | ||
| List of bbox pairs | ||
| """ | ||
| bbox_pairs = [] | ||
| for i in range(len(instances)): | ||
| for j in range(len(prev_instances)): | ||
| if iou_all[i, j] < threshold: | ||
| continue | ||
| bbox_pairs.append( | ||
| { | ||
| "idx": i, | ||
| "prev_idx": j, | ||
| "prev_id": prev_instances.ID[j], | ||
| "IoU": iou_all[i, j], | ||
| "prev_period": prev_instances.ID_period[j], | ||
| } | ||
| ) | ||
| return bbox_pairs |
131 changes: 131 additions & 0 deletions
131
detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| #!/usr/bin/env python3 | ||
| # Copyright 2004-present Facebook. All Rights Reserved. | ||
|
|
||
| from typing import List | ||
|
|
||
| import numpy as np | ||
| from detectron2.structures import Instances | ||
| from detectron2.structures.boxes import pairwise_iou | ||
| from detectron2.tracking.utils import create_prediction_pairs | ||
|
|
||
| from .base_tracker import TRACKER_HEADS_REGISTRY | ||
| from .hungarian_tracker import BaseHungarianTracker | ||
| from detectron2.config import configurable, CfgNode as CfgNode_ | ||
|
|
||
|
|
||
| @TRACKER_HEADS_REGISTRY.register() | ||
| class VanillaHungarianBBoxIOUTracker(BaseHungarianTracker): | ||
| """ | ||
| Hungarian algo based tracker using bbox iou as metric | ||
| """ | ||
|
|
||
| @configurable | ||
| def __init__( | ||
| self, | ||
| *, | ||
| video_height: int, | ||
| video_width: int, | ||
| max_num_instances: int = 200, | ||
| max_lost_frame_count: int = 0, | ||
| min_box_rel_dim: float = 0.02, | ||
| min_instance_period: int = 1, | ||
| track_iou_threshold: float = 0.5, | ||
| **kwargs | ||
| ): | ||
| """ | ||
| Args: | ||
| video_height: height the video frame | ||
| video_width: width of the video frame | ||
| max_num_instances: maximum number of id allowed to be tracked | ||
| max_lost_frame_count: maximum number of frame an id can lost tracking | ||
| exceed this number, an id is considered as lost | ||
| forever | ||
| min_box_rel_dim: a percentage, smaller than this dimension, a bbox is | ||
| removed from tracking | ||
| min_instance_period: an instance will be shown after this number of period | ||
| since its first showing up in the video | ||
| track_iou_threshold: iou threshold, below this number a bbox pair is removed | ||
| from tracking | ||
| """ | ||
| super().__init__( | ||
| video_height=video_height, | ||
| video_width=video_width, | ||
| max_num_instances=max_num_instances, | ||
| max_lost_frame_count=max_lost_frame_count, | ||
| min_box_rel_dim=min_box_rel_dim, | ||
| min_instance_period=min_instance_period | ||
| ) | ||
| self._track_iou_threshold = track_iou_threshold | ||
|
|
||
| @classmethod | ||
| def from_config(cls, cfg: CfgNode_): | ||
| """ | ||
| Old style initialization using CfgNode | ||
| Args: | ||
| cfg: D2 CfgNode, config file | ||
| Return: | ||
| dictionary storing arguments for __init__ method | ||
| """ | ||
| assert "VIDEO_HEIGHT" in cfg.TRACKER_HEADS | ||
| assert "VIDEO_WIDTH" in cfg.TRACKER_HEADS | ||
| video_height = cfg.TRACKER_HEADS.get("VIDEO_HEIGHT") | ||
| video_width = cfg.TRACKER_HEADS.get("VIDEO_WIDTH") | ||
| max_num_instances = cfg.TRACKER_HEADS.get("MAX_NUM_INSTANCES", 200) | ||
| max_lost_frame_count = cfg.TRACKER_HEADS.get("MAX_LOST_FRAME_COUNT", 0) | ||
| min_box_rel_dim = cfg.TRACKER_HEADS.get("MIN_BOX_REL_DIM", 0.02) | ||
| min_instance_period = cfg.TRACKER_HEADS.get("MIN_INSTANCE_PERIOD", 1) | ||
| track_iou_threshold = cfg.TRACKER_HEADS.get("TRACK_IOU_THRESHOLD", 0.5) | ||
| return { | ||
| "_target_": "detectron2.tracking.vanilla_hungarian_bbox_iou_tracker.VanillaHungarianBBoxIOUTracker", # noqa | ||
| "video_height": video_height, | ||
| "video_width": video_width, | ||
| "max_num_instances": max_num_instances, | ||
| "max_lost_frame_count": max_lost_frame_count, | ||
| "min_box_rel_dim": min_box_rel_dim, | ||
| "min_instance_period": min_instance_period, | ||
| "track_iou_threshold": track_iou_threshold | ||
| } | ||
|
|
||
| def build_cost_matrix(self, instances: Instances, prev_instances: Instances) -> np.ndarray: | ||
| """ | ||
| Build the cost matrix for assignment problem | ||
| (https://en.wikipedia.org/wiki/Assignment_problem) | ||
| Args: | ||
| instances: D2 Instances, for current frame predictions | ||
| prev_instances: D2 Instances, for previous frame predictions | ||
| Return: | ||
| the cost matrix in numpy array | ||
| """ | ||
| assert instances is not None and prev_instances is not None | ||
| # calculate IoU of all bbox pairs | ||
| iou_all = pairwise_iou( | ||
| boxes1=instances.pred_boxes, | ||
| boxes2=self._prev_instances.pred_boxes, | ||
| ) | ||
| bbox_pairs = create_prediction_pairs( | ||
| instances, | ||
| self._prev_instances, | ||
| iou_all, | ||
| threshold=self._track_iou_threshold | ||
| ) | ||
| # assign inf to make sure pair below IoU threshold won't be matched | ||
| cost_matrix = np.full((len(instances), len(prev_instances)), np.inf) | ||
| return self.assign_cost_matrix_values(cost_matrix, bbox_pairs) | ||
|
|
||
| def assign_cost_matrix_values(self, cost_matrix: np.ndarray, bbox_pairs: List) -> np.ndarray: | ||
| """ | ||
| Based on IoU for each pair of bbox, assign the associated value in cost matrix | ||
| Args: | ||
| cost_matrix: np.ndarray, initialized 2D array with target dimensions | ||
| bbox_pairs: list of bbox pair, in each pair, iou value is stored | ||
| Return: | ||
| np.ndarray, cost_matrix with assigned values | ||
| """ | ||
| for pair in bbox_pairs: | ||
| # assign -1 for IoU above threshold pairs, algorithms will minimize cost | ||
| cost_matrix[pair["idx"]][pair["prev_idx"]] = -1 | ||
| return cost_matrix |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.