Skip to content

Commit

Permalink
[Fxi] Fix ci bug of BatchATSSAssigner (#232)
Browse files Browse the repository at this point in the history
* Fix ci bug

* Fix ci bug

* Fix ci bug

* Fix ci bug

* Fix ci bug
  • Loading branch information
PeterH0323 authored and hhaAndroid committed Nov 3, 2022
1 parent 8fc8066 commit 0b48313
Showing 1 changed file with 9 additions and 9 deletions.
Expand Up @@ -82,8 +82,8 @@ def test_batch_atss_assigner_with_empty_gt(self):
[20., -4., 36., 12.],
]).unsqueeze(0).repeat(batch_size, 21, 1)

gt_bboxes = torch.empty(batch_size, 2, 4)
gt_labels = torch.empty(batch_size, 2, 1)
gt_bboxes = torch.zeros(batch_size, 0, 4)
gt_labels = torch.zeros(batch_size, 0, 1)

batch_assign_result = batch_atss_assigner.forward(
pred_bboxes, priors, num_level_bboxes, gt_labels, gt_bboxes,
Expand All @@ -101,15 +101,15 @@ def test_batch_atss_assigner_with_empty_gt(self):
torch.Size([batch_size, 84, num_classes]))
self.assertEqual(fg_mask_pre_prior.shape, torch.Size([batch_size, 84]))

def test_batch_atss_assigner_with_empty_boxes(self):
def test_batch_atss_assigner_with_empty_boxs(self):
"""Test corner case where a network might predict no boxes."""
num_classes = 2
batch_size = 2
batch_atss_assigner = BatchATSSAssigner(
topk=3,
iou_calculator=dict(type='mmdet.BboxOverlaps2D'),
num_classes=num_classes)
priors = torch.empty(84, 4)
priors = torch.zeros(84, 4)
gt_bboxes = torch.FloatTensor([
[0, 0, 60, 93],
[229, 0, 532, 157],
Expand Down Expand Up @@ -152,12 +152,12 @@ def test_batch_atss_assigner_with_empty_boxes_and_gt(self):
topk=3,
iou_calculator=dict(type='mmdet.BboxOverlaps2D'),
num_classes=num_classes)
priors = torch.empty(84, 4)
gt_bboxes = torch.empty(batch_size, 2, 4)
gt_labels = torch.empty(batch_size, 2, 1)
priors = torch.zeros(84, 4)
gt_bboxes = torch.zeros(batch_size, 0, 4)
gt_labels = torch.zeros(batch_size, 0, 1)
num_level_bboxes = [64, 16, 4]
pad_bbox_flag = torch.empty(batch_size, 2, 1)
pred_bboxes = torch.empty(batch_size, 84, 4)
pad_bbox_flag = torch.zeros(batch_size, 0, 1)
pred_bboxes = torch.zeros(batch_size, 0, 4)

batch_assign_result = batch_atss_assigner.forward(
pred_bboxes, priors, num_level_bboxes, gt_labels, gt_bboxes,
Expand Down

0 comments on commit 0b48313

Please sign in to comment.