Skip to content

Commit

Permalink
support fp16 training and testing (#230)
Browse files Browse the repository at this point in the history
* support fp16 of vid training

* support fp16 of sot training

* support fp16 of mot training

* fix a bug for reid fp16 testing

* support fp16 of mot testing

* add configs for fp16 training & testing

* add template for fp16 readme

* add readme for fp16 configs

* modify readme for fp16 configs

* modify readme for fp16 configs

* modify readme for fp16 configs

* add metafile for fp16 configs

* update based on 1-st comments
  • Loading branch information
GT9505 committed Aug 5, 2021
1 parent 983a1bc commit c89fe8e
Show file tree
Hide file tree
Showing 23 changed files with 316 additions and 62 deletions.
32 changes: 32 additions & 0 deletions configs/fp16/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Mixed Precision Training

## Introduction

<!-- [OTHERS] -->

```latex
@article{micikevicius2017mixed,
title={Mixed precision training},
author={Micikevicius, Paulius and Narang, Sharan and Alben, Jonah and Diamos, Gregory and Elsen, Erich and Garcia, David and Ginsburg, Boris and Houston, Michael and Kuchaiev, Oleksii and Venkatesh, Ganesh and others},
journal={arXiv preprint arXiv:1710.03740},
year={2017}
}
```

## Results and Models on VID task

| Method | Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | box AP@50 | Config | Download |
| :-------: | :-------------: | :-----: | :-----: | :------: | :------------: | :----: | :------: | :--------: |
| SELSA | R-50-DC5 | pytorch | 7e | 2.71 | - | 78.7 | [config](selsa_faster_rcnn_r50_dc5_fp16_1x_imagenetvid.py) | [model](https://download.openmmlab.com/mmtracking/fp16/selsa_faster_rcnn_r50_dc5_fp16_1x_imagenetvid_20210728_193846-dce6eb09.pth) &#124; [log](https://download.openmmlab.com/mmtracking/fp16/selsa_faster_rcnn_r50_dc5_fp16_1x_imagenetvid_20210728_193846.log.json) |

## Results and Models on MOT task

| Method | Detector | ReID | Train Set | Test Set | Public | Inf time (fps) | MOTA | IDF1 | FP | FN | IDSw. | Config | Download |
| :-------: | :-------------: | :----: | :-------: | :------: | :----: | :------------: | :--: | :--: |:--:|:--:| :---: | :----: | :------: |
| Tracktor | R50-FasterRCNN-FPN | R50 | half-train | half-val | N | - | 64.7 | 66.6 | 10710 | 45270 | 1152 | [config](tracktor_faster-rcnn_r50_fpn_fp16_4e_mot17-private-half.py) | [detector](https://download.openmmlab.com/mmtracking/fp16/faster-rcnn_r50_fpn_fp16_4e_mot17-half_20210730_002436-f4ba7d61.pth) &#124; [detector_log](https://download.openmmlab.com/mmtracking/fp16/faster-rcnn_r50_fpn_fp16_4e_mot17-half_20210730_002436.log.json) &#124; [reid](https://download.openmmlab.com/mmtracking/fp16/reid_r50_fp16_8x32_6e_mot17_20210731_033055-4747ee95.pth) &#124; [reid_log](https://download.openmmlab.com/mmtracking/fp16/reid_r50_fp16_8x32_6e_mot17_20210731_033055.log.json) |

## Results and Models on SOT task

| Method | Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | Success | Norm precision | Config | Download |
| :-------: | :-------------: | :-----: | :-----: | :------: | :------------: | :----: | :----: | :------: | :--------: |
| SiameseRPN++ | R-50 | - | 20e | - | - | 49.1 | 57.0 | [config](siamese_rpn_r50_fp16_1x_lasot.py) | [model](https://download.openmmlab.com/mmtracking/fp16/siamese_rpn_r50_fp16_1x_lasot_20210731_110245-6733c67e.pth) &#124; [log](https://download.openmmlab.com/mmtracking/fp16/siamese_rpn_r50_fp16_1x_lasot_20210731_110245.log.json) |
50 changes: 50 additions & 0 deletions configs/fp16/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
Collections:
- Name: FP16
Metadata:
Training Techniques:
- Mixed Precision Training
Training Resources: 8x TITAN Xp GPUs
Paper: https://arxiv.org/abs/1710.03740
README: configs/fp16/README.md

Models:
- Name: selsa_faster_rcnn_r50_dc5_fp16_1x_imagenetvid
In Collection: FP16
Config: configs/fp16/selsa_faster_rcnn_r50_dc5_fp16_1x_imagenetvid.py
Metadata:
Training Memory (GB): 2.71
Epochs: 7
Results:
- Task: Video Object Detection
Dataset: ILSVRC
Metrics:
box AP@0.5: 78.7
Weights: https://download.openmmlab.com/mmtracking/fp16/selsa_faster_rcnn_r50_dc5_fp16_1x_imagenetvid_20210728_193846-dce6eb09.pth

- Name: tracktor_faster-rcnn_r50_fpn_fp16_4e_mot17-private-half
In Collection: FP16
Config: configs/fp16/tracktor_faster-rcnn_r50_fpn_fp16_4e_mot17-private-half.py
Metadata:
Training Data: MOT17-half-train
Results:
- Task: Multiple Object Tracking
Dataset: MOT17-half-val
Metrics:
MOTA: 64.7
IDF1: 66.6
Weights:
- https://download.openmmlab.com/mmtracking/fp16/faster-rcnn_r50_fpn_fp16_4e_mot17-half_20210730_002436-f4ba7d61.pth
- https://download.openmmlab.com/mmtracking/fp16/reid_r50_fp16_8x32_6e_mot17_20210731_033055-4747ee95.pth

- Name: siamese_rpn_r50_fp16_1x_lasot
In Collection: FP16
Config: configs/fp16/siamese_rpn_r50_fp16_1x_lasot.py
Metadata:
Epochs: 20
Results:
- Task: Single Object Tracking
Dataset: LaSOT
Metrics:
Success: 49.1
Norm precision: 57.0
Weights: https://download.openmmlab.com/mmtracking/fp16/siamese_rpn_r50_fp16_1x_lasot_20210731_110245-6733c67e.pth
2 changes: 2 additions & 0 deletions configs/fp16/selsa_faster_rcnn_r50_dc5_fp16_1x_imagenetvid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = ['../vid/selsa/selsa_faster_rcnn_r50_dc5_1x_imagenetvid.py']
fp16 = dict(loss_scale=512.)
3 changes: 3 additions & 0 deletions configs/fp16/siamese_rpn_r50_fp16_1x_lasot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = ['../sot/siamese_rpn/siamese_rpn_r50_1x_lasot.py']
optimizer_config = dict(type='SiameseRPNFp16OptimizerHook')
fp16 = dict(loss_scale=512.)
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = [
'../mot/tracktor/tracktor_faster-rcnn_r50_fpn_4e_mot17-private-half.py'
]
model = dict(
pretrains=dict(
detector= # noqa: E251
'https://download.openmmlab.com/mmtracking/fp16/faster-rcnn_r50_fpn_fp16_4e_mot17-half_20210730_002436-f4ba7d61.pth', # noqa: E501
reid= # noqa: E251
'https://download.openmmlab.com/mmtracking/fp16/reid_r50_fp16_8x32_6e_mot17_20210731_033055-4747ee95.pth' # noqa: E501
))
fp16 = dict(loss_scale=512.)
17 changes: 9 additions & 8 deletions mmtrack/apis/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer)
build_optimizer)
from mmcv.utils import build_from_cfg
from mmdet.datasets import build_dataset

Expand Down Expand Up @@ -86,13 +86,14 @@ def train_model(model,

# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
optimizer_config = cfg.optimizer_config
if 'type' not in cfg.optimizer_config:
optimizer_config.type = 'Fp16OptimizerHook' \
if fp16_cfg else 'OptimizerHook'
if fp16_cfg:
optimizer_config.update(fp16_cfg)
if 'Fp16' in optimizer_config.type:
optimizer_config.update(distributed=distributed)

# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
Expand Down
4 changes: 2 additions & 2 deletions mmtrack/core/motion/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def flow_warp_feats(x, flow):
H, W = x.shape[-2:]
h_grid, w_grid = torch.meshgrid(torch.arange(H), torch.arange(W))
# [1, 1, H, W]
h_grid = h_grid.to(flow.device).float()[None, None, ...]
h_grid = h_grid.to(flow)[None, None, ...]
# [1, 1, H, W]
w_grid = w_grid.to(flow.device).float()[None, None, ...]
w_grid = w_grid.to(flow)[None, None, ...]
# [1, 2, H, W]
grid = torch.cat((w_grid, h_grid), dim=1)
# [N, 2, H, W]
Expand Down
8 changes: 6 additions & 2 deletions mmtrack/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .sot_lr_updater import SiameseRPNLrUpdaterHook
from .sot_optimizer_hook import SiameseRPNOptimizerHook
from .sot_optimizer_hook import (SiameseRPNFp16OptimizerHook,
SiameseRPNOptimizerHook)

__all__ = ['SiameseRPNOptimizerHook', 'SiameseRPNLrUpdaterHook']
__all__ = [
'SiameseRPNOptimizerHook', 'SiameseRPNLrUpdaterHook',
'SiameseRPNFp16OptimizerHook'
]
34 changes: 33 additions & 1 deletion mmtrack/core/optimizer/sot_optimizer_hook.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch.nn as nn
from mmcv.runner.hooks import HOOKS, OptimizerHook
from mmcv.runner.hooks import HOOKS, Fp16OptimizerHook, OptimizerHook


@HOOKS.register_module()
Expand Down Expand Up @@ -32,3 +32,35 @@ def before_train_epoch(self, runner):
layer).modules():
if isinstance(m, nn.BatchNorm2d):
m.train()


@HOOKS.register_module()
class SiameseRPNFp16OptimizerHook(Fp16OptimizerHook):
"""FP16Optimizer hook for siamese rpn.
Args:
backbone_start_train_epoch (int): Start to train the backbone at
`backbone_start_train_epoch`-th epoch. Note the epoch in this
class counts from 0, while the epoch in the log file counts from 1.
backbone_train_layers (list(str)): List of str denoting the stages
needed be trained in backbone.
"""

def __init__(self, backbone_start_train_epoch, backbone_train_layers,
**kwargs):
super(SiameseRPNFp16OptimizerHook, self).__init__(**kwargs)
self.backbone_start_train_epoch = backbone_start_train_epoch
self.backbone_train_layers = backbone_train_layers

def before_train_epoch(self, runner):
"""If `runner.epoch >= self.backbone_start_train_epoch`, start to train
the backbone."""
if runner.epoch >= self.backbone_start_train_epoch:
for layer in self.backbone_train_layers:
for param in getattr(runner.model.module.backbone,
layer).parameters():
param.requires_grad = True
for m in getattr(runner.model.module.backbone,
layer).modules():
if isinstance(m, nn.BatchNorm2d):
m.train()
6 changes: 3 additions & 3 deletions mmtrack/core/track/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def imrenormalize(img, img_norm_cfg, new_img_norm_cfg):
new_img = img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
new_img = _imrenormalize(new_img, img_norm_cfg, new_img_norm_cfg)
new_img = new_img.transpose(2, 0, 1)[None]
return torch.from_numpy(new_img).to(img.device)
return torch.from_numpy(new_img).to(img)
else:
return _imrenormalize(img, img_norm_cfg, new_img_norm_cfg)

Expand All @@ -33,14 +33,14 @@ def _imrenormalize(img, img_norm_cfg, new_img_norm_cfg):
new_img_norm_cfg = new_img_norm_cfg.copy()
for k, v in img_norm_cfg.items():
if (k == 'mean' or k == 'std') and not isinstance(v, np.ndarray):
img_norm_cfg[k] = np.array(v, dtype=np.float32)
img_norm_cfg[k] = np.array(v, dtype=img.dtype)
# reverse cfg
if 'to_rgb' in img_norm_cfg:
img_norm_cfg['to_bgr'] = img_norm_cfg['to_rgb']
img_norm_cfg.pop('to_rgb')
for k, v in new_img_norm_cfg.items():
if (k == 'mean' or k == 'std') and not isinstance(v, np.ndarray):
new_img_norm_cfg[k] = np.array(v, dtype=np.float32)
new_img_norm_cfg[k] = np.array(v, dtype=img.dtype)
img = mmcv.imdenormalize(img, **img_norm_cfg)
img = mmcv.imnormalize(img, **new_img_norm_cfg)
return img
Expand Down
1 change: 1 addition & 0 deletions mmtrack/models/mot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class BaseMultiObjectTracker(nn.Module, metaclass=ABCMeta):
def __init__(self):
super(BaseMultiObjectTracker, self).__init__()
self.logger = get_root_logger()
self.fp16_enabled = False

def init_module(self, module_name, pretrain=None):
"""Initialize the weights of a sub-module.
Expand Down
4 changes: 2 additions & 2 deletions mmtrack/models/mot/deep_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def init_weights(self, pretrain):
def forward_train(self, *args, **kwargs):
"""Forward function during training."""
raise NotImplementedError(
'Please train `detector` and `reid` models first and \
inference with Tracktor.')
'Please train `detector` and `reid` models firstly, then \
inference with SORT/DeepSORT.')

def simple_test(self,
img,
Expand Down
4 changes: 3 additions & 1 deletion mmtrack/models/mot/trackers/base_tracker.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from abc import ABCMeta, abstractmethod

import torch
import torch.nn as nn
import torch.nn.functional as F
from addict import Dict

from mmtrack.models import TRACKERS


@TRACKERS.register_module()
class BaseTracker(metaclass=ABCMeta):
class BaseTracker(nn.Module, metaclass=ABCMeta):
"""Base tracker model.
Args:
Expand All @@ -25,6 +26,7 @@ def __init__(self, momentums=None, num_frames_retain=10):
assert isinstance(momentums, dict), 'momentums must be a dict'
self.momentums = momentums
self.num_frames_retain = num_frames_retain
self.fp16_enabled = False

self.reset()

Expand Down
2 changes: 2 additions & 0 deletions mmtrack/models/mot/trackers/sort_tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch
from mmcv.runner import force_fp32
from mmdet.core import bbox_overlaps
from motmetrics.lap import linear_sum_assignment

Expand Down Expand Up @@ -97,6 +98,7 @@ def pop_invalid_tracks(self, frame_id):
for invalid_id in invalid_ids:
self.tracks.pop(invalid_id)

@force_fp32(apply_to=('img', ))
def track(self,
img,
img_metas,
Expand Down
2 changes: 2 additions & 0 deletions mmtrack/models/mot/trackers/tracktor_tracker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from mmcv.runner import force_fp32
from mmdet.core import bbox_overlaps, multiclass_nms
from scipy.optimize import linear_sum_assignment

Expand Down Expand Up @@ -76,6 +77,7 @@ def regress_tracks(self, x, img_metas, detector, frame_id, rescale=False):
return track_bboxes[valid_inds], track_labels[valid_inds], ids[
valid_inds]

@force_fp32(apply_to=('img', 'feats'))
def track(self,
img,
img_metas,
Expand Down
10 changes: 3 additions & 7 deletions mmtrack/models/mot/tracktor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def with_linear_motion(self):
def forward_train(self, *args, **kwargs):
"""Forward function during training."""
raise NotImplementedError(
'Please train `detector` and `reid` models first and \
'Please train `detector` and `reid` models firstly, then \
inference with Tracktor.')

def simple_test(self,
Expand Down Expand Up @@ -122,13 +122,9 @@ def simple_test(self,
det_labels = det_labels[0]
num_classes = self.detector.roi_head.bbox_head.num_classes
elif hasattr(self.detector, 'bbox_head'):
outs = self.detector.bbox_head(x)
result_list = self.detector.bbox_head.get_bboxes(
*outs, img_metas=img_metas, rescale=rescale)
# TODO: support batch inference
det_bboxes = result_list[0][0]
det_labels = result_list[0][1]
num_classes = self.detector.bbox_head.num_classes
raise NotImplementedError(
'Tracktor must need "roi_head" to refine proposals.')
else:
raise TypeError('detector must has roi_head or bbox_head.')

Expand Down
8 changes: 4 additions & 4 deletions mmtrack/models/motion/flownet_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,20 +164,20 @@ def prepare_imgs(self, imgs, img_metas):
"""
if not hasattr(self, 'img_norm_mean'):
mean = img_metas[0]['img_norm_cfg']['mean']
mean = torch.tensor(mean, device=imgs.device)
mean = torch.tensor(mean, dtype=imgs.dtype, device=imgs.device)
self.img_norm_mean = mean.repeat(2)[None, :, None, None]

mean = self.flow_img_norm_mean
mean = torch.tensor(mean, device=imgs.device)
mean = torch.tensor(mean, dtype=imgs.dtype, device=imgs.device)
self.flow_img_norm_mean = mean.repeat(2)[None, :, None, None]

if not hasattr(self, 'img_norm_std'):
std = img_metas[0]['img_norm_cfg']['std']
std = torch.tensor(std, device=imgs.device)
std = torch.tensor(std, dtype=imgs.dtype, device=imgs.device)
self.img_norm_std = std.repeat(2)[None, :, None, None]

std = self.flow_img_norm_std
std = torch.tensor(std, device=imgs.device)
std = torch.tensor(std, dtype=imgs.dtype, device=imgs.device)
self.flow_img_norm_std = std.repeat(2)[None, :, None, None]

flow_img = imgs * self.img_norm_std + self.img_norm_mean
Expand Down
13 changes: 9 additions & 4 deletions mmtrack/models/reid/base_reid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from mmcls.models import ImageClassifier
from mmcv.runner import auto_fp16

from ..builder import REID

Expand All @@ -19,16 +20,20 @@ def forward_train(self, img, gt_label, **kwargs):
# change the shape of label tensor from NxS to NS
gt_label = gt_label.view(-1)
x = self.extract_feat(img)
head_outputs = self.head.forward_train(x)

losses = dict()
loss = self.head.forward_train(x, gt_label)
losses.update(loss)
return loss
reid_loss = self.head.loss(gt_label, *head_outputs)
losses.update(reid_loss)
return losses

@auto_fp16(apply_to=('img', ), out_fp32=True)
def simple_test(self, img, **kwargs):
"""Test without augmentation."""
if img.nelement() > 0:
x = self.extract_feat(img)
return self.head.simple_test(x)
head_outputs = self.head.forward_train(x)
feats = head_outputs[0]
return feats
else:
return img.new_zeros(0, self.head.out_channels)
Loading

0 comments on commit c89fe8e

Please sign in to comment.