Skip to content

Commit

Permalink
Fix (transforms): fix bugs in transforms when the gt is empty (#2289)
Browse files Browse the repository at this point in the history
* fix (transforms): fix bugs in transforms when the gt is empty

* Fix (SSD): fix bug of data pipeline used in SSD, test more pipelines

* Fix (test_config): move print for more clear debug info

* clean comments
  • Loading branch information
ZwwWayne committed Mar 19, 2020
1 parent 091f5e2 commit 1a7354b
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 32 deletions.
93 changes: 61 additions & 32 deletions mmdet/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,11 @@ def _resize_masks(self, results):
mmcv.imresize(mask, mask_size, interpolation='nearest')
for mask in results[key]
]
results[key] = np.stack(masks)
if masks:
results[key] = np.stack(masks)
else:
results[key] = np.empty(
(0, ) + results['img_shape'], dtype=np.uint8)

def _resize_seg(self, results):
for key in results.get('seg_fields', []):
Expand Down Expand Up @@ -245,10 +249,15 @@ def __call__(self, results):
results['flip_direction'])
# flip masks
for key in results.get('mask_fields', []):
results[key] = np.stack([
masks = [
mmcv.imflip(mask, direction=results['flip_direction'])
for mask in results[key]
])
]
if masks:
results[key] = np.stack(masks)
else:
results[key] = np.empty(
(0, ) + results['img_shape'], dtype=np.uint8)

# flip segs
for key in results.get('seg_fields', []):
Expand Down Expand Up @@ -410,7 +419,12 @@ def __call__(self, results):
gt_mask = results['gt_masks'][i][crop_y1:crop_y2,
crop_x1:crop_x2]
valid_gt_masks.append(gt_mask)
results['gt_masks'] = np.stack(valid_gt_masks)

if valid_gt_masks:
results['gt_masks'] = np.stack(valid_gt_masks)
else:
results['gt_masks'] = np.empty(
(0, ) + results['img_shape'], dtype=np.uint8)

return results

Expand Down Expand Up @@ -528,7 +542,8 @@ def __repr__(self):
'saturation_range={}, hue_delta={})').format(
self.brightness_delta,
(self.contrast_lower, self.contrast_upper),
self.saturation_range, self.hue_delta)
(self.saturation_lower, self.saturation_upper),
self.hue_delta)
return repr_str


Expand Down Expand Up @@ -587,7 +602,12 @@ def __call__(self, results):
0).astype(mask.dtype)
expand_mask[top:top + h, left:left + w] = mask
expand_gt_masks.append(expand_mask)
results['gt_masks'] = np.stack(expand_gt_masks)

if expand_gt_masks:
results['gt_masks'] = np.stack(expand_gt_masks)
else:
results['gt_masks'] = np.empty(
(0, ) + results['img_shape'], dtype=np.uint8)

# not tested
if 'gt_semantic_seg' in results:
Expand Down Expand Up @@ -623,6 +643,7 @@ class MinIoURandomCrop(object):

def __init__(self, min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3):
# 1: return ori img
self.min_ious = min_ious
self.sample_mode = (1, *min_ious, 0)
self.min_crop_size = min_crop_size

Expand Down Expand Up @@ -652,37 +673,45 @@ def __call__(self, results):
(int(left), int(top), int(left + new_w), int(top + new_h)))
overlaps = bbox_overlaps(
patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1)
if overlaps.min() < min_iou:
if len(overlaps) > 0 and overlaps.min() < min_iou:
continue

# center of boxes should inside the crop img
center = (boxes[:, :2] + boxes[:, 2:]) / 2
mask = ((center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) *
(center[:, 0] < patch[2]) * (center[:, 1] < patch[3]))
if not mask.any():
continue
boxes = boxes[mask]
labels = labels[mask]

# adjust boxes
# only adjust boxes and instance masks when the gt is not empty
if len(overlaps) > 0:
# adjust boxes
center = (boxes[:, :2] + boxes[:, 2:]) / 2
mask = ((center[:, 0] > patch[0]) *
(center[:, 1] > patch[1]) *
(center[:, 0] < patch[2]) *
(center[:, 1] < patch[3]))
if not mask.any():
continue

boxes = boxes[mask]
labels = labels[mask]

boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
boxes -= np.tile(patch[:2], 2)

results['gt_bboxes'] = boxes
results['gt_labels'] = labels

if 'gt_masks' in results:
valid_masks = [
results['gt_masks'][i] for i in range(len(mask))
if mask[i]
]
# here the valid_masks is not empty
results['gt_masks'] = np.stack([
gt_mask[patch[1]:patch[3], patch[0]:patch[2]]
for gt_mask in valid_masks
])

# adjust the img no matter whether the gt is empty before crop
img = img[patch[1]:patch[3], patch[0]:patch[2]]
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
boxes -= np.tile(patch[:2], 2)

results['img'] = img
results['gt_bboxes'] = boxes
results['gt_labels'] = labels

if 'gt_masks' in results:
valid_masks = [
results['gt_masks'][i] for i in range(len(mask))
if mask[i]
]
results['gt_masks'] = np.stack([
gt_mask[patch[1]:patch[3], patch[0]:patch[2]]
for gt_mask in valid_masks
])

# not tested
if 'gt_semantic_seg' in results:
Expand Down
107 changes: 107 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,110 @@ def test_config_build_detector():
train_cfg=config_mod.train_cfg,
test_cfg=config_mod.test_cfg)
assert detector is not None


def test_config_data_pipeline():
"""
Test whether the data pipeline is valid and can process corner cases.
CommandLine:
xdoctest -m tests/test_config.py test_config_build_data_pipeline
"""
from xdoctest.utils import import_module_from_path
from mmdet.datasets.pipelines import Compose
import numpy as np

config_dpath = _get_config_directory()
print('Found config_dpath = {!r}'.format(config_dpath))

# Only tests a representative subset of configurations
# TODO: test pipelines using Albu, current Albu throw None given empty GT
config_names = [
'wider_face/ssd300_wider_face.py',
'pascal_voc/ssd300_voc.py',
'pascal_voc/ssd512_voc.py',
# 'albu_example/mask_rcnn_r50_fpn_1x.py',
'fp16/mask_rcnn_r50_fpn_fp16_1x.py',
]

print('Using {} config files'.format(len(config_names)))

for config_fname in config_names:
config_fpath = join(config_dpath, config_fname)
config_mod = import_module_from_path(config_fpath)

# remove loading pipeline
loading_pipeline = config_mod.train_pipeline.pop(0)
config_mod.train_pipeline.pop(0)
config_mod.test_pipeline.pop(0)

train_pipeline = Compose(config_mod.train_pipeline)
test_pipeline = Compose(config_mod.test_pipeline)

print(
'Building data pipeline, config_fpath = {!r}'.format(config_fpath))

print('Test training data pipeline: \n{!r}'.format(train_pipeline))
img = np.random.randint(0, 255, size=(888, 666, 3), dtype=np.uint8)
if loading_pipeline.get('to_float32', False):
img = img.astype(np.float32)
results = dict(
filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_bboxes=np.array([[35.2, 11.7, 39.7, 15.7]], dtype=np.float32),
gt_labels=np.array([1], dtype=np.int64),
gt_masks=[(img[..., 0] == 233).astype(np.uint8)],
)
results['bbox_fields'] = ['gt_bboxes']
results['mask_fields'] = ['gt_masks']
output_results = train_pipeline(results)
assert output_results is not None

print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
results = dict(
filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_bboxes=np.array([[35.2, 11.7, 39.7, 15.7]], dtype=np.float32),
gt_labels=np.array([1], dtype=np.int64),
gt_masks=[(img[..., 0] == 233).astype(np.uint8)],
)
results['bbox_fields'] = ['gt_bboxes']
results['mask_fields'] = ['gt_masks']
output_results = test_pipeline(results)
assert output_results is not None

# test empty GT
print('Test empty GT with training data pipeline: \n{!r}'.format(
train_pipeline))
results = dict(
filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_bboxes=np.zeros((0, 4), dtype=np.float32),
gt_labels=np.array([], dtype=np.int64),
gt_masks=[],
)
results['bbox_fields'] = ['gt_bboxes']
results['mask_fields'] = ['gt_masks']
output_results = train_pipeline(results)
assert output_results is not None

print('Test empty GT with testing data pipeline: \n{!r}'.format(
test_pipeline))
results = dict(
filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_bboxes=np.zeros((0, 4), dtype=np.float32),
gt_labels=np.array([], dtype=np.int64),
gt_masks=[],
)
results['bbox_fields'] = ['gt_bboxes']
results['mask_fields'] = ['gt_masks']
output_results = test_pipeline(results)
assert output_results is not None

0 comments on commit 1a7354b

Please sign in to comment.