Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code for paper "Dynamic R-CNN: Towards High Quality Object Detection via Dynamic Training" #3040

Merged
merged 9 commits into from Jun 26, 2020

Conversation

hkzhang95
Copy link
Contributor

No description provided.

@CLAassistant
Copy link

CLAassistant commented Jun 16, 2020

CLA assistant check
All committers have signed the CLA.

@hellock hellock requested a review from xvjiarui June 16, 2020 11:58
@xvjiarui
Copy link
Collaborator

Hi @hkzhang95
Thanks for your contribution. I will review and benchmark this PR very soon.

torch.topk(
assign_result.max_overlaps,
min(self.iou_topk,
len(assign_result.max_overlaps)))[0][-1].item())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify the logic to make it more readable for users.

iou_topk = min(self.iou_topk, len(assign_result.max_overlaps))
ious, _ = torch.topk(assign_result.max_overlaps, iou_topk)
cur_iou.append(ious[-1].item())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Copy link

@LorenzoQ5 LorenzoQ5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stop doing this sow I'm saying I'm going to report you

@hkzhang95
Copy link
Contributor Author

Stop doing this sow I'm saying I'm going to report you

What do you mean by saying that?

self.beta_topk = self.train_cfg.dynamic_rcnn.beta_topk
self.iteration_count = self.train_cfg.dynamic_rcnn.iteration_count
self.initial_iou = 0.4
self.initial_beta = 1.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to make initial_iou and initial_beta configurable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added related settings in the configuration file.

super(DynamicRoIHead, self).__init__(**kwargs)
self.iou_topk = self.train_cfg.dynamic_rcnn.iou_topk
self.beta_topk = self.train_cfg.dynamic_rcnn.beta_topk
self.iteration_count = self.train_cfg.dynamic_rcnn.iteration_count
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iteration_count -> udpate_iter_interval

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


def update_statistics(self):
new_iou_thr = max(self.initial_iou,
sum(self.cur_iou) / self.iteration_count)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_iou_thr = max(self.initial_iou, np.mean(self. iou_history))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

self.initial_iou = 0.4
self.initial_beta = 1.0
self.cur_iou = []
self.cur_beta = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use more comprehensible variable names.

# the IoU history of the past `udpate_iter_interval` iterations
self.iou_history = []
# the beta history of the past `udpate_iter_interval` iterations
self.beta_history = []

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

bbox_results.update(loss_bbox=loss_bbox)
return bbox_results

def update_statistics(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • update_statistics is not a good method name since this method is not updating statistics but updating parameters like iou_thr and beta.
  • Add a docstring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the method name to update_hyperparameters and add a docstring.

mmdet/models/roi_heads/dynamic_roi_head.py Show resolved Hide resolved
mmdet/models/roi_heads/dynamic_roi_head.py Show resolved Hide resolved
mmdet/models/roi_heads/dynamic_roi_head.py Show resolved Hide resolved
new_beta = min(self.initial_beta,
sorted(self.cur_beta)[self.iteration_count // 2])
self.cur_beta = []
assert isinstance(self.bbox_head.loss_bbox, SmoothL1Loss)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion can be moved to __init__ method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

self.bbox_assigner.neg_iou_thr = new_iou_thr
self.bbox_assigner.min_pos_iou = new_iou_thr
new_beta = min(self.initial_beta,
sorted(self.cur_beta)[self.iteration_count // 2])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_beta = min(self.initial_beta, np.median(self.beta_history))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@hellock hellock merged commit d2a8ba7 into open-mmlab:master Jun 26, 2020
mike112223 pushed a commit to mike112223/mmdetection that referenced this pull request Aug 25, 2020
…via Dynamic Training" (open-mmlab#3040)

* Code for paper "Dynamic R-CNN: Towards High Quality Object Detection via Dynamic Training"

* update configs/dynamic_rcnn/dynamic_rcnn_r50_fpn_1x.py

* reformat code

* simplify code

* update model link

* simplify code

* simplify code logic

* simplify code and add comments

* minor updates of the docstring

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
Co-authored-by: Kai Chen <chenkaidev@gmail.com>
liuhuiCNN pushed a commit to liuhuiCNN/mmdetection that referenced this pull request May 21, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants