Skip to content

Commit

Permalink
Merge branch 'dev-3.x' into dev-3.x-xjy
Browse files Browse the repository at this point in the history
  • Loading branch information
ZwwWayne committed Sep 19, 2022
2 parents 291ce50 + f1bdc96 commit 0057860
Show file tree
Hide file tree
Showing 12 changed files with 229 additions and 103 deletions.
2 changes: 1 addition & 1 deletion configs/faster_rcnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ We trained with R-50-FPN pytorch style backbone for 1x schedule.
| R-50-FPN | L1Loss | 4.0 | 21.4 | 37.4 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130_204655.log.json) |
| R-50-FPN | IoULoss | | | 37.9 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn/faster-rcnn_r50_fpn_iou_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_iou_1x_coco/faster_rcnn_r50_fpn_iou_1x_coco_20200506_095954-938e81f0.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_iou_1x_coco/faster_rcnn_r50_fpn_iou_1x_coco_20200506_095954.log.json) |
| R-50-FPN | GIoULoss | | | 37.6 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_giou_1x_coco-0eada910.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_giou_1x_coco_20200505_161120.log.json) |
| R-50-FPN | BoundedIoULoss | | | 37.4 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_bounded_iou_1x_coco-98ad993b.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_bounded_iou_1x_coco_20200505_160738.log.json) |
| R-50-FPN | BoundedIoULoss | | | 37.4 | [config](./faster-rcnn_r50_fpn_bounded-iou_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_bounded_iou_1x_coco-98ad993b.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_bounded_iou_1x_coco_20200505_160738.log.json) |

## Pre-trained Models

Expand Down
4 changes: 2 additions & 2 deletions docs/en/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ The downloading will take several seconds or more, depending on your network env
Option (a). If you install MMDetection from source, just run the following command.

```shell
python demo/image_demo.py demo/demo.jpg yolov3_mobilenetv2_8xb24-ms-416-300e_coco.py yolov3_mobilenetv2_320_300e_coco_20210719_215349-d18dff72.pth --device cpu --out-file result.jpg
python demo/image_demo.py demo/demo.jpg yolov3_mobilenetv2_8xb24-320-300e_coco.py yolov3_mobilenetv2_320_300e_coco_20210719_215349-d18dff72.pth --device cpu --out-file result.jpg
```

You will see a new image `result.jpg` on your current folder, where bounding boxes are plotted on cars, benches, etc.
Expand All @@ -98,7 +98,7 @@ from mmdet.apis import init_detector, inference_detector
from mmdet.utils import register_all_modules

register_all_modules()
config_file = 'yolov3_mobilenetv2_8xb24-ms-416-300e_coco.py'
config_file = 'yolov3_mobilenetv2_8xb24-320-300e_coco.py'
checkpoint_file = 'yolov3_mobilenetv2_320_300e_coco_20210719_215349-d18dff72.pth'
model = init_detector(config_file, checkpoint_file, device='cpu') # or device='cuda:0'
inference_detector(model, 'demo/demo.jpg')
Expand Down
4 changes: 2 additions & 2 deletions docs/zh_cn/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ mim download mmdet --config yolov3_mobilenetv2_8xb24-320-300e_coco --dest .
方案 1. 如果你通过源码安装的 MMDetection,那么直接运行以下命令进行验证:

```shell
python demo/image_demo.py demo/demo.jpg yolov3_mobilenetv2_8xb24-ms-416-300e_coco.py yolov3_mobilenetv2_320_300e_coco_20210719_215349-d18dff72.pth --device cpu --out-file result.jpg
python demo/image_demo.py demo/demo.jpg yolov3_mobilenetv2_8xb24-320-300e_coco.py yolov3_mobilenetv2_320_300e_coco_20210719_215349-d18dff72.pth --device cpu --out-file result.jpg
```

你会在当前文件夹中看到一个新的图像`result.jpg`,图像中包含有网络预测的检测框。
Expand All @@ -95,7 +95,7 @@ from mmdet.apis import init_detector, inference_detector
from mmdet.utils import register_all_modules

register_all_modules()
config_file = 'yolov3_mobilenetv2_8xb24-ms-416-300e_coco.py'
config_file = 'yolov3_mobilenetv2_8xb24-320-300e_coco.py'
checkpoint_file = 'yolov3_mobilenetv2_320_300e_coco_20210719_215349-d18dff72.pth'
model = init_detector(config_file, checkpoint_file, device='cpu') # or device='cuda:0'
inference_detector(model, 'demo/demo.jpg')
Expand Down
6 changes: 3 additions & 3 deletions mmdet/engine/hooks/num_class_check_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def _check_head(self, runner: Runner, mode: str) -> None:
f'CLASSES = ({CLASSES},)')
from mmdet.models.roi_heads.mask_heads import FusedSemanticHead
for name, module in model.named_modules():
if hasattr(module, 'num_classes'
) and name != 'rpn_head' and not isinstance(
module, (VGG, FusedSemanticHead)):
if hasattr(module, 'num_classes') and not name.endswith(
'rpn_head') and not isinstance(
module, (VGG, FusedSemanticHead)):
assert module.num_classes == len(CLASSES), \
(f'The `num_classes` ({module.num_classes}) in '
f'{module.__class__.__name__} of '
Expand Down
2 changes: 1 addition & 1 deletion mmdet/models/detectors/panoptic_two_stage_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def loss(self, batch_inputs: Tensor,
x, rpn_data_samples, proposal_cfg=proposal_cfg)
# avoid get same name with roi_head loss
keys = rpn_losses.keys()
for key in keys:
for key in list(keys):
if 'loss' in key and 'rpn' not in key:
rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)
losses.update(rpn_losses)
Expand Down
62 changes: 14 additions & 48 deletions mmdet/models/detectors/semi_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from collections import Sequence
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch import Tensor

from mmdet.models.utils import (filter_gt_instances, rename_loss_dict,
reweight_loss_dict)
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.structures.bbox import bbox_project
Expand All @@ -16,7 +17,7 @@

@MODELS.register_module()
class SemiBaseDetector(BaseDetector):
"""Base class for semi-supervsed detectors.
"""Base class for semi-supervised detectors.
Semi-supervised detectors typically consisting of a teacher model
updated by exponential moving average and a student model updated
Expand Down Expand Up @@ -58,17 +59,6 @@ def freeze(model: nn.Module):
for param in model.parameters():
param.requires_grad = False

@staticmethod
def reweight_loss(losses: dict, weight: float) -> dict:
"""Reweight loss for different branches."""
for name, loss in losses.items():
if 'loss' in name:
if isinstance(loss, Sequence):
losses[name] = [item * weight for item in loss]
else:
losses[name] = loss * weight
return losses

def loss(self, multi_batch_inputs: Dict[str, Tensor],
multi_batch_data_samples: Dict[str, SampleList]) -> dict:
"""Calculate losses from multi-branch inputs and data samples.
Expand Down Expand Up @@ -113,13 +103,10 @@ def loss_by_gt_instances(self, batch_inputs: Tensor,
Returns:
dict: A dictionary of loss components
"""
gt_loss = {
'sup_' + k: v
for k, v in self.reweight_loss(
self.student.loss(batch_inputs, batch_data_samples),
self.semi_train_cfg.get('sup_weight', 1.)).items()
}
return gt_loss

losses = self.student.loss(batch_inputs, batch_data_samples)
sup_weight = self.semi_train_cfg.get('sup_weight', 1.)
return rename_loss_dict('sup_', reweight_loss_dict(losses, sup_weight))

def loss_by_pseudo_instances(self,
batch_inputs: Tensor,
Expand All @@ -141,39 +128,17 @@ def loss_by_pseudo_instances(self,
Returns:
dict: A dictionary of loss components
"""
for data_samples in batch_data_samples:
if data_samples.gt_instances.bboxes.shape[0] > 0:
data_samples.gt_instances = data_samples.gt_instances[
data_samples.gt_instances.scores >
self.semi_train_cfg.cls_pseudo_thr]

batch_data_samples = filter_gt_instances(
batch_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr)
losses = self.student.loss(batch_inputs, batch_data_samples)
pseudo_instances_num = sum([
len(data_samples.gt_instances)
for data_samples in batch_data_samples
])
unsup_weight = self.semi_train_cfg.get(
'unsup_weight', 1.) if pseudo_instances_num > 0 else 0.

pseudo_loss = {
'unsup_' + k: v
for k, v in self.reweight_loss(
self.student.loss(batch_inputs, batch_data_samples),
unsup_weight).items()
}
return pseudo_loss

def filter_pseudo_instances(self,
batch_data_samples: SampleList) -> SampleList:
"""Filter invalid pseudo instances from teacher model."""
for data_samples in batch_data_samples:
pseudo_bboxes = data_samples.gt_instances.bboxes
if pseudo_bboxes.shape[0] > 0:
w = pseudo_bboxes[:, 2] - pseudo_bboxes[:, 0]
h = pseudo_bboxes[:, 3] - pseudo_bboxes[:, 1]
data_samples.gt_instances = data_samples.gt_instances[
(w > self.semi_train_cfg.min_pseudo_bbox_wh[0])
& (h > self.semi_train_cfg.min_pseudo_bbox_wh[1])]
return batch_data_samples
return rename_loss_dict('unsup_',
reweight_loss_dict(losses, unsup_weight))

@torch.no_grad()
def get_pseudo_instances(
Expand All @@ -198,7 +163,8 @@ def project_pseudo_instances(self, batch_pseudo_instances: SampleList,
data_samples.gt_instances.bboxes,
torch.tensor(data_samples.homography_matrix).to(
self.data_preprocessor.device), data_samples.img_shape)
return self.filter_pseudo_instances(batch_data_samples)
wh_thr = self.semi_train_cfg.get('min_pseudo_bbox_wh', (1e-2, 1e-2))
return filter_gt_instances(batch_data_samples, wh_thr=wh_thr)

def predict(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> SampleList:
Expand Down
65 changes: 29 additions & 36 deletions mmdet/models/detectors/soft_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from mmengine.structures import InstanceData
from torch import Tensor

from mmdet.models.utils import (filter_gt_instances, rename_loss_dict,
reweight_loss_dict)
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.structures.bbox import bbox2roi, bbox_project
Expand Down Expand Up @@ -77,12 +79,9 @@ def loss_by_pseudo_instances(self,
x, rpn_results_list, batch_data_samples, batch_info))
losses.update(**self.rcnn_reg_loss_by_pseudo_instances(
x, rpn_results_list, batch_data_samples))
pseudo_loss = {
'unsup_' + k: v
for k, v in self.reweight_loss(
losses, self.semi_train_cfg.get('unsup_weight', 1.)).items()
}
return pseudo_loss
unsup_weight = self.semi_train_cfg.get('unsup_weight', 1.)
return rename_loss_dict('unsup_',
reweight_loss_dict(losses, unsup_weight))

@torch.no_grad()
def get_pseudo_instances(
Expand All @@ -106,10 +105,10 @@ def get_pseudo_instances(

for data_samples, results in zip(batch_data_samples, results_list):
data_samples.gt_instances = results
if data_samples.gt_instances.bboxes.shape[0] > 0:
data_samples.gt_instances = data_samples.gt_instances[
data_samples.gt_instances.scores >
self.semi_train_cfg.pseudo_label_initial_score_thr]

batch_data_samples = filter_gt_instances(
batch_data_samples,
score_thr=self.semi_train_cfg.pseudo_label_initial_score_thr)

reg_uncs_list = self.compute_uncertainty_with_aug(
x, batch_data_samples)
Expand Down Expand Up @@ -151,12 +150,8 @@ def rpn_loss_by_pseudo_instances(self, x: Tuple[Tensor],
"""

rpn_data_samples = copy.deepcopy(batch_data_samples)
for data_samples in rpn_data_samples:
if data_samples.gt_instances.bboxes.shape[0] > 0:
data_samples.gt_instances = data_samples.gt_instances[
data_samples.gt_instances.scores >
self.semi_train_cfg.rpn_pseudo_thr]

rpn_data_samples = filter_gt_instances(
rpn_data_samples, score_thr=self.semi_train_cfg.rpn_pseudo_thr)
proposal_cfg = self.student.train_cfg.get('rpn_proposal',
self.student.test_cfg.rpn)
# set cat_id of gt_labels to 0 in RPN
Expand Down Expand Up @@ -196,11 +191,8 @@ def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor],
"""
rpn_results_list = copy.deepcopy(unsup_rpn_results_list)
cls_data_samples = copy.deepcopy(batch_data_samples)
for data_samples in cls_data_samples:
if data_samples.gt_instances.bboxes.shape[0] > 0:
data_samples.gt_instances = data_samples.gt_instances[
data_samples.gt_instances.scores >
self.semi_train_cfg.cls_pseudo_thr]
cls_data_samples = filter_gt_instances(
cls_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr)

outputs = unpack_gt_instances(cls_data_samples)
batch_gt_instances, batch_gt_instances_ignore, _ = outputs
Expand All @@ -212,7 +204,6 @@ def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor],
# rename rpn_results.bboxes to rpn_results.priors
rpn_results = rpn_results_list[i]
rpn_results.priors = rpn_results.pop('bboxes')

assign_result = self.student.roi_head.bbox_assigner.assign(
rpn_results, batch_gt_instances[i],
batch_gt_instances_ignore[i])
Expand All @@ -226,19 +217,21 @@ def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor],
selected_bboxes = [res.priors for res in sampling_results]
rois = bbox2roi(selected_bboxes)
bbox_results = self.student.roi_head._bbox_forward(x, rois)
# cls_reg_targets is a tuple of labels, label_weights,
# and bbox_targets, bbox_weights
cls_reg_targets = self.student.roi_head.bbox_head.get_targets(
sampling_results, self.student.train_cfg.rcnn)

selected_results_list = []
for bboxes, data_samples, homography_matrix, img_shape in zip(
for bboxes, data_samples, teacher_matrix, teacher_img_shape in zip(
selected_bboxes, batch_data_samples,
batch_info['homography_matrix'], batch_info['img_shape']):
selected_results_list.append(
InstanceData(
bboxes=bbox_project(
bboxes, homography_matrix @ torch.tensor(
data_samples.homography_matrix).inverse().to(
self.data_preprocessor.device), img_shape)))
student_matrix = torch.tensor(
data_samples.homography_matrix, device=teacher_matrix.device)
homography_matrix = teacher_matrix @ student_matrix.inverse()
projected_bboxes = bbox_project(bboxes, homography_matrix,
teacher_img_shape)
selected_results_list.append(InstanceData(bboxes=projected_bboxes))

with torch.no_grad():
results_list = self.teacher.roi_head.predict_bbox(
Expand All @@ -249,18 +242,18 @@ def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor],
rescale=False)
bg_score = torch.cat(
[results.scores[:, -1] for results in results_list])
# cls_reg_targets[0] is labels
neg_inds = cls_reg_targets[
0] == self.student.roi_head.bbox_head.num_classes
# cls_reg_targets[1] is label_weights
cls_reg_targets[1][neg_inds] = bg_score[neg_inds].detach()

losses = self.student.roi_head.bbox_head.loss(
bbox_results['cls_score'],
bbox_results['bbox_pred'],
rois,
*cls_reg_targets,
reduction_override='none')
losses['loss_cls'] = losses['loss_cls'].sum() / max(
cls_reg_targets[1].sum(), 1.0)
bbox_results['cls_score'], bbox_results['bbox_pred'], rois,
*cls_reg_targets)
# cls_reg_targets[1] is label_weights
losses['loss_cls'] = losses['loss_cls'] * len(
cls_reg_targets[1]) / max(sum(cls_reg_targets[1]), 1.0)
return losses

def rcnn_reg_loss_by_pseudo_instances(
Expand Down
2 changes: 1 addition & 1 deletion mmdet/models/detectors/two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def loss(self, batch_inputs: Tensor,
x, rpn_data_samples, proposal_cfg=proposal_cfg)
# avoid get same name with roi_head loss
keys = rpn_losses.keys()
for key in keys:
for key in list(keys):
if 'loss' in key and 'rpn' not in key:
rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)
losses.update(rpn_losses)
Expand Down
16 changes: 9 additions & 7 deletions mmdet/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
get_topk_from_heatmap, transpose_and_gather_feat)
from .make_divisible import make_divisible
from .misc import (cat_boxes, center_of_mass, empty_instances,
filter_scores_and_topk, flip_tensor, generate_coordinate,
get_box_tensor, get_box_wh, images_to_levels,
interpolate_as, levels_to_images, mask2ndarray, multi_apply,
samplelist_boxlist2tensor, scale_boxes, select_single_mlvl,
sigmoid_geometric_mean, stack_boxes, unmap,
unpack_gt_instances)
filter_gt_instances, filter_scores_and_topk, flip_tensor,
generate_coordinate, get_box_tensor, get_box_wh,
images_to_levels, interpolate_as, levels_to_images,
mask2ndarray, multi_apply, rename_loss_dict,
reweight_loss_dict, samplelist_boxlist2tensor, scale_boxes,
select_single_mlvl, sigmoid_geometric_mean, stack_boxes,
unmap, unpack_gt_instances)
from .panoptic_gt_processing import preprocess_panoptic_gt
from .point_sample import (get_uncertain_point_coords_with_randomness,
get_uncertainty)
Expand All @@ -24,5 +25,6 @@
'generate_coordinate', 'levels_to_images', 'mask2ndarray', 'multi_apply',
'select_single_mlvl', 'unmap', 'images_to_levels',
'samplelist_boxlist2tensor', 'cat_boxes', 'stack_boxes', 'scale_boxes',
'get_box_tensor', 'get_box_wh'
'get_box_tensor', 'get_box_wh', 'filter_gt_instances', 'rename_loss_dict',
'reweight_loss_dict'
]
Loading

0 comments on commit 0057860

Please sign in to comment.