Skip to content

Commit

Permalink
[Feature] Support OC-SORT for MOT (#545)
Browse files Browse the repository at this point in the history
* [WIP] support OC-SORT

* fix config

* implement OC-SORT, verified on MOT17-val

* edit OC-SORT readme

* fixed issue from code review suggestions

* format fix

* fix comments / format issues

* add unit test for oc-sort

* update mot17-reproduce results

* fixed issue mentioned in #545

* resolve conflict

* add trained weights and log

* update ocsort metafile.yml

* fix format

* fix typo

Co-authored-by: Jinkun Cao <jinkuncao@fb.com>
  • Loading branch information
noahcao and Jinkun Cao committed Aug 16, 2022
1 parent 904407e commit bcc0c9b
Show file tree
Hide file tree
Showing 13 changed files with 937 additions and 6 deletions.
34 changes: 34 additions & 0 deletions configs/mot/ocsort/README.md
@@ -0,0 +1,34 @@
# Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking

## Abstract

<!-- [ABSTRACT] -->

Multi-Object Tracking (MOT) has rapidly progressed with the development of object detection and re-identification. However, motion modeling, which facilitates object association by forecasting short-term trajec- tories with past observations, has been relatively under-explored in recent years. Current motion models in MOT typically assume that the object motion is linear in a small time window and needs continuous observations, so these methods are sensitive to occlusions and non-linear motion and require high frame-rate videos. In this work, we show that a simple motion model can obtain state-of-the-art tracking performance without other cues like appearance. We emphasize the role of “observation” when recovering tracks from being lost and reducing the error accumulated by linear motion models during the lost period. We thus name the proposed method as Observation-Centric SORT, OC-SORT for short. It remains simple, online, and real-time but improves robustness over occlusion and non-linear motion. It achieves 63.2 and 62.1 HOTA on MOT17 and MOT20, respectively, surpassing all published methods. It also sets new states of the art on KITTI Pedestrian Tracking and DanceTrack where the object motion is highly non-linear

<!-- [IMAGE] -->

<div align="center">
<img src="https://user-images.githubusercontent.com/17743251/168193097-b3ad1a94-b18c-4b14-b7b1-5f8c6ed842f0.png"/>
</div>

## Citation

<!-- [ALGORITHM] -->

```latex
@article{cao2022observation,
title={Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking},
author={Cao, Jinkun and Weng, Xinshuo and Khirodkar, Rawal and Pang, Jiangmiao and Kitani, Kris},
journal={arXiv preprint arXiv:2203.14360},
year={2022}
}
```

## Results and models on MOT17

The performance on `MOT17-half-val` is comparable with the performance from [the OC-SORT official implementation](https://github.com/noahcao/OC_SORT). We use the same YOLO-X detector weights as in [ByteTrack](https://github.com/open-mmlab/mmtracking/tree/master/configs/mot/bytetrack).

| Method | Detector | Train Set | Test Set | Public | Inf time (fps) | HOTA | MOTA | IDF1 | FP | FN | IDSw. | Config | Download |
| :-----: | :------: | :---------------------: | :------: | :----: | :------------: | :--: | :--: | :--: | :---: | :---: | :---: | :-------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| OC-SORT | YOLOX-X | CrowdHuman + half-train | half-val | N | - | 67.5 | 77.8 | 78.4 | 15576 | 19494 | 825 | [config](ocsort_yolox_x_crowdhuman_mot17-private-half.py) | [model](https://download.openmmlab.com/mmtracking/mot/ocsort/mot_dataset/ocsort_yolox_x_crowdhuman_mot17-private-half_20220813_101618-fe150582.pth) \| [log](https://download.openmmlab.com/mmtracking/mot/ocsort/mot_dataset/ocsort_yolox_x_crowdhuman_mot17-private-half_20220813_101618.log.json) |
26 changes: 26 additions & 0 deletions configs/mot/ocsort/metafile.yml
@@ -0,0 +1,26 @@
Collections:
- Name: OCSORT
Metadata:
Training Techniques:
- SGD with Momentum
Training Resources: 8x V100 GPUs
Architecture:
- YOLOX
Paper:
URL: https://arxiv.org/abs/2203.14360
Title: Observation-Centric SORT Rethinking SORT for Robust Multi-Object Tracking
README: configs/mot/ocsort/README.md

Models:
- Name: ocsort_yolox_x_crowdhuman_mot17-private-half
In Collection: OCSORT
Config: configs/mot/ocsort/ocsort_yolox_x_crowdhuman_mot17-private-half.py
Metadata:
Training Data: CrowdHuman + MOT17-half-train
Results:
- Task: Multiple Object Tracking
Dataset: MOT17-half-val
Metrics:
MOTA: 77.8
IDF1: 78.4
Weights: https://download.openmmlab.com/mmtracking/mot/ocsort/mot_dataset/ocsort_yolox_x_crowdhuman_mot17-private-half_20220813_101618-fe150582.pth
167 changes: 167 additions & 0 deletions configs/mot/ocsort/ocsort_yolox_x_crowdhuman_mot17-private-half.py
@@ -0,0 +1,167 @@
_base_ = [
'../../_base_/models/yolox_x_8x8.py',
'../../_base_/datasets/mot_challenge.py', '../../_base_/default_runtime.py'
]

img_scale = (800, 1440)
samples_per_gpu = 4

model = dict(
type='OCSORT',
detector=dict(
input_size=img_scale,
random_size_range=(18, 32),
bbox_head=dict(num_classes=1),
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)),
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501
)),
motion=dict(type='KalmanFilter'),
tracker=dict(
type='OCSORTTracker',
obj_score_thr=0.3,
init_track_thr=0.7,
weight_iou_with_det_scores=True,
match_iou_thr=0.3,
num_tentatives=3,
vel_consist_weight=0.2,
vel_delta_t=3,
num_frames_retain=30))

train_pipeline = [
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
bbox_clip_border=False),
dict(
type='RandomAffine',
scaling_ratio_range=(0.1, 2),
border=(-img_scale[0] // 2, -img_scale[1] // 2),
bbox_clip_border=False),
dict(
type='MixUp',
img_scale=img_scale,
ratio_range=(0.8, 1.6),
pad_val=114.0,
bbox_clip_border=False),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='Resize',
img_scale=img_scale,
keep_ratio=True,
bbox_clip_border=False),
dict(type='Pad', size_divisor=32, pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
to_rgb=False),
dict(
type='Pad',
size_divisor=32,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='ImageToTensor', keys=['img']),
dict(type='VideoCollect', keys=['img'])
])
]
data = dict(
samples_per_gpu=samples_per_gpu,
workers_per_gpu=4,
persistent_workers=True,
train=dict(
_delete_=True,
type='MultiImageMixDataset',
dataset=dict(
type='CocoDataset',
ann_file=[
'data/MOT17/annotations/half-train_cocoformat.json',
'data/crowdhuman/annotations/crowdhuman_train.json',
'data/crowdhuman/annotations/crowdhuman_val.json'
],
img_prefix=[
'data/MOT17/train', 'data/crowdhuman/train',
'data/crowdhuman/val'
],
classes=('pedestrian', ),
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True)
],
filter_empty_gt=False),
pipeline=train_pipeline),
val=dict(
pipeline=test_pipeline,
interpolate_tracks_cfg=dict(min_num_frames=5, max_num_frames=20)),
test=dict(
pipeline=test_pipeline,
interpolate_tracks_cfg=dict(min_num_frames=5, max_num_frames=20)))

# optimizer
# default 8 gpu
optimizer = dict(
type='SGD',
lr=0.001 / 8 * samples_per_gpu,
momentum=0.9,
weight_decay=5e-4,
nesterov=True,
paramwise_cfg=dict(norm_decay_mult=0.0, bias_decay_mult=0.0))
optimizer_config = dict(grad_clip=None)

# some hyper parameters
total_epochs = 80
num_last_epochs = 10
resume_from = None
interval = 5

# learning policy
lr_config = dict(
policy='YOLOX',
warmup='exp',
by_epoch=False,
warmup_by_epoch=True,
warmup_ratio=1,
warmup_iters=1,
num_last_epochs=num_last_epochs,
min_lr_ratio=0.05)

custom_hooks = [
dict(
type='YOLOXModeSwitchHook',
num_last_epochs=num_last_epochs,
priority=48),
dict(
type='SyncNormHook',
num_last_epochs=num_last_epochs,
interval=interval,
priority=48),
dict(
type='ExpMomentumEMAHook',
resume_from=resume_from,
momentum=0.0001,
priority=49)
]

checkpoint_config = dict(interval=1)
evaluation = dict(metric=['bbox', 'track'], interval=1)
search_metrics = ['MOTA', 'IDF1', 'FN', 'FP', 'IDs', 'MT', 'ML']

# you need to set mode='dynamic' if you are using pytorch<=1.5.0
fp16 = dict(loss_scale=dict(init_scale=512.))
6 changes: 6 additions & 0 deletions configs/mot/ocsort/ocsort_yolox_x_crowdhuman_mot17-private.py
@@ -0,0 +1,6 @@
_base_ = ['./ocsort_yolox_x_crowdhuman_mot17-private-half.py']

data = dict(
test=dict(
ann_file='data/MOT17/annotations/test_cocoformat.json',
img_prefix='data/MOT17/test'))
76 changes: 76 additions & 0 deletions configs/mot/ocsort/ocsort_yolox_x_crowdhuman_mot20-private-half.py
@@ -0,0 +1,76 @@
_base_ = ['./ocsort_yolox_x_crowdhuman_mot17-private-half.py']

img_scale = (896, 1600)

model = dict(
detector=dict(input_size=img_scale, random_size_range=(20, 36)),
tracker=dict(
weight_iou_with_det_scores=False,
match_iou_thr=0.3,
))

train_pipeline = [
dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
dict(
type='RandomAffine',
scaling_ratio_range=(0.1, 2),
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(
type='MixUp',
img_scale=img_scale,
ratio_range=(0.8, 1.6),
pad_val=114.0),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Resize', img_scale=img_scale, keep_ratio=True),
dict(type='Pad', size_divisor=32, pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
to_rgb=False),
dict(
type='Pad',
size_divisor=32,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='ImageToTensor', keys=['img']),
dict(type='VideoCollect', keys=['img'])
])
]
data = dict(
train=dict(
dataset=dict(
ann_file=[
'data/MOT20/annotations/train_cocoformat.json',
'data/crowdhuman/annotations/crowdhuman_train.json',
'data/crowdhuman/annotations/crowdhuman_val.json'
],
img_prefix=[
'data/MOT20/train', 'data/crowdhuman/train',
'data/crowdhuman/val'
]),
pipeline=train_pipeline),
val=dict(
ann_file='data/MOT17/annotations/train_cocoformat.json',
img_prefix='data/MOT17/train',
pipeline=test_pipeline),
test=dict(
ann_file='data/MOT20/annotations/test_cocoformat.json',
img_prefix='data/MOT20/test',
pipeline=test_pipeline))

checkpoint_config = dict(interval=1)
evaluation = dict(metric=['bbox', 'track'], interval=1)
4 changes: 3 additions & 1 deletion mmtrack/models/mot/__init__.py
Expand Up @@ -2,9 +2,11 @@
from .base import BaseMultiObjectTracker
from .byte_track import ByteTrack
from .deep_sort import DeepSORT
from .ocsort import OCSORT
from .qdtrack import QDTrack
from .tracktor import Tracktor

__all__ = [
'BaseMultiObjectTracker', 'Tracktor', 'DeepSORT', 'ByteTrack', 'QDTrack'
'BaseMultiObjectTracker', 'Tracktor', 'DeepSORT', 'ByteTrack', 'QDTrack',
'OCSORT'
]

0 comments on commit bcc0c9b

Please sign in to comment.