Skip to content

Commit

Permalink
[Feature]: Support ONNX inference with dynamic input shape in AnchorH…
Browse files Browse the repository at this point in the history
…ead (#4684)

* make anchor_generator exportable to ONNX

* make k of topk dynamic for onnx

* rename nms_pre_t -> nms_pre_tensor
  • Loading branch information
RunningLeon committed Mar 3, 2021
1 parent 66ccfeb commit 9946e12
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
9 changes: 4 additions & 5 deletions mmdet/core/anchor/anchor_generator.py
Expand Up @@ -196,8 +196,9 @@ def _meshgrid(self, x, y, row_major=True):
Returns:
tuple[torch.Tensor]: The mesh grids of x and y.
"""
xx = x.repeat(len(y))
yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
# use shape instead of len to keep tracing while exporting to onnx
xx = x.repeat(y.shape[0])
yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
if row_major:
return xx, yy
else:
Expand Down Expand Up @@ -250,10 +251,8 @@ def single_level_grid_anchors(self,
Returns:
torch.Tensor: Anchors in the overall feature maps.
"""
# keep as Tensor, so that we can covert to ONNX correctly
feat_h, feat_w = featmap_size
# convert Tensor to int, so that we can covert to ONNX correctlly
feat_h = int(feat_h)
feat_w = int(feat_w)
shift_x = torch.arange(0, feat_w, device=device) * stride[0]
shift_y = torch.arange(0, feat_h, device=device) * stride[1]

Expand Down
15 changes: 13 additions & 2 deletions mmdet/models/dense_heads/anchor_head.py
Expand Up @@ -620,6 +620,11 @@ def _get_bboxes_single(self,
"""
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
# convert to tensor to keep tracing
nms_pre_tensor = torch.tensor(
cfg.get('nms_pre', -1),
device=cls_score_list[0].device,
dtype=torch.long)
mlvl_bboxes = []
mlvl_scores = []
for cls_score, bbox_pred, anchors in zip(cls_score_list,
Expand All @@ -632,8 +637,14 @@ def _get_bboxes_single(self,
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
# Always keep topk op for dynamic input in onnx
if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export()
or scores.shape[-2] > nms_pre_tensor):
from torch import _shape_as_tensor
# keep shape as tensor and get k
num_anchor = _shape_as_tensor(scores)[-2]
nms_pre = torch.where(nms_pre_tensor < num_anchor,
nms_pre_tensor, num_anchor)
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1)
Expand Down

0 comments on commit 9946e12

Please sign in to comment.