Skip to content

Commit

Permalink
RCNN FPN enhancement (#700)
Browse files Browse the repository at this point in the history
* RCNN FPN enhancement.
1. change faster rcnn fpn roi to 7x7 instead of 14x14.
2. change RPN Sampler, and box encoder to use numpy. This is significantly faster than mxnet.ndarray implementation.

* model hash update

* group import

* doc update

* Revert "model hash update"
fix model hash

This reverts commit b6ad1f7.

* fix docs link

* update model store

* fix author typo
  • Loading branch information
Jerryzcn authored and zhreshold committed Apr 12, 2019
1 parent 0dbd05c commit 23cb790
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 55 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ If you feel our code or models helps in your research, please kindly cite our pa
@article{zhang2019bag,
title={Bag of Freebies for Training Object Detection Neural Networks},
author={Zhang, Zhi and He, Tong and Zhang, Hang and Zhang, Zhongyuan and Xie, Junyuan and Li, Mu},
author={Zhang, Zhi and He, Tong and Zhang, Hang and Zhang, Zhongyue and Xie, Junyuan and Li, Mu},
journal={arXiv preprint arXiv:1902.04103},
year={2019}
}```
26 changes: 13 additions & 13 deletions docs/model_zoo/detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,19 @@ Checkout Faster-RCNN demo tutorial here: :ref:`sphx_glr_build_examples_detection
.. table::
:widths: 50 5 25 20

+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+
| Model | Box AP | Training Command | Training Log |
+===========================================+=================+=========================================================================================================================================+==================================================================================================================================+
| faster_rcnn_resnet50_v1b_coco [2]_ | 37.0/57.8/39.6 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet50_v1b_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet50_v1b_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_resnet101_v1d_coco [2]_ | 40.1/60.9/43.3 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet101_v1d_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet101_v1d_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_fpn_resnet50_v1b_coco [4]_ | 38.4/60.3/41.4 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_resnet50_v1b_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet101_v1d_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_fpn_resnet101_v1d_coco [4]_ | 41.2/62.7/44.8 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_resnet101_v1d_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet101_v1d_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_fpn_bn_resnet50_v1b_coco [5]_ | 39.3/61.3/42.9 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_bn_resnet50_v1b_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet101_v1d_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Model | Box AP | Training Command | Training Log |
+===========================================+=================+=========================================================================================================================================+=======================================================================================================================================+
| faster_rcnn_resnet50_v1b_coco [2]_ | 37.0/57.8/39.6 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet50_v1b_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet50_v1b_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_resnet101_v1d_coco [2]_ | 40.1/60.9/43.3 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet101_v1d_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet101_v1d_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_fpn_resnet50_v1b_coco [4]_ | 38.5/60.1/41.6 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_resnet50_v1b_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_resnet50_v1b_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_fpn_resnet101_v1d_coco [4]_ | 40.8/62.4/44.7 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_resnet101_v1d_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_resnet101_v1d_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_fpn_bn_resnet50_v1b_coco [5]_ | 39.3/61.3/42.9 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_bn_resnet50_v1b_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_bn_resnet50_v1b_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+

YOLO-v3
-------
Expand Down
4 changes: 2 additions & 2 deletions gluoncv/model_zoo/faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def faster_rcnn_fpn_resnet50_v1b_coco(pretrained=False, pretrained_base=True, **
name='fpn_resnet50_v1b', dataset='coco', pretrained=pretrained, features=features,
top_features=top_features, classes=classes, box_features=box_features,
short=800, max_size=1333, min_stage=2, max_stage=6, train_patterns=train_patterns,
nms_thresh=0.5, nms_topk=-1, post_nms=-1, roi_mode='align', roi_size=(14, 14),
nms_thresh=0.5, nms_topk=-1, post_nms=-1, roi_mode='align', roi_size=(7, 7),
strides=(4, 8, 16, 32, 64), clip=4.42, rpn_channel=1024, base_size=16,
scales=(2, 4, 8, 16, 32), ratios=(0.5, 1, 2), alloc_size=(384, 384),
rpn_nms_thresh=0.7, rpn_train_pre_nms=12000, rpn_train_post_nms=2000,
Expand Down Expand Up @@ -878,7 +878,7 @@ def faster_rcnn_fpn_resnet101_v1d_coco(pretrained=False, pretrained_base=True, *
name='fpn_resnet101_v1d', dataset='coco', pretrained=pretrained, features=features,
top_features=top_features, classes=classes, box_features=box_features,
short=800, max_size=1333, min_stage=2, max_stage=6, train_patterns=train_patterns,
nms_thresh=0.5, nms_topk=-1, post_nms=-1, roi_mode='align', roi_size=(14, 14),
nms_thresh=0.5, nms_topk=-1, post_nms=-1, roi_mode='align', roi_size=(7, 7),
strides=(4, 8, 16, 32, 64), clip=4.42, rpn_channel=1024, base_size=16,
scales=(2, 4, 8, 16, 32), ratios=(0.5, 1, 2), alloc_size=(384, 384),
rpn_nms_thresh=0.7, rpn_train_pre_nms=12000, rpn_train_post_nms=2000,
Expand Down
4 changes: 2 additions & 2 deletions gluoncv/model_zoo/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@
('447328d89d70ae1e2ca49226b8d834e5a5456df3', 'faster_rcnn_resnet50_v1b_voc'),
('5b4690fb7c5b62c44fb36c67d0642b633697f1bb', 'faster_rcnn_resnet50_v1b_coco'),
('a465eca35e78aba6ebdf99bf52031a447e501063', 'faster_rcnn_resnet101_v1d_coco'),
('24727e5541734d9703260a3ac3509a1a0cec8b82', 'faster_rcnn_fpn_resnet50_v1b_coco'),
('233572743bc537291590f4edf8a0c17c14b234bb', 'faster_rcnn_fpn_resnet50_v1b_coco'),
('977c247d70c33d1426f62147fc0e04dd329fc5ec', 'faster_rcnn_fpn_bn_resnet50_v1b_coco'),
('c24d9227b75f53b06e66f1c6a0f9115b04acc583', 'faster_rcnn_fpn_resnet101_v1d_coco'),
('1194ab4ec6e06386aadd55820add312c8ef59c74', 'faster_rcnn_fpn_resnet101_v1d_coco'),
('a3527fdc2cee5b1f32a61e5fd7cda8fb673e86e5', 'mask_rcnn_resnet50_v1b_coco'),
('4a3249c584f81c2a9b5d852b742637cd692ebdcb', 'mask_rcnn_resnet101_v1d_coco'),
('1364d0afe4de575af5d4389d50c2dbf22449ceac', 'mask_rcnn_fpn_resnet50_v1b_coco'),
Expand Down
54 changes: 25 additions & 29 deletions gluoncv/model_zoo/rpn/rpn_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import numpy as np
import mxnet as mx
from mxnet import gluon
from mxnet import autograd
from mxnet import autograd, gluon

from ...nn.bbox import BBoxSplit
from ...nn.coder import SigmoidClassEncoder, NormalizedBoxCenterEncoder
from ...nn.coder import SigmoidClassEncoder, NumPyNormalizedBoxCenterEncoder


class RPNTargetSampler(gluon.Block):
class RPNTargetSampler(object):
"""A sampler to choose positive/negative samples from RPN anchors
Parameters
Expand All @@ -25,6 +25,7 @@ class RPNTargetSampler(gluon.Block):
to be sampled.
"""

def __init__(self, num_sample, pos_iou_thresh, neg_iou_thresh, pos_ratio):
super(RPNTargetSampler, self).__init__()
self._num_sample = num_sample
Expand All @@ -33,8 +34,7 @@ def __init__(self, num_sample, pos_iou_thresh, neg_iou_thresh, pos_ratio):
self._neg_iou_thresh = neg_iou_thresh
self._eps = np.spacing(np.float32(1.0))

# pylint: disable=arguments-differ
def forward(self, ious):
def __call__(self, ious):
"""RPNTargetSampler is only used in data transform with no batch dimension.
Parameters
Expand All @@ -47,30 +47,28 @@ def forward(self, ious):
matches: (num_anchors,) value [0, M).
"""
matches = mx.nd.argmax(ious, axis=1)
matches = np.argmax(ious, axis=1)

# samples init with 0 (ignore)
ious_max_per_anchor = mx.nd.max(ious, axis=1)
samples = mx.nd.zeros_like(ious_max_per_anchor)
ious_max_per_anchor = np.max(ious, axis=1)
samples = np.zeros_like(ious_max_per_anchor)

# set argmax (1, num_gt)
ious_max_per_gt = mx.nd.max(ious, axis=0, keepdims=True)
ious_max_per_gt = np.max(ious, axis=0, keepdims=True)
# ious (num_anchor, num_gt) >= argmax (1, num_gt) -> mark row as positive
mask = mx.nd.broadcast_greater(ious + self._eps, ious_max_per_gt)
mask = (ious + self._eps) > ious_max_per_gt
# reduce column (num_anchor, num_gt) -> (num_anchor)
mask = mx.nd.sum(mask, axis=1)
mask = np.sum(mask, axis=1)
# row maybe sampled by 2 columns but still only matches to most overlapping gt
samples = mx.nd.where(mask, mx.nd.ones_like(samples), samples)
samples = np.where(mask, 1.0, samples)

# set positive overlap to 1
samples = mx.nd.where(ious_max_per_anchor >= self._pos_iou_thresh,
mx.nd.ones_like(samples), samples)
samples = np.where(ious_max_per_anchor >= self._pos_iou_thresh, 1.0, samples)
# set negative overlap to -1
tmp = (ious_max_per_anchor < self._neg_iou_thresh) * (ious_max_per_anchor >= 0)
samples = mx.nd.where(tmp, mx.nd.ones_like(samples) * -1, samples)
samples = np.where(tmp, -1.0, samples)

# subsample fg labels
samples = samples.asnumpy()
num_pos = int((samples > 0).sum())
if num_pos > self._max_pos:
disable_indices = np.random.choice(
Expand All @@ -86,8 +84,6 @@ def forward(self, ious):
np.where(samples < 0)[0], size=(num_neg - max_neg), replace=False)
samples[disable_indices] = 0

# convert to ndarray
samples = mx.nd.array(samples, ctx=matches.context)
return samples, matches


Expand All @@ -114,6 +110,7 @@ class RPNTargetGenerator(gluon.Block):
border anchors. You can set it to very large value to keep all anchors.
"""

def __init__(self, num_sample=256, pos_iou_thresh=0.7, neg_iou_thresh=0.3,
pos_ratio=0.5, stds=(1., 1., 1., 1.), allowed_border=0):
super(RPNTargetGenerator, self).__init__()
Expand All @@ -125,7 +122,7 @@ def __init__(self, num_sample=256, pos_iou_thresh=0.7, neg_iou_thresh=0.3,
self._bbox_split = BBoxSplit(axis=-1)
self._sampler = RPNTargetSampler(num_sample, pos_iou_thresh, neg_iou_thresh, pos_ratio)
self._cls_encoder = SigmoidClassEncoder()
self._box_encoder = NormalizedBoxCenterEncoder(stds=stds)
self._box_encoder = NumPyNormalizedBoxCenterEncoder(stds=stds)

# pylint: disable=arguments-differ
def forward(self, bbox, anchor, width, height):
Expand All @@ -147,23 +144,22 @@ def forward(self, bbox, anchor, width, height):
box_mask: (N, 4) only anchors whose cls_target > 0 has nonzero mask
"""
F = mx.nd
with autograd.pause():
# calculate ious between (N, 4) anchors and (M, 4) bbox ground-truths
# ious is (N, M)
ious = mx.nd.contrib.box_iou(anchor, bbox, format='corner')
ious = mx.nd.contrib.box_iou(anchor, bbox, format='corner').asnumpy()

# mask out invalid anchors, (N, 4)
a_xmin, a_ymin, a_xmax, a_ymax = F.split(anchor, num_outputs=4, axis=-1)
a_xmin, a_ymin, a_xmax, a_ymax = mx.nd.split(anchor, 4, axis=-1)
invalid_mask = (a_xmin < 0) + (a_ymin < 0) + (a_xmax >= width) + (a_ymax >= height)
invalid_mask = F.repeat(invalid_mask, repeats=bbox.shape[0], axis=-1)
ious = F.where(invalid_mask, mx.nd.ones_like(ious) * -1, ious)

ious = np.where(invalid_mask.asnumpy(), -1.0, ious)
samples, matches = self._sampler(ious)

# training targets for RPN
cls_target, _ = self._cls_encoder(samples)
box_target, box_mask = self._box_encoder(
samples.expand_dims(axis=0), matches.expand_dims(0),
anchor.expand_dims(axis=0), bbox.expand_dims(0))
return cls_target, box_target[0], box_mask[0]
np.expand_dims(samples, axis=0), np.expand_dims(matches, axis=0),
np.expand_dims(anchor.asnumpy(), axis=0), np.expand_dims(bbox.asnumpy(), axis=0))
return mx.nd.array(cls_target, ctx=bbox.context), \
mx.nd.array(box_target[0], ctx=bbox.context), \
mx.nd.array(box_mask[0], ctx=bbox.context)

0 comments on commit 23cb790

Please sign in to comment.