Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue in refine_bboxes and add doctest #1962

Merged
merged 5 commits into from
Jan 13, 2020

Conversation

Erotemic
Copy link
Contributor

The BBoxHead.refine_boxes function was failing in cases where there is only a single roi in an image. When this happened the squeeze function removed all dimensions of the tensor and it became a scalar. This caused issues when trying to index into the rois tensor because it did not keep its two dimensional shape and that cause an assertion to fail in regress_by_class.

This issue is fixed by specifying that we are squeezing over dimension 1. This prevents the 0-th dimension from collapsing and guarantees that rois keeps its 2d shape.

Also while I was in this function I noticed the assertion:

        img_ids = rois[:, 0].long().unique(sorted=True)
        assert img_ids.numel() == len(img_metas)

this will fail if in the corner case where an image doesn't have an roi. Changing the == to a <= fixes this problem. I tested for this case and the rest of the logic works fine.

Speaking of test cases I added a doctest for this function so I could debug it. I was having a difficult time running it without the functions in Kitware's kwarray and kwimage vision utility libraries. I added simple versions of the kwimage.Boxes.random and kwarray.ensure_rng functions in a new file mmdet/core/bbox/demodata.py. These functions let a developer generate random valid bounding boxes, which is very useful for writing tests (and sometimes algorithms!).

Lastly when constructing pos_is_gts for the test, I ran into an issue where I really needed the kwarray.group_items function. I didn't want to simply dump that function in this file or the demodata file. So I simply used it and added a # xdoctest: +REQUIRES(module:kwarray) to the doctests, so it wont run unless that module exists.

I could port that code to a utility library here to avoid the dependency on kwarray, but I'm not sure where it goes. Alternatively, kwarray is a pure-python pip installable module, so it's not the end of the world if its added as a dependency.

>>> num = 10
>>> scale = 512
>>> rng = 0
>>> boxes = random_boxes()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you intend random(num, scale, rng)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did yes. Fixing.

@@ -187,13 +188,45 @@ def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):

Returns:
list[Tensor]: Refined bboxes of each image in a mini-batch.

Example:
>>> # xdoctest: +REQUIRES(module:kwarray)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a temporary solution, kwarrary can be added as a test-time dependency in tests/requirements.txt.
For future usage except for unittests, I suggest porting it to mmcv.

tlbr[:, 2] = br_x * scale
tlbr[:, 3] = br_y * scale

boxes = torch.FloatTensor(tlbr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.from_numpy(tlbr) would be better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are they not equivalent in this case? The numpy array is already a float32, so there shouldn't be any difference in calling FloatTensor vs from_numpy. Am I understanding this wrong? I'm basing my understanding on this thread: https://stackoverflow.com/questions/48482787/pytorch-memory-model-torch-from-numpy-vs-torch-tensor

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_numpy() avoids the data copy. BTW, Torch.FloatTensor(input) less preferred than the new api torch.tensor(input) in new versions of pytorch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know, thanks for the info!

@Erotemic
Copy link
Contributor Author

The latest push should fix the issues in the test. I also added a unittest variant of the doctest as well. The unit test is more comprehensive and tests multiple values of n_rois and n_imgs. I also ensured that all randomness was seeded in the unit-test variant so chances that cause issues in the future will be easier to reproduce. The doctest variant still contains some non-determinism.

@hellock hellock merged commit 8c0ecd1 into open-mmlab:master Jan 13, 2020
mattdawkins added a commit to VIAME/mmdetection that referenced this pull request Mar 13, 2020
* origin/viame/master: (28 commits)
  Fix FPN upscale
  Extra compiler args
  VIAME-specific build parameters
  Bump version to 1.0.0 (open-mmlab#2029)
  Fix the incompatibility of the latest numpy and pycocotools (open-mmlab#2024)
  format configs with yapf (open-mmlab#2023)
  options for FCNMaskHead during testtime (open-mmlab#2013)
  Enhance AssignResult and SamplingResult (open-mmlab#1995)
  Fix typo activatation -> activation (open-mmlab#2007)
  Reorganize requirements, make albumentations optional (open-mmlab#1969)
  Encapsulate DCN into a ConvModule & Conv_layers (open-mmlab#1894)
  Code for Paper "Bridging the Gap Between Anchor-based and Anchor-free… (open-mmlab#1872)
  Non color images (open-mmlab#1976)
  Fix albu mask format bug (open-mmlab#1818)
  Fix CI by limiting the version of torchvision (open-mmlab#2005)
  Add ability to overwite existing module in Registry (open-mmlab#1982)
  bug for distributed training (open-mmlab#1985)
  Update Libra RetinaNet config with the latest code (open-mmlab#1975)
  Fix issue in refine_bboxes and add doctest (open-mmlab#1962)
  add link to official repo (open-mmlab#1971)
  ...
ioir123ju pushed a commit to ioir123ju/mmdetection that referenced this pull request Mar 30, 2020
* Fix issue in refine_bboxes and add doctest

* fix pillow version on travis

* Fixes based on review

* Fix errors in doctest and add comprehensive unit test

* Fix linting error
mike112223 pushed a commit to mike112223/mmdetection that referenced this pull request Aug 25, 2020
* Fix issue in refine_bboxes and add doctest

* fix pillow version on travis

* Fixes based on review

* Fix errors in doctest and add comprehensive unit test

* Fix linting error
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants