Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor SemiBaseDetector and SoftTeacher #8786

Merged
merged 6 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 29 additions & 25 deletions mmdet/models/detectors/semi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ def reweight_loss(losses: dict, weight: float) -> dict:
losses[name] = loss * weight
return losses

@staticmethod
Czm369 marked this conversation as resolved.
Show resolved Hide resolved
def rename_loss(prefix: str, losses: dict) -> dict:
"""Rename loss for different branches."""
return {prefix + k: v for k, v in losses.items()}

@staticmethod
def filter_pseudo_instances_by_score(batch_data_samples: SampleList,
thr: float) -> SampleList:
"""Filter invalid pseudo instances by scores."""
for data_samples in batch_data_samples:
if data_samples.gt_instances.bboxes.shape[0] > 0:
data_samples.gt_instances = data_samples.gt_instances[
data_samples.gt_instances.scores > thr]
return batch_data_samples

def loss(self, multi_batch_inputs: Dict[str, Tensor],
multi_batch_data_samples: Dict[str, SampleList]) -> dict:
"""Calculate losses from multi-branch inputs and data samples.
Expand Down Expand Up @@ -113,13 +128,11 @@ def loss_by_gt_instances(self, batch_inputs: Tensor,
Returns:
dict: A dictionary of loss components
"""
gt_loss = {
'sup_' + k: v
for k, v in self.reweight_loss(
self.student.loss(batch_inputs, batch_data_samples),
self.semi_train_cfg.get('sup_weight', 1.)).items()
}
return gt_loss

losses = self.student.loss(batch_inputs, batch_data_samples)
sup_weight = self.semi_train_cfg.get('sup_weight', 1.)
return self.rename_loss('sup_',
self.reweight_loss(losses, sup_weight))

def loss_by_pseudo_instances(self,
batch_inputs: Tensor,
Expand All @@ -141,30 +154,21 @@ def loss_by_pseudo_instances(self,
Returns:
dict: A dictionary of loss components
"""
for data_samples in batch_data_samples:
if data_samples.gt_instances.bboxes.shape[0] > 0:
data_samples.gt_instances = data_samples.gt_instances[
data_samples.gt_instances.scores >
self.semi_train_cfg.cls_pseudo_thr]

batch_data_samples = self.filter_pseudo_instances_by_score(
batch_data_samples, self.semi_train_cfg.cls_pseudo_thr)
losses = self.student.loss(batch_inputs, batch_data_samples)
pseudo_instances_num = sum([
len(data_samples.gt_instances)
for data_samples in batch_data_samples
])
unsup_weight = self.semi_train_cfg.get(
'unsup_weight', 1.) if pseudo_instances_num > 0 else 0.
return self.rename_loss(
'unsup_', self.reweight_loss(losses, unsup_weight))

pseudo_loss = {
'unsup_' + k: v
for k, v in self.reweight_loss(
self.student.loss(batch_inputs, batch_data_samples),
unsup_weight).items()
}
return pseudo_loss

def filter_pseudo_instances(self,
batch_data_samples: SampleList) -> SampleList:
"""Filter invalid pseudo instances from teacher model."""
def filter_pseudo_instances_by_sizes(
Czm369 marked this conversation as resolved.
Show resolved Hide resolved
self, batch_data_samples: SampleList) -> SampleList:
"""Filter invalid pseudo instances by sizes from teacher model."""
for data_samples in batch_data_samples:
pseudo_bboxes = data_samples.gt_instances.bboxes
if pseudo_bboxes.shape[0] > 0:
Expand Down Expand Up @@ -198,7 +202,7 @@ def project_pseudo_instances(self, batch_pseudo_instances: SampleList,
data_samples.gt_instances.bboxes,
torch.tensor(data_samples.homography_matrix).to(
self.data_preprocessor.device), data_samples.img_shape)
return self.filter_pseudo_instances(batch_data_samples)
return self.filter_pseudo_instances_by_sizes(batch_data_samples)

def predict(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> SampleList:
Expand Down
63 changes: 27 additions & 36 deletions mmdet/models/detectors/soft_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,9 @@ def loss_by_pseudo_instances(self,
x, rpn_results_list, batch_data_samples, batch_info))
losses.update(**self.rcnn_reg_loss_by_pseudo_instances(
x, rpn_results_list, batch_data_samples))
pseudo_loss = {
'unsup_' + k: v
for k, v in self.reweight_loss(
losses, self.semi_train_cfg.get('unsup_weight', 1.)).items()
}
return pseudo_loss
unsup_weight = self.semi_train_cfg.get('unsup_weight', 1.)
return self.rename_loss(
'unsup_', self.reweight_loss(losses, unsup_weight))

@torch.no_grad()
def get_pseudo_instances(
Expand All @@ -106,10 +103,10 @@ def get_pseudo_instances(

for data_samples, results in zip(batch_data_samples, results_list):
data_samples.gt_instances = results
if data_samples.gt_instances.bboxes.shape[0] > 0:
data_samples.gt_instances = data_samples.gt_instances[
data_samples.gt_instances.scores >
self.semi_train_cfg.pseudo_label_initial_score_thr]

batch_data_samples = self.filter_pseudo_instances_by_score(
batch_data_samples,
self.semi_train_cfg.pseudo_label_initial_score_thr)

reg_uncs_list = self.compute_uncertainty_with_aug(
x, batch_data_samples)
Expand Down Expand Up @@ -151,12 +148,8 @@ def rpn_loss_by_pseudo_instances(self, x: Tuple[Tensor],
"""

rpn_data_samples = copy.deepcopy(batch_data_samples)
for data_samples in rpn_data_samples:
if data_samples.gt_instances.bboxes.shape[0] > 0:
data_samples.gt_instances = data_samples.gt_instances[
data_samples.gt_instances.scores >
self.semi_train_cfg.rpn_pseudo_thr]

rpn_data_samples = self.filter_pseudo_instances_by_score(
rpn_data_samples, self.semi_train_cfg.rpn_pseudo_thr)
proposal_cfg = self.student.train_cfg.get('rpn_proposal',
self.student.test_cfg.rpn)
# set cat_id of gt_labels to 0 in RPN
Expand Down Expand Up @@ -196,11 +189,8 @@ def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor],
"""
rpn_results_list = copy.deepcopy(unsup_rpn_results_list)
cls_data_samples = copy.deepcopy(batch_data_samples)
for data_samples in cls_data_samples:
if data_samples.gt_instances.bboxes.shape[0] > 0:
data_samples.gt_instances = data_samples.gt_instances[
data_samples.gt_instances.scores >
self.semi_train_cfg.cls_pseudo_thr]
cls_data_samples = self.filter_pseudo_instances_by_score(
cls_data_samples, self.semi_train_cfg.cls_pseudo_thr)

outputs = unpack_gt_instances(cls_data_samples)
batch_gt_instances, batch_gt_instances_ignore, _ = outputs
Expand All @@ -212,7 +202,6 @@ def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor],
# rename rpn_results.bboxes to rpn_results.priors
rpn_results = rpn_results_list[i]
rpn_results.priors = rpn_results.pop('bboxes')

assign_result = self.student.roi_head.bbox_assigner.assign(
rpn_results, batch_gt_instances[i],
batch_gt_instances_ignore[i])
Expand All @@ -226,19 +215,21 @@ def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor],
selected_bboxes = [res.priors for res in sampling_results]
rois = bbox2roi(selected_bboxes)
bbox_results = self.student.roi_head._bbox_forward(x, rois)
# cls_reg_targets is a tuple of labels, label_weights,
# and bbox_targets, bbox_weights
cls_reg_targets = self.student.roi_head.bbox_head.get_targets(
sampling_results, self.student.train_cfg.rcnn)

selected_results_list = []
for bboxes, data_samples, homography_matrix, img_shape in zip(
for bboxes, data_samples, teacher_matrix, teacher_img_shape in zip(
selected_bboxes, batch_data_samples,
batch_info['homography_matrix'], batch_info['img_shape']):
selected_results_list.append(
InstanceData(
bboxes=bbox_project(
bboxes, homography_matrix @ torch.tensor(
data_samples.homography_matrix).inverse().to(
self.data_preprocessor.device), img_shape)))
student_matrix = torch.tensor(
data_samples.homography_matrix, device=teacher_matrix.device)
homography_matrix = teacher_matrix @ student_matrix.inverse()
projected_bboxes = bbox_project(bboxes, homography_matrix,
teacher_img_shape)
selected_results_list.append(InstanceData(bboxes=projected_bboxes))

with torch.no_grad():
results_list = self.teacher.roi_head.predict_bbox(
Expand All @@ -249,18 +240,18 @@ def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor],
rescale=False)
bg_score = torch.cat(
[results.scores[:, -1] for results in results_list])
# cls_reg_targets[0] is labels
neg_inds = cls_reg_targets[
0] == self.student.roi_head.bbox_head.num_classes
# cls_reg_targets[1] is label_weights
cls_reg_targets[1][neg_inds] = bg_score[neg_inds].detach()
Czm369 marked this conversation as resolved.
Show resolved Hide resolved

losses = self.student.roi_head.bbox_head.loss(
bbox_results['cls_score'],
bbox_results['bbox_pred'],
rois,
*cls_reg_targets,
reduction_override='none')
losses['loss_cls'] = losses['loss_cls'].sum() / max(
cls_reg_targets[1].sum(), 1.0)
bbox_results['cls_score'], bbox_results['bbox_pred'], rois,
*cls_reg_targets)
# cls_reg_targets[1] is label_weights
losses['loss_cls'] = losses['loss_cls'] * len(
cls_reg_targets[1]) / max(sum(cls_reg_targets[1]), 1.0)
return losses

def rcnn_reg_loss_by_pseudo_instances(
Expand Down