From ab5c51f3ab62a20fc0c224022619d10082d200b8 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Sat, 18 Dec 2021 14:36:54 +0800 Subject: [PATCH] fix mmdet tests (#302) * fix mmdet tests * fix --- .../mmdet/models/dense_heads/base_dense_head.py | 10 +++++----- tests/test_codebase/test_mmdet/test_mmdet_models.py | 11 ++++++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py index 9a38bd549..538f2779f 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py @@ -104,23 +104,23 @@ def base_dense_head__get_bbox(ctx, score_factors = score_factors.permute(0, 2, 3, 1).reshape(batch_size, -1).sigmoid() + score_factors = score_factors.unsqueeze(2) bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) if not is_dynamic_flag: priors = priors.data priors = priors.expand(batch_size, -1, priors.size(-1)) if pre_topk > 0: if with_score_factors: - nms_pre_score = (nms_pre_score * score_factors[..., None]) + nms_pre_score = nms_pre_score * score_factors if backend == Backend.TENSORRT: priors = pad_with_value(priors, 1, pre_topk) bbox_pred = pad_with_value(bbox_pred, 1, pre_topk) scores = pad_with_value(scores, 1, pre_topk, 0.) nms_pre_score = pad_with_value(nms_pre_score, 1, pre_topk, 0.) if with_score_factors: - score_factors = pad_with_value( - score_factors.unsqueeze(2), 1, pre_topk, 0.) - else: - score_factors = score_factors.unsqueeze(2) + score_factors = pad_with_value(score_factors, 1, pre_topk, + 0.) + # Get maximum scores for foreground classes. if self.use_sigmoid_cls: max_scores, _ = nms_pre_score.max(-1) diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index 28cc4e4c8..32da1d0c5 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -11,8 +11,8 @@ from mmdeploy.codebase import import_codebase from mmdeploy.utils import Backend, Codebase -from mmdeploy.utils.test import (WrapModel, backend_checker, check_backend, - get_model_outputs, get_rewrite_outputs) +from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs, + get_rewrite_outputs) import_codebase(Codebase.MMDET) @@ -292,15 +292,16 @@ def _replace_r50_with_r18(model): return model +@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME]) @pytest.mark.parametrize('model_cfg_path', [ 'tests/test_codebase/test_mmdet/data/single_stage_model.json', 'tests/test_codebase/test_mmdet/data/mask_model.json' ]) -@backend_checker(Backend.ONNXRUNTIME) -def test_forward_of_base_detector(model_cfg_path): +def test_forward_of_base_detector(model_cfg_path, backend): + check_backend(backend) deploy_cfg = mmcv.Config( dict( - backend_config=dict(type='onnxruntime'), + backend_config=dict(type=backend.value), onnx_config=dict( output_names=['dets', 'labels'], input_shape=None), codebase_config=dict(