diff --git a/mmdet/core/anchor/anchor_generator.py b/mmdet/core/anchor/anchor_generator.py index 29b5ed04b95..388d2608b81 100644 --- a/mmdet/core/anchor/anchor_generator.py +++ b/mmdet/core/anchor/anchor_generator.py @@ -196,8 +196,9 @@ def _meshgrid(self, x, y, row_major=True): Returns: tuple[torch.Tensor]: The mesh grids of x and y. """ - xx = x.repeat(len(y)) - yy = y.view(-1, 1).repeat(1, len(x)).view(-1) + # use shape instead of len to keep tracing while exporting to onnx + xx = x.repeat(y.shape[0]) + yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1) if row_major: return xx, yy else: @@ -250,10 +251,8 @@ def single_level_grid_anchors(self, Returns: torch.Tensor: Anchors in the overall feature maps. """ + # keep as Tensor, so that we can covert to ONNX correctly feat_h, feat_w = featmap_size - # convert Tensor to int, so that we can covert to ONNX correctlly - feat_h = int(feat_h) - feat_w = int(feat_w) shift_x = torch.arange(0, feat_w, device=device) * stride[0] shift_y = torch.arange(0, feat_h, device=device) * stride[1] diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py index 970b63dc2b6..8ae50d38dd9 100644 --- a/mmdet/models/dense_heads/anchor_head.py +++ b/mmdet/models/dense_heads/anchor_head.py @@ -620,6 +620,11 @@ def _get_bboxes_single(self, """ cfg = self.test_cfg if cfg is None else cfg assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) + # convert to tensor to keep tracing + nms_pre_tensor = torch.tensor( + cfg.get('nms_pre', -1), + device=cls_score_list[0].device, + dtype=torch.long) mlvl_bboxes = [] mlvl_scores = [] for cls_score, bbox_pred, anchors in zip(cls_score_list, @@ -632,8 +637,14 @@ def _get_bboxes_single(self, else: scores = cls_score.softmax(-1) bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) - nms_pre = cfg.get('nms_pre', -1) - if nms_pre > 0 and scores.shape[0] > nms_pre: + # Always keep topk op for dynamic input in onnx + if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export() + or scores.shape[-2] > nms_pre_tensor): + from torch import _shape_as_tensor + # keep shape as tensor and get k + num_anchor = _shape_as_tensor(scores)[-2] + nms_pre = torch.where(nms_pre_tensor < num_anchor, + nms_pre_tensor, num_anchor) # Get maximum scores for foreground classes. if self.use_sigmoid_cls: max_scores, _ = scores.max(dim=1)