Skip to content

Commit

Permalink
Allow for images to contain zero true detections (open-mmlab#1531)
Browse files Browse the repository at this point in the history
* Allow for images to contain zero true detections

* Allow for empty assignment in PointAssigner

* Allow ApproxMaxIouAssigner to return an empty result

* Fix CascadeRNN forward when entire batch has no truth

* Correctly assign boxes to background when there is no truth

* Fix assignment tests

* Make flatten robust

* Fix bbox loss with empty pred/truth

* Fix logic error in BBoxHead.loss

* Add tests for empty truth cases

* tests faster rcnn empty forward

* Skip roipool forward tests if torchvision is not installed

* Add tests for bbox/anchor heads

* Consolidate test_forward and test_forward2

* Fix assign_results.labels = None when gt_labels is given; Add test for this case

* Fix OHEM Sampler with zero truth

* remove xdev

* resolve 3 reviews

* Fix flake8

* refactoring

* fix yaml format

* add filter flag

* minor fix

* delete redundant code in load anno

* fix flake8 errors

* quick fix for empty truth with masks

* fix yapf error

* fix mask padding for empty masks

Co-authored-by: Cao Yuhang <yhcao6@gmail.com>
Co-authored-by: Kai Chen <chenkaidev@gmail.com>
  • Loading branch information
3 people authored and ioir123ju committed Mar 30, 2020
1 parent 3eb3f4b commit e60e341
Show file tree
Hide file tree
Showing 18 changed files with 1,032 additions and 66 deletions.
14 changes: 9 additions & 5 deletions mmdet/core/bbox/assigners/approx_max_iou_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def assign(self,
Args:
approxs (Tensor): Bounding boxes to be assigned,
shape(approxs_per_octave*n, 4).
shape(approxs_per_octave*n, 4).
squares (Tensor): Base Bounding boxes to be assigned,
shape(n, 4).
shape(n, 4).
approxs_per_octave (int): number of approxs per octave
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
Expand All @@ -86,11 +86,15 @@ def assign(self,
Returns:
:obj:`AssignResult`: The assign result.
"""

if squares.shape[0] == 0 or gt_bboxes.shape[0] == 0:
raise ValueError('No gt or approxs')
num_squares = squares.size(0)
num_gts = gt_bboxes.size(0)

if num_squares == 0 or num_gts == 0:
# No predictions and/or truth, return empty assignment
overlaps = approxs.new(num_gts, num_squares)
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
return assign_result

# re-organize anchors by approxs_per_octave x num_squares
approxs = torch.transpose(
approxs.view(num_squares, approxs_per_octave, 4), 0,
Expand Down
75 changes: 74 additions & 1 deletion mmdet/core/bbox/assigners/assign_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,41 @@


class AssignResult(object):
"""
Stores assignments between predicted and truth boxes.
Attributes:
num_gts (int): the number of truth boxes considered when computing this
assignment
gt_inds (LongTensor): for each predicted box indicates the 1-based
index of the assigned truth box. 0 means unassigned and -1 means
ignore.
max_overlaps (FloatTensor): the iou between the predicted box and its
assigned truth box.
labels (None | LongTensor): If specified, for each predicted box
indicates the category label of the assigned truth box.
Example:
>>> # An assign result between 4 predicted boxes and 9 true boxes
>>> # where only two boxes were assigned.
>>> num_gts = 9
>>> max_overlaps = torch.LongTensor([0, .5, .9, 0])
>>> gt_inds = torch.LongTensor([-1, 1, 2, 0])
>>> labels = torch.LongTensor([0, 3, 4, 0])
>>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels)
>>> print(str(self)) # xdoctest: +IGNORE_WANT
<AssignResult(num_gts=9, gt_inds.shape=(4,), max_overlaps.shape=(4,),
labels.shape=(4,))>
>>> # Force addition of gt labels (when adding gt as proposals)
>>> new_labels = torch.LongTensor([3, 4, 5])
>>> self.add_gt_(new_labels)
>>> print(str(self)) # xdoctest: +IGNORE_WANT
<AssignResult(num_gts=9, gt_inds.shape=(7,), max_overlaps.shape=(7,),
labels.shape=(7,))>
"""

def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
self.num_gts = num_gts
Expand All @@ -13,7 +48,45 @@ def add_gt_(self, gt_labels):
self_inds = torch.arange(
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
self.gt_inds = torch.cat([self_inds, self.gt_inds])

# Was this a bug?
# self.max_overlaps = torch.cat(
# [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
# IIUC, It seems like the correct code should be:
self.max_overlaps = torch.cat(
[self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
[self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])

if self.labels is not None:
self.labels = torch.cat([gt_labels, self.labels])

def __nice__(self):
"""
Create a "nice" summary string describing this assign result
"""
parts = []
parts.append('num_gts={!r}'.format(self.num_gts))
if self.gt_inds is None:
parts.append('gt_inds={!r}'.format(self.gt_inds))
else:
parts.append('gt_inds.shape={!r}'.format(
tuple(self.gt_inds.shape)))
if self.max_overlaps is None:
parts.append('max_overlaps={!r}'.format(self.max_overlaps))
else:
parts.append('max_overlaps.shape={!r}'.format(
tuple(self.max_overlaps.shape)))
if self.labels is None:
parts.append('labels={!r}'.format(self.labels))
else:
parts.append('labels.shape={!r}'.format(tuple(self.labels.shape)))
return ', '.join(parts)

def __repr__(self):
nice = self.__nice__()
classname = self.__class__.__name__
return '<{}({}) at {}>'.format(classname, nice, hex(id(self)))

def __str__(self):
classname = self.__class__.__name__
nice = self.__nice__()
return '<{}({})>'.format(classname, nice)
31 changes: 26 additions & 5 deletions mmdet/core/bbox/assigners/max_iou_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,15 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
Returns:
:obj:`AssignResult`: The assign result.
Example:
>>> self = MaxIoUAssigner(0.5, 0.5)
>>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
>>> gt_bboxes = torch.Tensor([[0, 0, 10, 9]])
>>> assign_result = self.assign(bboxes, gt_bboxes)
>>> expected_gt_inds = torch.LongTensor([1, 0])
>>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
"""
if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0:
raise ValueError('No gt or bboxes')
assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
gt_bboxes.shape[0] > self.gpu_assign_thr) else False
# compute overlap and assign gt on CPU when number of GT is large
Expand All @@ -88,6 +94,7 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
gt_bboxes_ignore = gt_bboxes_ignore.cpu()
if gt_labels is not None:
gt_labels = gt_labels.cpu()

bboxes = bboxes[:, :4]
overlaps = bbox_overlaps(gt_bboxes, bboxes)

Expand Down Expand Up @@ -122,16 +129,30 @@ def assign_wrt_overlaps(self, overlaps, gt_labels=None):
Returns:
:obj:`AssignResult`: The assign result.
"""
if overlaps.numel() == 0:
raise ValueError('No gt or proposals')

num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)

# 1. assign -1 by default
assigned_gt_inds = overlaps.new_full((num_bboxes, ),
-1,
dtype=torch.long)

if num_gts == 0 or num_bboxes == 0:
# No ground truth or boxes, return empty assignment
max_overlaps = overlaps.new_zeros((num_bboxes, ))
if num_gts == 0:
# No truth, assign everything to background
assigned_gt_inds[:] = 0
if gt_labels is None:
assigned_labels = None
else:
assigned_labels = overlaps.new_zeros((num_bboxes, ),
dtype=torch.long)
return AssignResult(
num_gts,
assigned_gt_inds,
max_overlaps,
labels=assigned_labels)

# for each anchor, which gt best overlaps with it
# for each anchor, the max iou of all gts
max_overlaps, argmax_overlaps = overlaps.max(dim=0)
Expand Down
20 changes: 17 additions & 3 deletions mmdet/core/bbox/assigners/point_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,33 @@ def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
NOTE: currently unused.
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
Returns:
:obj:`AssignResult`: The assign result.
"""
if points.shape[0] == 0 or gt_bboxes.shape[0] == 0:
raise ValueError('No gt or bboxes')
num_points = points.shape[0]
num_gts = gt_bboxes.shape[0]

if num_gts == 0 or num_points == 0:
# If no truth assign everything to the background
assigned_gt_inds = points.new_full((num_points, ),
0,
dtype=torch.long)
if gt_labels is None:
assigned_labels = None
else:
assigned_labels = points.new_zeros((num_points, ),
dtype=torch.long)
return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)

points_xy = points[:, :2]
points_stride = points[:, 2]
points_lvl = torch.log2(
points_stride).int() # [3...,4...,5...,6...,7...]
lvl_min, lvl_max = points_lvl.min(), points_lvl.max()
num_gts, num_points = gt_bboxes.shape[0], points.shape[0]

# assign gt box
gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2
Expand Down
31 changes: 28 additions & 3 deletions mmdet/core/bbox/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,39 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False):
bboxes1 and bboxes2.
Args:
bboxes1 (Tensor): shape (m, 4)
bboxes2 (Tensor): shape (n, 4), if is_aligned is ``True``, then m and n
must be equal.
bboxes1 (Tensor): shape (m, 4) in <x1, y1, x2, y2> format.
bboxes2 (Tensor): shape (n, 4) in <x1, y1, x2, y2> format.
If is_aligned is ``True``, then m and n must be equal.
mode (str): "iou" (intersection over union) or iof (intersection over
foreground).
Returns:
ious(Tensor): shape (m, n) if is_aligned == False else shape (m, 1)
Example:
>>> bboxes1 = torch.FloatTensor([
>>> [0, 0, 10, 10],
>>> [10, 10, 20, 20],
>>> [32, 32, 38, 42],
>>> ])
>>> bboxes2 = torch.FloatTensor([
>>> [0, 0, 10, 20],
>>> [0, 10, 10, 19],
>>> [10, 10, 20, 20],
>>> ])
>>> bbox_overlaps(bboxes1, bboxes2)
tensor([[0.5238, 0.0500, 0.0041],
[0.0323, 0.0452, 1.0000],
[0.0000, 0.0000, 0.0000]])
Example:
>>> empty = torch.FloatTensor([])
>>> nonempty = torch.FloatTensor([
>>> [0, 0, 10, 9],
>>> ])
>>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
>>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
>>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
"""

assert mode in ['iou', 'iof']
Expand Down
2 changes: 1 addition & 1 deletion mmdet/core/bbox/samplers/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def sample(self,
bboxes = bboxes[:, :4]

gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
if self.add_gt_as_proposals:
if self.add_gt_as_proposals and len(gt_bboxes) > 0:
bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
assign_result.add_gt_(gt_labels)
gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
Expand Down
2 changes: 1 addition & 1 deletion mmdet/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _filter_imgs(self, min_size=32):
valid_inds = []
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
for i, img_info in enumerate(self.img_infos):
if self.img_ids[i] not in ids_with_ann:
if self.filter_empty_gt and self.img_ids[i] not in ids_with_ann:
continue
if min(img_info['width'], img_info['height']) >= min_size:
valid_inds.append(i)
Expand Down
6 changes: 4 additions & 2 deletions mmdet/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def __init__(self,
img_prefix='',
seg_prefix=None,
proposal_file=None,
test_mode=False):
test_mode=False,
filter_empty_gt=True):
self.ann_file = ann_file
self.data_root = data_root
self.img_prefix = img_prefix
self.seg_prefix = seg_prefix
self.proposal_file = proposal_file
self.test_mode = test_mode
self.filter_empty_gt = filter_empty_gt

# join paths if data_root is specified
if self.data_root is not None:
Expand All @@ -66,7 +68,7 @@ def __init__(self,
self.proposals = self.load_proposals(self.proposal_file)
else:
self.proposals = None
# filter images with no annotation during training
# filter images too small
if not test_mode:
valid_inds = self._filter_imgs()
self.img_infos = [self.img_infos[i] for i in valid_inds]
Expand Down
15 changes: 1 addition & 14 deletions mmdet/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os.path as osp
import warnings

import mmcv
import numpy as np
Expand Down Expand Up @@ -42,28 +41,16 @@ def __init__(self,
with_label=True,
with_mask=False,
with_seg=False,
poly2mask=True,
skip_img_without_anno=True):
poly2mask=True):
self.with_bbox = with_bbox
self.with_label = with_label
self.with_mask = with_mask
self.with_seg = with_seg
self.poly2mask = poly2mask
self.skip_img_without_anno = skip_img_without_anno

def _load_bboxes(self, results):
ann_info = results['ann_info']
results['gt_bboxes'] = ann_info['bboxes']
if len(results['gt_bboxes']) == 0 and self.skip_img_without_anno:
if results['img_prefix'] is not None:
file_path = osp.join(results['img_prefix'],
results['img_info']['filename'])
else:
file_path = results['img_info']['filename']
warnings.warn(
'Skip the image "{}" that has no valid gt bbox'.format(
file_path))
return None

gt_bboxes_ignore = ann_info.get('bboxes_ignore', None)
if gt_bboxes_ignore is not None:
Expand Down
5 changes: 4 additions & 1 deletion mmdet/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,10 @@ def _pad_masks(self, results):
mmcv.impad(mask, pad_shape, pad_val=self.pad_val)
for mask in results[key]
]
results[key] = np.stack(padded_masks, axis=0)
if padded_masks:
results[key] = np.stack(padded_masks, axis=0)
else:
results[key] = np.empty((0, ) + pad_shape, dtype=np.uint8)

def __call__(self, results):
self._pad_img(results)
Expand Down
Loading

0 comments on commit e60e341

Please sign in to comment.