Skip to content

Commit

Permalink
Support unclip border bbox regression (#4076)
Browse files Browse the repository at this point in the history
* update

* clip border

* clip border

* clip

* update

* update

* update

* update
  • Loading branch information
OceanPang committed Nov 19, 2020
1 parent f382ec8 commit caa4a4e
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 27 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -104,6 +104,7 @@ venv.bak/
.mypy_cache/

data/
data
.vscode
.idea
.DS_Store
Expand Down
15 changes: 11 additions & 4 deletions mmdet/core/bbox/coder/bucketing_bbox_coder.py
Expand Up @@ -26,20 +26,24 @@ class BucketingBBoxCoder(BaseBBoxCoder):
To avoid too large offset displacements. Defaults to 1.0.
cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
Defaults to True.
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
"""

def __init__(self,
num_buckets,
scale_factor,
offset_topk=2,
offset_upperbound=1.0,
cls_ignore_neighbor=True):
cls_ignore_neighbor=True,
clip_border=True):
super(BucketingBBoxCoder, self).__init__()
self.num_buckets = num_buckets
self.scale_factor = scale_factor
self.offset_topk = offset_topk
self.offset_upperbound = offset_upperbound
self.cls_ignore_neighbor = cls_ignore_neighbor
self.clip_border = clip_border

def encode(self, bboxes, gt_bboxes):
"""Get bucketing estimation and fine regression targets during
Expand Down Expand Up @@ -81,7 +85,7 @@ def decode(self, bboxes, pred_bboxes, max_shape=None):
0) == bboxes.size(0)
decoded_bboxes = bucket2bbox(bboxes, cls_preds, offset_preds,
self.num_buckets, self.scale_factor,
max_shape)
max_shape, self.clip_border)

return decoded_bboxes

Expand Down Expand Up @@ -262,7 +266,8 @@ def bucket2bbox(proposals,
offset_preds,
num_buckets,
scale_factor=1.0,
max_shape=None):
max_shape=None,
clip_border=True):
"""Apply bucketing estimation (cls preds) and fine regression (offset
preds) to generate det bboxes.
Expand All @@ -273,6 +278,8 @@ def bucket2bbox(proposals,
num_buckets (int): Number of buckets.
scale_factor (float): Scale factor to rescale proposals.
max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
Returns:
tuple[Tensor]: (bboxes, loc_confidence).
Expand Down Expand Up @@ -322,7 +329,7 @@ def bucket2bbox(proposals,
y1 = t_buckets - t_offsets * bucket_h
y2 = d_buckets - d_offsets * bucket_h

if max_shape is not None:
if clip_border and max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1] - 1)
y1 = y1.clamp(min=0, max=max_shape[0] - 1)
x2 = x2.clamp(min=0, max=max_shape[1] - 1)
Expand Down
15 changes: 11 additions & 4 deletions mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
Expand Up @@ -18,14 +18,18 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder):
delta coordinates
target_stds (Sequence[float]): Denormalizing standard deviation of
target for delta coordinates
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
"""

def __init__(self,
target_means=(0., 0., 0., 0.),
target_stds=(1., 1., 1., 1.)):
target_stds=(1., 1., 1., 1.),
clip_border=True):
super(BaseBBoxCoder, self).__init__()
self.means = target_means
self.stds = target_stds
self.clip_border = clip_border

def encode(self, bboxes, gt_bboxes):
"""Get box regression transformation deltas that can be used to
Expand Down Expand Up @@ -66,7 +70,7 @@ def decode(self,

assert pred_bboxes.size(0) == bboxes.size(0)
decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds,
max_shape, wh_ratio_clip)
max_shape, wh_ratio_clip, self.clip_border)

return decoded_bboxes

Expand Down Expand Up @@ -121,7 +125,8 @@ def delta2bbox(rois,
means=(0., 0., 0., 0.),
stds=(1., 1., 1., 1.),
max_shape=None,
wh_ratio_clip=16 / 1000):
wh_ratio_clip=16 / 1000,
clip_border=True):
"""Apply deltas to shift/scale base boxes.
Typically the rois are anchor or proposed bounding boxes and the deltas are
Expand All @@ -138,6 +143,8 @@ def delta2bbox(rois,
coordinates
max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
wh_ratio_clip (float): Maximum aspect ratio for boxes.
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
Returns:
Tensor: Boxes with shape (N, 4), where columns represent
Expand Down Expand Up @@ -188,7 +195,7 @@ def delta2bbox(rois,
y1 = gy - gh * 0.5
x2 = gx + gw * 0.5
y2 = gy + gh * 0.5
if max_shape is not None:
if clip_border and max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1])
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
Expand Down
15 changes: 11 additions & 4 deletions mmdet/core/bbox/coder/tblr_bbox_coder.py
Expand Up @@ -17,11 +17,14 @@ class TBLRBBoxCoder(BaseBBoxCoder):
divided with when coding the coordinates. If it is a list, it should
have length of 4 indicating normalization factor in tblr dims.
Otherwise it is a unified float factor for all dims. Default: 4.0
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
"""

def __init__(self, normalizer=4.0):
def __init__(self, normalizer=4.0, clip_border=True):
super(BaseBBoxCoder, self).__init__()
self.normalizer = normalizer
self.clip_border = clip_border

def encode(self, bboxes, gt_bboxes):
"""Get box regression transformation deltas that can be used to
Expand Down Expand Up @@ -59,7 +62,8 @@ def decode(self, bboxes, pred_bboxes, max_shape=None):
bboxes,
pred_bboxes,
normalizer=self.normalizer,
max_shape=max_shape)
max_shape=max_shape,
clip_border=self.clip_border)

return decoded_bboxes

Expand Down Expand Up @@ -114,7 +118,8 @@ def tblr2bboxes(priors,
tblr,
normalizer=4.0,
normalize_by_wh=True,
max_shape=None):
max_shape=None,
clip_border=True):
"""Decode tblr outputs to prediction boxes.
The process includes 3 steps: 1) De-normalize tblr coordinates by
Expand All @@ -136,6 +141,8 @@ def tblr2bboxes(priors,
normalized by the side length (wh) of prior bboxes.
max_shape (tuple, optional): Shape of the image. Decoded bboxes
exceeding which will be clamped.
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
Return:
encoded boxes (Tensor), Shape: (n, 4)
Expand All @@ -157,7 +164,7 @@ def tblr2bboxes(priors,
ymin = prior_centers[:, 1].unsqueeze(1) - top
ymax = prior_centers[:, 1].unsqueeze(1) + bottom
boxes = torch.cat((xmin, ymin, xmax, ymax), dim=1)
if max_shape is not None:
if clip_border and max_shape is not None:
boxes[:, 0].clamp_(min=0, max=max_shape[1])
boxes[:, 1].clamp_(min=0, max=max_shape[0])
boxes[:, 2].clamp_(min=0, max=max_shape[1])
Expand Down
59 changes: 44 additions & 15 deletions mmdet/datasets/pipelines/transforms.py
Expand Up @@ -49,6 +49,8 @@ class Resize(object):
ratio_range (tuple[float]): (min_ratio, max_ratio)
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image.
bbox_clip_border (bool, optional): Whether clip the objects outside
the border of the image. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
Expand All @@ -59,6 +61,7 @@ def __init__(self,
multiscale_mode='range',
ratio_range=None,
keep_ratio=True,
bbox_clip_border=True,
backend='cv2'):
if img_scale is None:
self.img_scale = None
Expand All @@ -80,6 +83,7 @@ def __init__(self,
self.multiscale_mode = multiscale_mode
self.ratio_range = ratio_range
self.keep_ratio = keep_ratio
self.bbox_clip_border = bbox_clip_border

@staticmethod
def random_select(img_scales):
Expand Down Expand Up @@ -219,11 +223,12 @@ def _resize_img(self, results):

def _resize_bboxes(self, results):
"""Resize bounding boxes with ``results['scale_factor']``."""
img_shape = results['img_shape']
for key in results.get('bbox_fields', []):
bboxes = results[key] * results['scale_factor']
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
if self.bbox_clip_border:
img_shape = results['img_shape']
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
results[key] = bboxes

def _resize_masks(self, results):
Expand Down Expand Up @@ -290,6 +295,7 @@ def __repr__(self):
repr_str += f'multiscale_mode={self.multiscale_mode}, '
repr_str += f'ratio_range={self.ratio_range}, '
repr_str += f'keep_ratio={self.keep_ratio})'
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
return repr_str


Expand Down Expand Up @@ -570,6 +576,8 @@ class RandomCrop(object):
crop_size (tuple): Expected size after cropping, (h, w).
allow_negative_crop (bool): Whether to allow a crop that does not
contain any bbox area. Default to False.
bbox_clip_border (bool, optional): Whether clip the objects outside
the border of the image. Defaults to True.
Note:
- If the image is smaller than the crop size, return the original image
Expand All @@ -581,10 +589,14 @@ class RandomCrop(object):
`allow_negative_crop` is set to False, skip this image.
"""

def __init__(self, crop_size, allow_negative_crop=False):
def __init__(self,
crop_size,
allow_negative_crop=False,
bbox_clip_border=True):
assert crop_size[0] > 0 and crop_size[1] > 0
self.crop_size = crop_size
self.allow_negative_crop = allow_negative_crop
self.bbox_clip_border = bbox_clip_border
# The key correspondence from bboxes to labels and masks.
self.bbox2label = {
'gt_bboxes': 'gt_labels',
Expand Down Expand Up @@ -628,8 +640,9 @@ def __call__(self, results):
bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h],
dtype=np.float32)
bboxes = results[key] - bbox_offset
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
if self.bbox_clip_border:
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & (
bboxes[:, 3] > bboxes[:, 1])
# If the crop does not contain any gt-bbox area and
Expand Down Expand Up @@ -657,7 +670,9 @@ def __call__(self, results):
return results

def __repr__(self):
return self.__class__.__name__ + f'(crop_size={self.crop_size})'
repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}), '
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
return repr_str


@PIPELINES.register_module()
Expand Down Expand Up @@ -907,18 +922,24 @@ class MinIoURandomCrop(object):
bounding boxes
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
where a >= min_crop_size).
bbox_clip_border (bool, optional): Whether clip the objects outside
the border of the image. Defaults to True.
Note:
The keys for bboxes, labels and masks should be paired. That is, \
`gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \
`gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`.
"""

def __init__(self, min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3):
def __init__(self,
min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
min_crop_size=0.3,
bbox_clip_border=True):
# 1: return ori img
self.min_ious = min_ious
self.sample_mode = (1, *min_ious, 0)
self.min_crop_size = min_crop_size
self.bbox_clip_border = bbox_clip_border
self.bbox2label = {
'gt_bboxes': 'gt_labels',
'gt_bboxes_ignore': 'gt_labels_ignore'
Expand Down Expand Up @@ -995,8 +1016,9 @@ def is_center_of_bboxes_in_patch(boxes, patch):
boxes = results[key].copy()
mask = is_center_of_bboxes_in_patch(boxes, patch)
boxes = boxes[mask]
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
if self.bbox_clip_border:
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
boxes -= np.tile(patch[:2], 2)

results[key] = boxes
Expand Down Expand Up @@ -1024,7 +1046,8 @@ def is_center_of_bboxes_in_patch(boxes, patch):
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(min_ious={self.min_ious}, '
repr_str += f'min_crop_size={self.min_crop_size})'
repr_str += f'min_crop_size={self.min_crop_size}), '
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
return repr_str


Expand Down Expand Up @@ -1351,6 +1374,8 @@ class RandomCenterCropPad(object):
- 'logical_or': final_shape = input_shape | padding_shape_value
- 'size_divisor': final_shape = int(
ceil(input_shape / padding_shape_value) * padding_shape_value)
bbox_clip_border (bool, optional): Whether clip the objects outside
the border of the image. Defaults to True.
"""

def __init__(self,
Expand All @@ -1361,7 +1386,8 @@ def __init__(self,
std=None,
to_rgb=None,
test_mode=False,
test_pad_mode=('logical_or', 127)):
test_pad_mode=('logical_or', 127),
bbox_clip_border=True):
if test_mode:
assert crop_size is None, 'crop_size must be None in test mode'
assert ratios is None, 'ratios must be None in test mode'
Expand Down Expand Up @@ -1394,6 +1420,7 @@ def __init__(self,
self.std = std
self.test_mode = test_mode
self.test_pad_mode = test_pad_mode
self.bbox_clip_border = bbox_clip_border

def _get_border(self, border, size):
"""Get final border for the target size.
Expand Down Expand Up @@ -1527,8 +1554,9 @@ def _train_aug(self, results):
bboxes = results[key][mask]
bboxes[:, 0:4:2] += cropped_center_x - left_w - x0
bboxes[:, 1:4:2] += cropped_center_y - top_h - y0
bboxes[:, 0:4:2] = np.clip(bboxes[:, 0:4:2], 0, new_w)
bboxes[:, 1:4:2] = np.clip(bboxes[:, 1:4:2], 0, new_h)
if self.bbox_clip_border:
bboxes[:, 0:4:2] = np.clip(bboxes[:, 0:4:2], 0, new_w)
bboxes[:, 1:4:2] = np.clip(bboxes[:, 1:4:2], 0, new_h)
keep = (bboxes[:, 2] > bboxes[:, 0]) & (
bboxes[:, 3] > bboxes[:, 1])
bboxes = bboxes[keep]
Expand Down Expand Up @@ -1602,7 +1630,8 @@ def __repr__(self):
repr_str += f'std={self.input_std}, '
repr_str += f'to_rgb={self.to_rgb}, '
repr_str += f'test_mode={self.test_mode}, '
repr_str += f'test_pad_mode={self.test_pad_mode})'
repr_str += f'test_pad_mode={self.test_pad_mode}), '
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
return repr_str


Expand Down
1 change: 1 addition & 0 deletions mmdet/models/dense_heads/anchor_head.py
Expand Up @@ -41,6 +41,7 @@ def __init__(self,
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
clip_border=True,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)),
reg_decoded_bbox=False,
Expand Down
1 change: 1 addition & 0 deletions mmdet/models/dense_heads/ssd_head.py
Expand Up @@ -40,6 +40,7 @@ def __init__(self,
basesize_ratio_range=(0.1, 0.9)),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
clip_border=True,
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
),
Expand Down
1 change: 1 addition & 0 deletions mmdet/models/roi_heads/bbox_heads/bbox_head.py
Expand Up @@ -23,6 +23,7 @@ def __init__(self,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
clip_border=True,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
Expand Down

0 comments on commit caa4a4e

Please sign in to comment.