Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YOLACT #3456

Merged
merged 34 commits into from
Sep 26, 2020
Merged

YOLACT #3456

merged 34 commits into from
Sep 26, 2020

Conversation

chongzhou96
Copy link
Contributor

No description provided.

@CLAassistant
Copy link

CLAassistant commented Aug 1, 2020

CLA assistant check
All committers have signed the CLA.

@codecov
Copy link

codecov bot commented Aug 1, 2020

Codecov Report

Merging #3456 into master will increase coverage by 0.64%.
The diff coverage is 82.67%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3456      +/-   ##
==========================================
+ Coverage   61.10%   61.75%   +0.64%     
==========================================
  Files         213      215       +2     
  Lines       15213    15668     +455     
  Branches     2587     2661      +74     
==========================================
+ Hits         9296     9675     +379     
- Misses       5449     5510      +61     
- Partials      468      483      +15     
Flag Coverage Δ
#unittests 61.75% <82.67%> (+0.64%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmdet/models/dense_heads/rpn_test_mixin.py 83.87% <0.00%> (ø)
mmdet/datasets/pipelines/loading.py 48.42% <23.52%> (-2.45%) ⬇️
mmdet/models/detectors/yolact.py 75.92% <75.92%> (ø)
mmdet/models/dense_heads/yolact_head.py 84.37% <84.37%> (ø)
mmdet/core/post_processing/bbox_nms.py 88.67% <96.42%> (+8.67%) ⬆️
mmdet/__init__.py 94.11% <100.00%> (ø)
...mdet/core/bbox/iou_calculators/iou2d_calculator.py 88.46% <100.00%> (+0.96%) ⬆️
mmdet/core/post_processing/__init__.py 100.00% <100.00%> (ø)
mmdet/models/dense_heads/__init__.py 100.00% <100.00%> (ø)
mmdet/models/detectors/__init__.py 100.00% <100.00%> (ø)
... and 6 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 14b5d8e...dea1329. Read the comment docs.

Copy link
Collaborator

@yhcao6 yhcao6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your nice work. I give some comments and feel free to point out if I make something wrong.
The biggest problem may be the incomplete docstring and unit test for we follow a strict criteria now.

@@ -0,0 +1,5 @@
_base_ = ['./yolact_r50.py']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List is redundant.


model = dict(pretrained='torchvision://resnet101', backbone=dict(depth=101))

work_dir = './work_dirs/yolact_r101'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value of work_dir now is set to be the name of the config. So this is redundant.

configs/yolact/yolact_r50.py Outdated Show resolved Hide resolved

wh = (rb - lt + 1).clamp(min=0)
overlap = wh[:, 0] * wh[:, 1]
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MMDet v2.0 rewrites the coordinate system. So we don't need +1 anymore.

for obj_idx in range(downsampled_masks.size(0)):
segm_targets[cur_gt_labels[obj_idx] - 1] = torch.max(
segm_targets[cur_gt_labels[obj_idx] - 1],
downsampled_masks[obj_idx])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be better to move the target computation to a new function get_target to be consistent with other heads.

num_protos=32,
num_classes=80,
loss_mask_weight=1.0,
max_masks_to_train=100,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the last comma.

super(YolactProtonet, self).__init__()
config = [(256, 3, {'padding': 1})] * 3 + \
[(None, -2, {}), (256, 3, {'padding': 1})] \
+ [(num_protos, 1, {})]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the hardcode here? For example, we add a parameter num_proto_convs to indicate the first 3 convs. And we use upsample_cfg to configure the method of upsampling, you can refer here.

mmdet/models/dense_heads/yolact_head.py Show resolved Hide resolved

num_imgs = x.size(0)
# Training state
if sampling_results is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am not wrong, here we use sampling_results is None to decide if the model is training. But we can also use self.training to decide if is training or not. It will be more straightforward.

mode='bilinear',
align_corners=False).squeeze(0)
cur_gt_masks = cur_gt_masks.gt(0.5).float()
mask_targets = cur_gt_masks[pos_assigned_gt_inds]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to YolactSegmHead I think it will make the code more clear if we move the target computation to a new function get_target though it is simple.

@yhcao6
Copy link
Collaborator

yhcao6 commented Aug 10, 2020

Could you add unit tests for yolact head? You can refer test_head.py.

Args:
img (Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the spaceline

@yhcao6
Copy link
Collaborator

yhcao6 commented Aug 13, 2020

Dear @chongzhou96, I think we have almost done but some problems about the unit tests. The coverage rate of yolact_head.py is 70.93%, but yolact.py is only 33.33%, you can check here. So I think it better to add one more unit tests in test_forward to test forward method of Yolact detector.

@@ -46,3 +47,64 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', eps=1e-6):
if exchange:
ious = ious.T
return ious


def batch_bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's simple to make bbox_overlaps (in core/bbox/iou_calculators/iou2d_calculator.py) support arbitrary dimensions. This method is better to be avoided.

mmdet/core/post_processing/bbox_nms.py Show resolved Hide resolved
will not be considered.
nms_cfg (dict): NMS config.
max_num (int): if there are more than max_num bboxes after NMS,
only top max_num will be kept.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default: -1.

multi_coeffs (Tensor): shape (n, #class*coeffs_dim).
score_thr (float): bbox threshold, bboxes with scores lower than it
will not be considered.
nms_cfg (dict): NMS config.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike batched_nms, fast_nms does not support multiple methods. Thus there is no need to add an argument nms_cfg. It is okay to use top_k and iou_threshold directly, which is more clear. For the api design, it should be similar to nms and soft_nms.


Args:
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
multi_scores (Tensor): shape (n, #class), where the last column
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the shape be (n, #class+1)?

num_head_convs (int): Number of the conv layers shared by
box and cls branches.
num_protos (int): Number of the mask coefficients.
use_OHEM (bool): If true, `loss_single_OHEM` will be used for
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_ohem

norm_cfg (dict): Dictionary to construct and config norm layer.
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of localization loss.
anchor_generator (dict): Config dict for anchor generator.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow the argument order of AnchorHead.

mmdet/models/dense_heads/yolact_head.py Show resolved Hide resolved
gt_bboxes_ignore=None):
"""A combination of the func:`AnchorHead.loss` and func:`SSDHead.loss`.

When self.use_OHEM == True, it functions like `SSDHead.loss`,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rst syntax uses double `: When ``self.use_OHEM == True``, it functions like ``SSDHead.loss``

label_weights, bbox_targets, bbox_weights,
num_total_samples):
""""See func:`SSDHead.loss`."""
loss_cls_all = F.cross_entropy(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Losses are hard-coded instead of using self.loss_cls and self.loss_reg.

@hellock
Copy link
Member

hellock commented Aug 23, 2020

@yhcao6 Please help benchmark the latest version and it can be merged then.

| Image Size | Backbone | *FPS | mAP | Weights |
|:----------:|:-------------:|:----:|:----:|----------------------------------------------------------------------------------------------------------------------|
| 550 | Resnet50-FPN | 42.5 | 29.0 | [yolact_r50_epoch_55.pth](https://drive.google.com/file/d/1S30dWxxF1jmz7Tbh1cWbHZUiMpNOK1ts/view?usp=sharing) |
| 550 | Resnet101-FPN | 33.5 | 30.4 | [yolact_r101_epoch_55.pth](https://drive.google.com/file/d/1yv752K659KbR9arOVD1f6Rst8i00OZnX/view?usp=sharing)|
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chongzhou96 I have trained yolact r50 myself and the mAP is exactly the same with README.md. But single GPU training takes me around 5 days, so may I ask if you have tried SyncBN or 8-GPUs training but increasing the learning rate?
And I benchmarked the inference speed but I found it hard to get the same FPS with README.md.
I use benchmark.py and it should exclude the unnecessary time of data loading and post-processing.
For Yolact resnet 50, on v100 I can get only 17.5 FPS. May I ask have you tried to benchmark the influence speed with this Pr?

return x1, x2


class InterpolateModule(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems no need to define it.

@hellock hellock merged commit d460a53 into open-mmlab:master Sep 26, 2020
mattdawkins added a commit to VIAME/mmdetection that referenced this pull request Dec 3, 2020
* tag 'v2.5.0': (102 commits)
  Bump to v2.5.0 (open-mmlab#3879)
  [Fix]: fix mask rcnn training stuck problem when there is no positive rois (open-mmlab#3713)
  Add missing notes in data customization (open-mmlab#3906)
  support to use pytorch 1.6 in docker (open-mmlab#3905)
  [enhance]: Improve documentation of modules and dataset customization (open-mmlab#3821)
  fix cv2 import error of ligGL.so.1 (open-mmlab#3891)
  fix the API change bug of PAA (open-mmlab#3883)
  Fix typo in bbox_flip (open-mmlab#3886)
  [Enhance]: Convert mask to bool before using it as img's index for robustness and speedup (open-mmlab#3870)
  [Docs] Remove duplicate content in docs/config.md (open-mmlab#3875)
  [Docs] Fix typo in docs/tutorials/new_dataset.md (open-mmlab#3876)
  Support TTA of ATSS, FCOS, YOLOv3 (open-mmlab#3844)
  Fix nonzero in NMS for PyTorch 1.6.0 (open-mmlab#3867)
  [Refactor] refactor get_subset_by_classes in dataloader for training with empty-GT images (open-mmlab#3695)
  fix rpn transforming bug in two stage networks (open-mmlab#3754)
  Clean background_labels in the dense heads (open-mmlab#3221)
  improve the function of simple_test_bboxes (open-mmlab#3853)
  Add doc of modify loss (open-mmlab#3777)
  fix sabl validating bug (open-mmlab#3849)
  YOLACT (open-mmlab#3456)
  ...
@OpenMMLab-Assistant-007
Copy link

Hi!
@chongzhou96
First of all, we want to express our gratitude for your significant PR in the OpenMMLab project. Your contribution is highly appreciated, and we are grateful for your efforts in helping improve this open-source project during your personal time. We believe that many developers will benefit from your PR.

We would also like to invite you to join our Special Interest Group (SIG) private channel on Discord, where you can share your experiences, ideas, and build connections with like-minded peers. To join the SIG channel, simply message moderator— OpenMMLab on Discord or briefly share your open-source contributions in the #introductions channel and we will assist you. Look forward to seeing you there! Join us :https://discord.gg/UjgXkPWNqA

If you have WeChat account,welcome to join our community on WeChat. You can add our assistant :openmmlabwx. Please add "mmsig + Github ID" as a remark when adding friends:)
Thank you again for your contribution❤

FANGAreNotGnu pushed a commit to FANGAreNotGnu/mmdetection that referenced this pull request Oct 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants