Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix]: fix multi-batch test problem #95

Merged
merged 11 commits into from
Sep 11, 2020
2 changes: 1 addition & 1 deletion mmdet3d/apis/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def single_gpu_test(model, data_loader, show=False, out_dir=None):
if show:
model.module.show_results(data, result, out_dir)

results.append(result)
results.extend(result)

batch_size = len(data['img_metas'][0].data)
for _ in range(batch_size):
Expand Down
3 changes: 0 additions & 3 deletions mmdet3d/models/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ def forward_test(self, points, img_metas, img=None, **kwargs):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(points), len(img_metas)))
# TODO: remove the restriction of imgs_per_gpu == 1 when prepared
samples_per_gpu = len(points[0])
assert samples_per_gpu == 1

if num_augs == 1:
img = [img] if img is None else img
Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/models/detectors/h3dnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def aug_test(self, points, img_metas, imgs=None, rescale=False):
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)

return merged_bboxes
return [merged_bboxes]

def extract_feats(self, points, img_metas):
"""Extract features of multiple samples."""
Expand Down
12 changes: 7 additions & 5 deletions mmdet3d/models/detectors/mvx_two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,22 +389,24 @@ def simple_test_pts(self, x, img_metas, rescale=False):
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results[0]
return bbox_results

def simple_test(self, points, img_metas, img=None, rescale=False):
"""Test function without augmentaiton."""
img_feats, pts_feats = self.extract_feat(
points, img=img, img_metas=img_metas)

bbox_list = dict()
bbox_list = [dict() for i in range(len(img_metas))]
if pts_feats and self.with_pts_bbox:
bbox_pts = self.simple_test_pts(
pts_feats, img_metas, rescale=rescale)
bbox_list.update(pts_bbox=bbox_pts)
for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
result_dict['pts_bbox'] = pts_bbox
if img_feats and self.with_img_bbox:
bbox_img = self.simple_test_img(
img_feats, img_metas, rescale=rescale)
bbox_list.update(img_bbox=bbox_img)
for result_dict, img_bbox in zip(bbox_list, bbox_img):
result_dict['img_bbox'] = img_bbox
return bbox_list

def aug_test(self, points, img_metas, imgs=None, rescale=False):
Expand All @@ -415,7 +417,7 @@ def aug_test(self, points, img_metas, imgs=None, rescale=False):
if pts_feats and self.with_pts_bbox:
bbox_pts = self.aug_test_pts(pts_feats, img_metas, rescale)
bbox_list.update(pts_bbox=bbox_pts)
return bbox_list
return [bbox_list]

def extract_feats(self, points, img_metas, imgs=None):
"""Extract point and image features of multiple samples."""
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/models/detectors/votenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def simple_test(self, points, img_metas, imgs=None, rescale=False):
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results[0]
return bbox_results

def aug_test(self, points, img_metas, imgs=None, rescale=False):
"""Test with augmentation."""
Expand All @@ -104,4 +104,4 @@ def aug_test(self, points, img_metas, imgs=None, rescale=False):
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)

return merged_bboxes
return [merged_bboxes]
4 changes: 2 additions & 2 deletions mmdet3d/models/detectors/voxelnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def simple_test(self, points, img_metas, imgs=None, rescale=False):
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results[0]
return bbox_results

def aug_test(self, points, img_metas, imgs=None, rescale=False):
"""Test function with augmentaiton."""
Expand All @@ -123,4 +123,4 @@ def aug_test(self, points, img_metas, imgs=None, rescale=False):
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)

return merged_bboxes
return [merged_bboxes]
2 changes: 1 addition & 1 deletion mmdet3d/models/roi_heads/h3d_roi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,4 @@ def simple_test(self, feats_dict, img_metas, points, rescale=False):
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results[0]
return bbox_results
2 changes: 1 addition & 1 deletion mmdet3d/models/roi_heads/part_aggregation_roi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def simple_test(self, feats_dict, voxels_dict, img_metas, proposal_list,
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results[0]
return bbox_results

def _bbox_forward_train(self, seg_feats, part_feats, voxels_dict,
sampling_results):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def test_inference_detector():
'6x8_160e_kitti-3d-3class.py'
detector = init_detector(detector_cfg, device='cpu')
results = inference_detector(detector, pcd)
bboxes_3d = results[0]['boxes_3d']
scores_3d = results[0]['scores_3d']
labels_3d = results[0]['labels_3d']
bboxes_3d = results[0][0]['boxes_3d']
scores_3d = results[0][0]['scores_3d']
labels_3d = results[0][0]['labels_3d']
assert bboxes_3d.tensor.shape[0] >= 0
assert bboxes_3d.tensor.shape[1] == 7
assert scores_3d.shape[0] >= 0
Expand Down
22 changes: 11 additions & 11 deletions tests/test_detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def test_voxel_net():

# test simple_test
results = self.simple_test(points, img_metas)
boxes_3d = results['boxes_3d']
scores_3d = results['scores_3d']
labels_3d = results['labels_3d']
boxes_3d = results[0]['boxes_3d']
scores_3d = results[0]['scores_3d']
labels_3d = results[0]['labels_3d']
assert boxes_3d.tensor.shape == (50, 7)
assert scores_3d.shape == torch.Size([50])
assert labels_3d.shape == torch.Size([50])
Expand Down Expand Up @@ -155,9 +155,9 @@ def test_vote_net():

# test simple_test
results = self.simple_test(points, img_metas)
boxes_3d = results['boxes_3d']
scores_3d = results['scores_3d']
labels_3d = results['labels_3d']
boxes_3d = results[0]['boxes_3d']
scores_3d = results[0]['scores_3d']
labels_3d = results[0]['labels_3d']
assert boxes_3d.tensor.shape[0] >= 0
assert boxes_3d.tensor.shape[1] == 7
assert scores_3d.shape[0] >= 0
Expand All @@ -171,8 +171,8 @@ def test_parta2():
parta2 = _get_detector_cfg(
'parta2/hv_PartA2_secfpn_2x8_cyclic_80e_kitti-3d-3class.py')
self = build_detector(parta2).cuda()
points_0 = torch.rand([2000, 4], device='cuda')
points_1 = torch.rand([2000, 4], device='cuda')
points_0 = torch.rand([1000, 4], device='cuda')
points_1 = torch.rand([1000, 4], device='cuda')
points = [points_0, points_1]
img_meta_0 = dict(box_type_3d=LiDARInstance3DBoxes)
img_meta_1 = dict(box_type_3d=LiDARInstance3DBoxes)
Expand All @@ -197,9 +197,9 @@ def test_parta2():

# test_simple_test
results = self.simple_test(points, img_metas)
boxes_3d = results['boxes_3d']
scores_3d = results['scores_3d']
labels_3d = results['labels_3d']
boxes_3d = results[0]['boxes_3d']
scores_3d = results[0]['scores_3d']
labels_3d = results[0]['labels_3d']
assert boxes_3d.tensor.shape[0] >= 0
assert boxes_3d.tensor.shape[1] == 7
assert scores_3d.shape[0] >= 0
Expand Down
61 changes: 38 additions & 23 deletions tests/test_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,22 +372,20 @@ def test_part_aggregation_ROI_head():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')

_setup_seed(0)
roi_head_cfg = _get_roi_head_cfg(
'parta2/hv_PartA2_secfpn_2x8_cyclic_80e_kitti-3d-3class.py')
self = build_head(roi_head_cfg).cuda()
spatial_features = torch.rand([1, 256, 200, 176], device='cuda')
seg_features = torch.rand([32000, 16], device='cuda')
neck_features = [torch.rand([1, 512, 200, 176], device='cuda')]
feats_dict = dict(
spatial_features=spatial_features,
seg_features=seg_features,
neck_features=neck_features)

voxels = torch.rand([32000, 5, 4], device='cuda')
num_points = torch.ones([32000], device='cuda')
coors = torch.zeros([32000, 4], device='cuda')
voxel_centers = torch.zeros([32000, 3], device='cuda')

features = np.load('./tests/test_samples/parta2_roihead_inputs.npz')
seg_features = torch.tensor(
features['seg_features'], dtype=torch.float32, device='cuda')
feats_dict = dict(seg_features=seg_features)

voxels = torch.tensor(
features['voxels'], dtype=torch.float32, device='cuda')
num_points = torch.ones([500], device='cuda')
coors = torch.zeros([500, 4], device='cuda')
voxel_centers = torch.zeros([500, 3], device='cuda')
box_type_3d = LiDARInstance3DBoxes
img_metas = [dict(box_type_3d=box_type_3d)]
voxels_dict = dict(
Expand All @@ -396,10 +394,27 @@ def test_part_aggregation_ROI_head():
coors=coors,
voxel_centers=voxel_centers)

pred_bboxes = LiDARInstance3DBoxes(torch.rand([5, 7], device='cuda'))
pred_scores = torch.rand([5], device='cuda')
pred_labels = torch.randint(0, 3, [5], device='cuda')
pred_clses = torch.rand([5, 3], device='cuda')
pred_bboxes = LiDARInstance3DBoxes(
torch.tensor(
[[0.3990, 0.5167, 0.0249, 0.9401, 0.9459, 0.7967, 0.4150],
[0.8203, 0.2290, 0.9096, 0.1183, 0.0752, 0.4092, 0.9601],
[0.2093, 0.1940, 0.8909, 0.4387, 0.3570, 0.5454, 0.8299],
[0.2099, 0.7684, 0.4290, 0.2117, 0.6606, 0.1654, 0.4250],
[0.9927, 0.6964, 0.2472, 0.7028, 0.7494, 0.9303, 0.0494]],
dtype=torch.float32,
device='cuda'))
pred_scores = torch.tensor([0.9722, 0.7910, 0.4690, 0.3300, 0.3345],
dtype=torch.float32,
device='cuda')
pred_labels = torch.tensor([0, 1, 0, 2, 1],
dtype=torch.int64,
device='cuda')
pred_clses = torch.tensor(
[[0.7874, 0.1344, 0.2190], [0.8193, 0.6969, 0.7304],
[0.2328, 0.9028, 0.3900], [0.6177, 0.5012, 0.2330],
[0.8985, 0.4894, 0.7152]],
dtype=torch.float32,
device='cuda')
proposal = dict(
boxes_3d=pred_bboxes,
scores_3d=pred_scores,
Expand All @@ -419,12 +434,12 @@ def test_part_aggregation_ROI_head():

bbox_results = self.simple_test(feats_dict, voxels_dict, img_metas,
proposal_list)
boxes_3d = bbox_results['boxes_3d']
scores_3d = bbox_results['scores_3d']
labels_3d = bbox_results['labels_3d']
assert boxes_3d.tensor.shape == (6, 7)
assert scores_3d.shape == (6, )
assert labels_3d.shape == (6, )
boxes_3d = bbox_results[0]['boxes_3d']
scores_3d = bbox_results[0]['scores_3d']
labels_3d = bbox_results[0]['labels_3d']
assert boxes_3d.tensor.shape == (12, 7)
assert scores_3d.shape == (12, )
assert labels_3d.shape == (12, )


def test_free_anchor_3D_head():
Expand Down
Binary file added tests/test_samples/parta2_roihead_inputs.npz
Binary file not shown.
4 changes: 2 additions & 2 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ def main():
set_random_seed(args.seed, deterministic=args.deterministic)

# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
Expand Down