Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for images to contain zero true detections #1531

Merged
merged 28 commits into from
Dec 24, 2019

Conversation

Erotemic
Copy link
Contributor

@Erotemic Erotemic commented Oct 11, 2019

When I went to train a CascadeRCNN on my dataset the loss computation failed when it loaded an image that had no truth boxes on it. I'm a bit surprised that this was an issue. Perhaps I'm using the library incorrectly? Does this library expect that there exists some magic negative bounding box in cases where there are really no objects of interest n an image?

If this is indeed a real issue, I think I fixed it. I also added tests to ensure that these corner cases don't break in the future. The main issue was in MaxIoUAssigner, which explicitly disallowed both the number of predicted boxes to be zero and the number of truth boxes to be zero. I modified the code so it instead returns an appropriate empty assignment if either truth or predictions have no boxes.

There was also an issue in bbox_head, where it asserted that all images had truth. I simply removed this check, and I believe the rest of the code still functions correctly (but it would be good if someone could double check this).

Lastly I added some docs to AssignResult to make it clear what the object contains.


EDIT 2019-10-18: I also fixed the ApproxMaxIoUAssigner and PointAssigner and added corresponding tests cases.

There were issues in CascadeRCNN, where it would crash when trying to RoiAlign the assigned ROIs in the case where the assignment was empty. I simply added some logic to skip that step, which is the correct thing to do.

There was an issue in loss_bbox, where it failed to compute a bbox loss when all boxes are assigned to the background. Again, the fix for this case is a simple check and skipping the computation of that loss term.

There was also an issue in my previous code where if no truth was made all predicted boxes got a gt_ind of -1, which means "dont care". I fixed this so now they correctly get assigned to 0, which means background.

Lastly, I added two tests to make sure cascade rcnn could compute losses for batches that had no truth boxes. I also added a test case for AnchorHead loss to ensure it computes background loss correctly in the case where the batch has no truth.


EDIT 2019-10-21: I found another edge case in convfc_bbox_head, where x.view(x.shape[0], -1) raised an error when x.shape[0] was 0. I added a function _view_flat_trailing_dims which tests for and handles this case.


EDIT 2019-10-29: I fixed a logic error where I wrote pos_inds.numel() instead of pos_inds.any(), rebased on master, and added tests for BBoxHead.

@Erotemic
Copy link
Contributor Author

The dashboards seem to be failing for all builds? I don't see any change in the history that might have caused this. Is there an issue with travis?

@hellock
Copy link
Member

hellock commented Oct 12, 2019

The dashboards seem to be failing for all builds? I don't see any change in the history that might have caused this. Is there an issue with travis?

Seems to be caused by the prebuilt PyTorch 1.3, which requires CUDA 10.1. It should be fixed in #1534.

@hellock
Copy link
Member

hellock commented Oct 12, 2019

#425

@Erotemic
Copy link
Contributor Author

Cool, I'll rebase on master once #1534 is merged. Then this PR should fix #425

@AAnoosheh
Copy link

Maybe we can also add an augmentation param to prevent pipelines.transforms.RandomCrop from returning None when the image has no annotations left?

@Erotemic Erotemic force-pushed the dev/allow_empty_gt branch 2 times, most recently from fb0cb9d to 8885bbb Compare October 18, 2019 13:58
@Erotemic
Copy link
Contributor Author

@AAnoosheh my main concern in this PR is to increase the robustness of the models regardless of the underlying dataloading / training loop. This PR doesn't fix the mmdet training loop because I think the dataloader explicitly ignores empty images. Perhaps the fix the augmenter / data loader would be best addressed by a separate PR after this one is merged?

@Erotemic Erotemic changed the title Allow for images to contain zero true detections WIP: Allow for images to contain zero true detections Oct 18, 2019
@Erotemic Erotemic changed the title WIP: Allow for images to contain zero true detections Allow for images to contain zero true detections Oct 18, 2019
@AAnoosheh
Copy link

@Erotemic
Sure thing.

Also another question: Do these changes still enforce a background-class loss on images without bounding boxes? Or do they just allow an image to pass through without error, but without computing a background loss?

The codebase is confusing enough that I can't figure out what's actually going on.
Thanks!

@Erotemic
Copy link
Contributor Author

@AAnoosheh, This codebase is certainly complex, but its not insurmountable. Things are pretty well compartmentalized and most function have exactly one job (I give a lot of credit to @hellock et al for this disciplined design), which makes it possible to grok things in small chunks.

To answer your question "do empty-gt batch items generate loss?": Yes. If you look in my test_anchor_head.py I test this explicitly. In the case where there is no gt, the class loss is nonzero because we want to encourage the RPN to predict background everywhere. However, note the bbox loss is zero because the shape of the boxes isn't particularly important when everything predicts background.

If you want to ensure that the builtin mmdet trainer works with empty-gt, then I think you'll have to look in the datasets to ensure that no-gt items aren't skipped. You'll probably have to test that it works, but I think that's all that needs to be done.

@yhcao6
Copy link
Collaborator

yhcao6 commented Oct 24, 2019

Thanks for the pr.

To test faster R-50 when there is no gt, I hack the program by inserting

gt_bboxes = [gt.new_empty(0, 4) for gt in gt_bboxes]
gt_labels = [gt.new_empty(0, ) for gt in gt_labels]

at the begining of farward_train in two_stage.py, but the program fails.

It seems there some are logic errors. For example, even there is no gt, the sampler of bbox head will still sample some proposals for training. As a result, cls_score.numel() is not zero and pos_inds is also not zero.

Could you check and test it again?

@andriilitvynchuk
Copy link

andriilitvynchuk commented Oct 28, 2019

Hello! I forked your pull request and tried to train Cascade RCNN + Guided Anchoring. And I got next error. This error is 100% connected with background pictures, because when I changed flag skip_img_without_anno in Albu extension to False this error occurred. (This means, that before that I skipped all pictures without annotations, using Albu augmentations extansion and everything worked fine)

  File "/home/andrii/mmdetection/tools/train.py", line 108, in <module>
    main()
  File "/home/andrii/mmdetection/tools/train.py", line 104, in main
    logger=logger)
  File "/home/andrii/mmdetection/mmdet/apis/train.py", line 58, in train_detector
    _dist_train(model, dataset, cfg, validate=validate)
  File "/home/andrii/mmdetection/mmdet/apis/train.py", line 192, in _dist_train
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
  File "/home/andrii/anaconda3/envs/dev/lib/python3.7/site-packages/mmcv/runner/runner.py", line 358, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/home/andrii/anaconda3/envs/dev/lib/python3.7/site-packages/mmcv/runner/runner.py", line 264, in train
    self.model, data_batch, train_mode=True, **kwargs)
  File "/home/andrii/mmdetection/mmdet/apis/train.py", line 38, in batch_processor
    losses = model(**data)
  File "/home/andrii/anaconda3/envs/dev/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/andrii/anaconda3/envs/dev/lib/python3.7/site-packages/mmcv/parallel/distributed.py", line 50, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/andrii/anaconda3/envs/dev/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/andrii/mmdetection/mmdet/core/fp16/decorators.py", line 49, in new_func
    return old_func(*args, **kwargs)
  File "/home/andrii/mmdetection/mmdet/models/detectors/base.py", line 100, in forward
    return self.forward_train(img, img_meta, **kwargs)
  File "/home/andrii/mmdetection/mmdet/models/detectors/cascade_rcnn.py", line 197, in forward_train
    *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
  File "/home/andrii/mmdetection/mmdet/models/anchor_heads/ga_rpn_head.py", line 53, in loss
    gt_bboxes_ignore=gt_bboxes_ignore)
  File "/home/andrii/mmdetection/mmdet/core/fp16/decorators.py", line 127, in new_func
    return old_func(*args, **kwargs)
  File "/home/andrii/mmdetection/mmdet/models/anchor_heads/guided_anchor_head.py", line 420, in loss
    ignore_ratio=cfg.ignore_ratio)
  File "/home/andrii/mmdetection/mmdet/core/anchor/guided_anchor_target.py", line 79, in ga_loc_target
    scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) *
IndexError: too many indices for tensor of dimension 1

I guess, that this problem occurred not because of your PR, but because background pictures' problem is much deeper and there is still a lot of work to be done.
Then I wanted to try Cascade RCNN + Libra RCNN. I got almost the same error(if it is needed, I can reproduce that error).

@Erotemic
Copy link
Contributor Author

@yhcao6 you are right. My faster_rcnn test is broken. Unfortunately it doesn't run on travis-CI because the RoIAlign forward implementation either needs a GPU or torchvision with the _C backend, so those tests were skipped in the CI runs.

I did have a logic error in bbox_head.BBoxHead.loss. I was assuming pos_inds was an array of indexes, but in fact it was an array of boolean flags, which means I should have used pos_inds.any() instead of numel. I fixed that error and also added a test in test_heads.py which tests the BBoxHead.loss function explicitly.

Correct me if I'm wrong, but I think the cls_score.numel() check is ok. I only added it in case there were no predictions. In the more common case where there are sampled predictions but the truth is empty there should be a cls loss because any prediction should be encouraged to choose background as its class.

@LitvinchukAndrey Yes, there are probably several other code paths where empty bounding boxes will still cause a problem. Because this problem is pretty big it might make sense to fix it incrementally, which will help prevent the PRs from becoming monolithic. In this PR I'm trying to only focus on issues in models. Furthermore, I'm only fixing the code where I encounter the problem and I can write a unit tests that demonstrates that the problem is fixed. I'm explicitly not fixing the code in apis and datasets. Once we have unit tests demonstrating that the loss functions can handle empty truth, it should be easier to go back and fix those other components of the system.

However, it does look like the problem you are encountering is in models/guided_anchor_head. Perhaps you can follow my examples in tests/test_heads.py and write a unit test that reproduces the issue and then find a fix for it in that particular component?

@ternaus
Copy link
Contributor

ternaus commented Oct 29, 2019

Very excited for this PR to get merged.

I have strange errors when I am trying to run training of the cascade RCNN on 32 GPUs that may be addressed by this PR.

@andriilitvynchuk
Copy link

andriilitvynchuk commented Oct 30, 2019

@Erotemic I tried to remove all features like Guided Anchoring and Libra RCNN and train vanilla Cascade RCNN and Faster RCNN. But with of both them I meet the same problem:

  File "/home/andrii/mmdetection/tools/train.py", line 108, in <module>
    main()
  File "/home/andrii/mmdetection/tools/train.py", line 104, in main
    logger=logger)
  File "/home/andrii/mmdetection/mmdet/apis/train.py", line 58, in train_detector
    _dist_train(model, dataset, cfg, validate=validate)
  File "/home/andrii/mmdetection/mmdet/apis/train.py", line 192, in _dist_train
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
  File "/home/andrii/anaconda3/envs/dev/lib/python3.7/site-packages/mmcv/runner/runner.py", line 358, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/home/andrii/anaconda3/envs/dev/lib/python3.7/site-packages/mmcv/runner/runner.py", line 264, in train
    self.model, data_batch, train_mode=True, **kwargs)
  File "/home/andrii/mmdetection/mmdet/apis/train.py", line 38, in batch_processor
    losses = model(**data)
  File "/home/andrii/anaconda3/envs/dev/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/andrii/anaconda3/envs/dev/lib/python3.7/site-packages/mmcv/parallel/distributed.py", line 50, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/andrii/anaconda3/envs/dev/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/andrii/mmdetection/mmdet/core/fp16/decorators.py", line 49, in new_func
    return old_func(*args, **kwargs)
  File "/home/andrii/mmdetection/mmdet/models/detectors/base.py", line 100, in forward
    return self.forward_train(img, img_meta, **kwargs)
  File "/home/andrii/mmdetection/mmdet/models/detectors/two_stage.py", line 176, in forward_train
    *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
  File "/home/andrii/mmdetection/mmdet/models/anchor_heads/rpn_head.py", line 51, in loss
    gt_bboxes_ignore=gt_bboxes_ignore)
  File "/home/andrii/mmdetection/mmdet/core/fp16/decorators.py", line 127, in new_func
    return old_func(*args, **kwargs)
  File "/home/andrii/mmdetection/mmdet/models/anchor_heads/anchor_head.py", line 189, in loss
    sampling=self.sampling)
  File "/home/andrii/mmdetection/mmdet/core/anchor/anchor_target.py", line 63, in anchor_target
    unmap_outputs=unmap_outputs)
  File "/home/andrii/mmdetection/mmdet/core/utils/misc.py", line 24, in multi_apply
    return tuple(map(list, zip(*map_results)))
  File "/home/andrii/mmdetection/mmdet/core/anchor/anchor_target.py", line 116, in anchor_target_single
    anchors, gt_bboxes, gt_bboxes_ignore, None, cfg)
  File "/home/andrii/mmdetection/mmdet/core/bbox/assign_sampling.py", line 32, in assign_and_sample
    gt_labels)
  File "/home/andrii/mmdetection/mmdet/core/bbox/samplers/base_sampler.py", line 78, in sample
    assign_result, gt_flags)
  File "/home/andrii/mmdetection/mmdet/core/bbox/samplers/sampling_result.py", line 17, in __init__
    self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]

I debugged it and found that the problem is because self.pos_assigned_gt_inds=tensor([], device='cuda:1', dtype=torch.int64) and it tries to get empty slice. I wonder how didn't you encounter this problem? May be problem is on my side. But previously everything worked fine(I understand that previously background pictures were filtered out by dataset). Maybe I need to change types of boxes and labels in dataset preparation? Now for background picture I have following assignments:

bboxes = np.zeros((0, 4))
labels = np.zeros((0, ))

Or we need to fix SamplingResult for empty self.pos_assigned_gt_inds tensor.
UPDATE
I fixed that problem by changing code in class SamplingResult in mmdet/core/bbox/samplers/sampling_result from

self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
if assign_result.labels is not None:
    self.pos_gt_labels = assign_result.labels[pos_inds]
else:
    self.pos_gt_labels = None

to

self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
if len(self.pos_assigned_gt_inds) != 0:
    self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
else:
    self.pos_gt_bboxes = torch.empty(0)

if assign_result.labels is not None:
    self.pos_gt_labels = assign_result.labels[pos_inds]
else:
     self.pos_gt_labels = None

(self.pos_gt_bboxes must be only torch.Tensor, because later there will be calling of its size())
It works both for Faster RCNN and Cascade RCNN. Also it fixes problems with Libra RCNN. Anyway, there are still problems with Guided Anchoring and fixing that problem isn't trivial, because all Guided Anchoring is based on ground truth boxes. I think it is worth of separate PR.
Let me know if this fix is right solution of problem, because I am new at mmdetection and haven't understood all architecture yet.

@Erotemic
Copy link
Contributor Author

@LitvinchukAndrey I am completely unable to reproduce your issue.

I went as far as to use the mmdet train scripts to attempt to reproduce the issue.

I used configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py (I dont have the COCO dataset on my machine) and modified LoadAnnotations so skip_img_without_anno=False. I then added:

            ann_info['bboxes'] = np.empty(shape=(0, 4), dtype=ann_info['bboxes'].dtype)
            ann_info['labels'] = np.empty(shape=(0), dtype=ann_info['labels'].dtype)
            ann_info['bboxes_ignore'] = np.empty(shape=(0, 4), dtype=ann_info['bboxes_ignore'].dtype)
            ann_info['labels_ignore'] = np.empty(shape=(0), dtype=ann_info['labels_ignore'].dtype)

to CustomDataset.prepare_train_img, which should hack all truth to be empty. (I also tried randomly removing truth in case the issue appeared when only some of the annotations were empty).

I also added print debugging to TwoStageDetector.forward_train:

                print('gt_bboxes[{}] = {}'.format(i, gt_bboxes[i]))
                print('gt_labels[{}] = {}'.format(i, gt_labels[i]))
                print('gt_bboxes_ignore[i] = {}'.format(gt_bboxes_ignore[i]))
                print('proposal_list[i].shape = {}'.format(proposal_list[i].shape))
                print('sampling_result.bboxes.shape = {!r}'.format(sampling_result.bboxes.shape))
                print('sampling_result.neg_bboxes.shape = {!r}'.format(sampling_result.neg_bboxes.shape))
                print('sampling_result.neg_inds.shape = {!r}'.format(sampling_result.neg_inds.shape))
                print('sampling_result.num_gts = {!r}'.format(sampling_result.num_gts))
                print('sampling_result.pos_assigned_gt_inds = {!r}'.format(sampling_result.pos_assigned_gt_inds))
                print('sampling_result.pos_bboxes = {!r}'.format(sampling_result.pos_bboxes))
                print('sampling_result.pos_gt_bboxes = {!r}'.format(sampling_result.pos_gt_bboxes))
                print('sampling_result.pos_gt_labels = {!r}'.format(sampling_result.pos_gt_labels))
                print('sampling_result.pos_inds = {!r}'.format(sampling_result.pos_inds))
                print('sampling_result.pos_is_gt = {!r}'.format(sampling_result.pos_is_gt))
                print('sampling_result = {!r}'.format(sampling_result))

And got things like:

gt_bboxes[0] = tensor([], device='cuda:0', size=(0, 4))
gt_labels[0] = tensor([], device='cuda:0', dtype=torch.int64)
gt_bboxes_ignore[i] = None
proposal_list[i].shape = torch.Size([2000, 5])
sampling_result.bboxes.shape = torch.Size([512, 4])
sampling_result.neg_bboxes.shape = torch.Size([512, 4])
sampling_result.neg_inds.shape = torch.Size([512])
sampling_result.num_gts = 0
sampling_result.pos_assigned_gt_inds = tensor([], device='cuda:0', dtype=torch.int64)
sampling_result.pos_bboxes = tensor([], device='cuda:0', size=(0, 4))
sampling_result.pos_gt_bboxes = tensor([], device='cuda:0', size=(0, 4))
sampling_result.pos_gt_labels = None
sampling_result.pos_inds = tensor([], device='cuda:0', dtype=torch.int64)
sampling_result.pos_is_gt = tensor([], device='cuda:0', dtype=torch.uint8)
sampling_result = <mmdet.core.bbox.samplers.sampling_result.SamplingResult object at 0x7faf26787510>
gt_bboxes[1] = tensor([], device='cuda:0', size=(0, 4))
gt_labels[1] = tensor([], device='cuda:0', dtype=torch.int64)
gt_bboxes_ignore[i] = None
proposal_list[i].shape = torch.Size([2000, 5])
sampling_result.bboxes.shape = torch.Size([512, 4])
sampling_result.neg_bboxes.shape = torch.Size([512, 4])
sampling_result.neg_inds.shape = torch.Size([512])
sampling_result.num_gts = 0
sampling_result.pos_assigned_gt_inds = tensor([], device='cuda:0', dtype=torch.int64)
sampling_result.pos_bboxes = tensor([], device='cuda:0', size=(0, 4))
sampling_result.pos_gt_bboxes = tensor([], device='cuda:0', size=(0, 4))
sampling_result.pos_gt_labels = None
sampling_result.pos_inds = tensor([], device='cuda:0', dtype=torch.int64)
sampling_result.pos_is_gt = tensor([], device='cuda:0', dtype=torch.uint8)
sampling_result = <mmdet.core.bbox.samplers.sampling_result.SamplingResult object at 0x7faeac7a64d0>
loss_bbox = {'loss_cls': tensor(3.0426, device='cuda:0', grad_fn=<MulBackward0>), 'acc': tensor([1.8555], device='cuda:0')}

So, this experiment shows that the mmdet trainer does work with the current version of this PR. (Given that you set skip_img_without_anno=False). I'm not sure why you are getting errors. Looking at your code, what happens if you set
labels = np.zeros((0, ), dtype=np.int64) (I did try setting my labels to float, but I still didn't get the error so perhaps this is not the problem)?

@Erotemic
Copy link
Contributor Author

Erotemic commented Nov 18, 2019

I've rebased on master, removed test_forward2 and consolidated it with test_forward.

I also determined that my initial guess on how to handled AssignResult.labels was incorrect. When gt_labels is specified and there are no truths, it is incorrect to set labels=None. This causes an error in bbox_target_single where num_pos > 0, but pos_gt_labels = None.

The correct behavior seems to be creating an empty LongTensor of shape (num_gt, num_bboxes). I'm now fairly confident that the code that creates the AssignResult objects is correct.

@yhcao6
Copy link
Collaborator

yhcao6 commented Nov 26, 2019

Really thanks for your hard work. I found another inconsistent config.

To test OHEM Sampler when there is no gt, I insert the following code

gt_bboxes = [gt.new_empty(0, 4) for gt in gt_bboxes]
gt_labels = [gt.new_empty(0, ) for gt in gt_labels]

at the begining of farward_train in two_stage.py, but the program fails.

Could you have a check about that?

@Erotemic
Copy link
Contributor Author

@yhcao6 I think I've addressed the issue.

When I was looking into it I found that I may have been setting assigned result slightly incorrectly. Previously I was setting AssignedResult.max_overlaps as an empty tensor when there were no truth boxes, however I believe it should be a 1D zero tensor with shape equal to the number of predicted boxes to indicate that no pred box had any overlap with the truth.

While debugging this I had to inspect the contents of the AssignResult class often. To make this easier I added a __str__ and __repr__ method that print out nice stats about it. I also added a docstring to AssignResult.

While inspecting AssignResult, I also noticed that there seemed to be a a bug in add_gt_. The max_overlaps tensor was extended by a OnesTensor of length self.num_gts. This causes gt_inds, labels, and max_overlaps to have different shapes, which I think is not correct. To keep the shapes the same I think it should be extended by len(gt_labels) instead. I left the original code in with a commend. Please double check that I've understood what is happening here correctly.

I also added a doctest to bbox_overlaps because I thought there was an issue there, but it turns out that there wasn't. But now there's a nice test for it.

Finally, I added standalone test for samplers in test_sampler.py and I added a test of OHEM with empty truth in test_forward to ensure it works end-to-end and in a unit test.

Please take a look and let me know if there are any other outstanding issues.

return ', '.join(parts)

def __repr__(self):
devnice = self.__nice__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

devnice --> device

mmdet/core/bbox/assigners/assign_result.py Outdated Show resolved Hide resolved
self.max_overlaps = torch.cat(
[self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
[self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why num_gts not equal to len(gt_labels)?


if num_squares == 0 or num_gts == 0:
# No predictions and/or truth, return empty assignment
overlaps = approxs.new(num_gts, num_squares)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overlaps initialization is not consistent.
In approx_max_iou_assigner: overlaps = approxs.new(num_gts, num_squares)
In max_iou_assigner: max_overlaps = overlaps.new_zeros((num_bboxes, ))
In point_assigner: max_overlaps = None

mmdet/core/bbox/assigners/max_iou_assigner.py Show resolved Hide resolved
>>> assert tuple(x.shape) == (0, 1)
"""
if x.numel() == 0:
num_trailing = reduce(mul, x.shape[1:], 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reduce and mul introduce two extrac packages: functools and operator, there should be a better way to implement the multiply of x.shape[1:].

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about x = x.flatten(1)

mmdet/models/detectors/cascade_rcnn.py Outdated Show resolved Hide resolved
@Erotemic
Copy link
Contributor Author

@yhcao6 I fixed the linting errors causing Travis to fail. Any chance this could get merged in the near future?

@hellock
Copy link
Member

hellock commented Dec 21, 2019

@Erotemic Thanks for your contribution. This is an important and non-trivial improvement so that we have to be cautious. Now we are doing a final test and it is very near to be merged.

@Erotemic
Copy link
Contributor Author

@hellock I agree that it is best to be cautions. I think merging this change is likely to break something, albeit that something should be small considering that test cases do cover a large portion of this functionality.

I've actually run into one of these issues recently: training MaskRCNN using segmentation masks breaks when there are no truth. The fix is simple, (just don't do mask loss if you have no masks) but I doubt this is the only other issue with enabling this feature.

@hellock
Copy link
Member

hellock commented Dec 22, 2019

Yes the core logic for supporting empty gts is not likely to break since the coverage of test cases is quite good. We mainly verify if other parts like data pre-processing, logging are working well.

I trained two models for person detection, w/ or w/o using images without person annotations. It is confirmed that those images are definitely used for training. The performance of the two models are similar, and I believe that hyper-parameters need to be tuned when using additional background images, which is out of the scope of this PR.

Overall, this PR looks good to me now. I've also reproduced the errors for empty masks. You may push a quick fix.

@Erotemic
Copy link
Contributor Author

On a software note: I pushed up the fix for empty masks.

On a research note: I'm not surprised that models trained with / without empty images are comparable. I think on average the negative cases in an image without any objects of interest won't provide SGD with much more information than images with objects of interest. However, it does open the possibility of finding and including truly difficult examples that will benefit SGD. I think that is where adding this feature will really shine in terms of improving PR / ROC curves. Also it makes mmdet more robust to unseen datasets which will often contain images without any annotations. (The main reason I'm interested in getting this merged in a timely fashion is because we want to let the users of our VIAME project to train models on their custom datasets, which will almost certainly contain empty images.)

@hellock
Copy link
Member

hellock commented Dec 23, 2019

Yes hard negative mining is usually applied to a pool of background images in practice.

I fixed the padding transform for empty masks and you may have a check. I will merge it if there is no further issues.

@hellock hellock merged commit b696670 into open-mmlab:master Dec 24, 2019
@hellock
Copy link
Member

hellock commented Dec 24, 2019

@Erotemic Thanks for your enthusiasm and nice work! It finally got merged.

@yangninghua
Copy link

https://github.com/open-mmlab/mmdetection/tree/v1.0.0

I use version 1.0.0 and added unlabeled background image training. There is no problem when using a single GPU. But when I try to train with multiple GPUs,

when run the script "./tools/dist_train.sh ./configs/mask_rcnn_r50_fpn_1x.py 4", meet the problem
2020-01-14 09: 01: 53,683-INFO-workflow: [('train', 1)], max: 12 epochs

Stuck here for hours, please help

@yangninghua
Copy link

yangninghua commented Feb 21, 2020

And the occupancy rate of multiple GPUs is 100%

I did the tests myself. Once multi GPU training with the participation of unlabeled background images, it will get stuck

@leemengwei
Copy link

@yangninghua
in file : mmdet/models/bbox_heads/bbox_head.py
after 119 pos_inds = labels > 0
add: losses['loss_bbox'] = torch.tensor(0.).to(bbox_pred.device)
will solve your problem, I forget where is the reference for this solve. but it works

ioir123ju pushed a commit to ioir123ju/mmdetection that referenced this pull request Mar 30, 2020
* Allow for images to contain zero true detections

* Allow for empty assignment in PointAssigner

* Allow ApproxMaxIouAssigner to return an empty result

* Fix CascadeRNN forward when entire batch has no truth

* Correctly assign boxes to background when there is no truth

* Fix assignment tests

* Make flatten robust

* Fix bbox loss with empty pred/truth

* Fix logic error in BBoxHead.loss

* Add tests for empty truth cases

* tests faster rcnn empty forward

* Skip roipool forward tests if torchvision is not installed

* Add tests for bbox/anchor heads

* Consolidate test_forward and test_forward2

* Fix assign_results.labels = None when gt_labels is given; Add test for this case

* Fix OHEM Sampler with zero truth

* remove xdev

* resolve 3 reviews

* Fix flake8

* refactoring

* fix yaml format

* add filter flag

* minor fix

* delete redundant code in load anno

* fix flake8 errors

* quick fix for empty truth with masks

* fix yapf error

* fix mask padding for empty masks

Co-authored-by: Cao Yuhang <yhcao6@gmail.com>
Co-authored-by: Kai Chen <chenkaidev@gmail.com>
mike112223 pushed a commit to mike112223/mmdetection that referenced this pull request Aug 25, 2020
* Allow for images to contain zero true detections

* Allow for empty assignment in PointAssigner

* Allow ApproxMaxIouAssigner to return an empty result

* Fix CascadeRNN forward when entire batch has no truth

* Correctly assign boxes to background when there is no truth

* Fix assignment tests

* Make flatten robust

* Fix bbox loss with empty pred/truth

* Fix logic error in BBoxHead.loss

* Add tests for empty truth cases

* tests faster rcnn empty forward

* Skip roipool forward tests if torchvision is not installed

* Add tests for bbox/anchor heads

* Consolidate test_forward and test_forward2

* Fix assign_results.labels = None when gt_labels is given; Add test for this case

* Fix OHEM Sampler with zero truth

* remove xdev

* resolve 3 reviews

* Fix flake8

* refactoring

* fix yaml format

* add filter flag

* minor fix

* delete redundant code in load anno

* fix flake8 errors

* quick fix for empty truth with masks

* fix yapf error

* fix mask padding for empty masks

Co-authored-by: Cao Yuhang <yhcao6@gmail.com>
Co-authored-by: Kai Chen <chenkaidev@gmail.com>
liuhuiCNN pushed a commit to liuhuiCNN/mmdetection that referenced this pull request May 21, 2021
* share train_batch_size from reader to model

* update floor_divide

* add makedirs for dumping config
@joihn
Copy link

joihn commented Oct 14, 2021

How should I format a coco-style dataset JSON to take into account these "pure background" images ?
edit : all good, no annotation at all for those images and filter_empty_gt=False in the cfg did the trick :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants