Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tta to HTC and Cascade RCNN #1251

Merged
merged 8 commits into from
Sep 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 107 additions & 4 deletions mmdet/models/detectors/cascade_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import torch
import torch.nn as nn

from mmdet.core import (bbox2result, bbox2roi, build_assigner, build_sampler,
merge_aug_masks)
from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner,
build_sampler, merge_aug_bboxes, merge_aug_masks,
multiclass_nms)
from .. import builder
from ..registry import DETECTORS
from .base import BaseDetector
Expand Down Expand Up @@ -399,8 +400,110 @@ def simple_test(self, img, img_meta, proposals=None, rescale=False):

return results

def aug_test(self, img, img_meta, proposals=None, rescale=False):
raise NotImplementedError
def aug_test(self, imgs, img_metas, proposals=None, rescale=False):
"""Test with augmentations.
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
# recompute feats to save memory
proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn)

rcnn_test_cfg = self.test_cfg.rcnn
aug_bboxes = []
aug_scores = []
for x, img_meta in zip(self.extract_feats(imgs), img_metas):
# only one image in the batch
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']

proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
scale_factor, flip)
# "ms" in variable names means multi-stage
ms_scores = []

rois = bbox2roi([proposals])
for i in range(self.num_stages):
bbox_roi_extractor = self.bbox_roi_extractor[i]
bbox_head = self.bbox_head[i]

bbox_feats = bbox_roi_extractor(
x[:len(bbox_roi_extractor.featmap_strides)], rois)
if self.with_shared_head:
bbox_feats = self.shared_head(bbox_feats)

cls_score, bbox_pred = bbox_head(bbox_feats)
ms_scores.append(cls_score)

if i < self.num_stages - 1:
bbox_label = cls_score.argmax(dim=1)
rois = bbox_head.regress_by_class(rois, bbox_label,
bbox_pred, img_meta[0])

cls_score = sum(ms_scores) / float(len(ms_scores))
bboxes, scores = self.bbox_head[-1].get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=False,
cfg=None)
aug_bboxes.append(bboxes)
aug_scores.append(scores)

# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
rcnn_test_cfg.score_thr,
rcnn_test_cfg.nms,
rcnn_test_cfg.max_per_img)

bbox_result = bbox2result(det_bboxes, det_labels,
self.bbox_head[-1].num_classes)

if self.with_mask:
if det_bboxes.shape[0] == 0:
segm_result = [[]
for _ in range(self.mask_head[-1].num_classes -
1)]
else:
aug_masks = []
aug_img_metas = []
for x, img_meta in zip(self.extract_feats(imgs), img_metas):
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
scale_factor, flip)
mask_rois = bbox2roi([_bboxes])
for i in range(self.num_stages):
mask_feats = self.mask_roi_extractor[i](
x[:len(self.mask_roi_extractor[i].featmap_strides
)], mask_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats)
mask_pred = self.mask_head[i](mask_feats)
aug_masks.append(mask_pred.sigmoid().cpu().numpy())
aug_img_metas.append(img_meta)
merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
self.test_cfg.rcnn)

ori_shape = img_metas[0][0]['ori_shape']
segm_result = self.mask_head[-1].get_seg_masks(
merged_masks,
det_bboxes,
det_labels,
rcnn_test_cfg,
ori_shape,
scale_factor=1.0,
rescale=False)
return bbox_result, segm_result
else:
return bbox_result

def show_result(self, data, result, **kwargs):
if self.with_mask:
Expand Down
128 changes: 124 additions & 4 deletions mmdet/models/detectors/htc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
import torch.nn.functional as F

from mmdet.core import (bbox2result, bbox2roi, build_assigner, build_sampler,
merge_aug_masks)
from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner,
build_sampler, merge_aug_bboxes, merge_aug_masks,
multiclass_nms)
from .. import builder
from ..registry import DETECTORS
from .cascade_rcnn import CascadeRCNN
Expand Down Expand Up @@ -431,5 +432,124 @@ def simple_test(self, img, img_meta, proposals=None, rescale=False):

return results

def aug_test(self, img, img_meta, proposals=None, rescale=False):
raise NotImplementedError
def aug_test(self, imgs, img_metas, proposals=None, rescale=False):
"""Test with augmentations.

If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
if self.with_semantic:
semantic_feats = [
self.semantic_head(feat)[1]
for feat in self.extract_feats(imgs)
]
else:
semantic_feats = [None] * len(img_metas)

# recompute feats to save memory
proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn)

rcnn_test_cfg = self.test_cfg.rcnn
aug_bboxes = []
aug_scores = []
for x, img_meta, semantic in zip(
self.extract_feats(imgs), img_metas, semantic_feats):
# only one image in the batch
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']

proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
scale_factor, flip)
# "ms" in variable names means multi-stage
ms_scores = []

rois = bbox2roi([proposals])
for i in range(self.num_stages):
bbox_head = self.bbox_head[i]
cls_score, bbox_pred = self._bbox_forward_test(
i, x, rois, semantic_feat=semantic)
ms_scores.append(cls_score)

if i < self.num_stages - 1:
bbox_label = cls_score.argmax(dim=1)
rois = bbox_head.regress_by_class(rois, bbox_label,
bbox_pred, img_meta[0])

cls_score = sum(ms_scores) / float(len(ms_scores))
bboxes, scores = self.bbox_head[-1].get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=False,
cfg=None)
aug_bboxes.append(bboxes)
aug_scores.append(scores)

# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
rcnn_test_cfg.score_thr,
rcnn_test_cfg.nms,
rcnn_test_cfg.max_per_img)

bbox_result = bbox2result(det_bboxes, det_labels,
self.bbox_head[-1].num_classes)

if self.with_mask:
if det_bboxes.shape[0] == 0:
segm_result = [[]
for _ in range(self.mask_head[-1].num_classes -
1)]
else:
aug_masks = []
aug_img_metas = []
for x, img_meta, semantic in zip(
self.extract_feats(imgs), img_metas, semantic_feats):
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
scale_factor, flip)
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor[-1](
x[:len(self.mask_roi_extractor[-1].featmap_strides)],
mask_rois)
if self.with_semantic:
semantic_feat = semantic
mask_semantic_feat = self.semantic_roi_extractor(
[semantic_feat], mask_rois)
if mask_semantic_feat.shape[-2:] != mask_feats.shape[
-2:]:
mask_semantic_feat = F.adaptive_avg_pool2d(
mask_semantic_feat, mask_feats.shape[-2:])
mask_feats += mask_semantic_feat
last_feat = None
for i in range(self.num_stages):
mask_head = self.mask_head[i]
if self.mask_info_flow:
mask_pred, last_feat = mask_head(
mask_feats, last_feat)
else:
mask_pred = mask_head(mask_feats)
aug_masks.append(mask_pred.sigmoid().cpu().numpy())
aug_img_metas.append(img_meta)
merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
self.test_cfg.rcnn)

ori_shape = img_metas[0][0]['ori_shape']
segm_result = self.mask_head[-1].get_seg_masks(
merged_masks,
det_bboxes,
det_labels,
rcnn_test_cfg,
ori_shape,
scale_factor=1.0,
rescale=False)
return bbox_result, segm_result
else:
return bbox_result