Skip to content

Commit

Permalink
fix mmdet tests (open-mmlab#302)
Browse files Browse the repository at this point in the history
* fix mmdet tests

* fix
  • Loading branch information
RunningLeon committed Dec 18, 2021
1 parent 3be1779 commit ab5c51f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
10 changes: 5 additions & 5 deletions mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,23 @@ def base_dense_head__get_bbox(ctx,
score_factors = score_factors.permute(0, 2, 3,
1).reshape(batch_size,
-1).sigmoid()
score_factors = score_factors.unsqueeze(2)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
if not is_dynamic_flag:
priors = priors.data
priors = priors.expand(batch_size, -1, priors.size(-1))
if pre_topk > 0:
if with_score_factors:
nms_pre_score = (nms_pre_score * score_factors[..., None])
nms_pre_score = nms_pre_score * score_factors
if backend == Backend.TENSORRT:
priors = pad_with_value(priors, 1, pre_topk)
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
scores = pad_with_value(scores, 1, pre_topk, 0.)
nms_pre_score = pad_with_value(nms_pre_score, 1, pre_topk, 0.)
if with_score_factors:
score_factors = pad_with_value(
score_factors.unsqueeze(2), 1, pre_topk, 0.)
else:
score_factors = score_factors.unsqueeze(2)
score_factors = pad_with_value(score_factors, 1, pre_topk,
0.)

# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = nms_pre_score.max(-1)
Expand Down
11 changes: 6 additions & 5 deletions tests/test_codebase/test_mmdet/test_mmdet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Backend, Codebase
from mmdeploy.utils.test import (WrapModel, backend_checker, check_backend,
get_model_outputs, get_rewrite_outputs)
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
get_rewrite_outputs)

import_codebase(Codebase.MMDET)

Expand Down Expand Up @@ -292,15 +292,16 @@ def _replace_r50_with_r18(model):
return model


@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
@pytest.mark.parametrize('model_cfg_path', [
'tests/test_codebase/test_mmdet/data/single_stage_model.json',
'tests/test_codebase/test_mmdet/data/mask_model.json'
])
@backend_checker(Backend.ONNXRUNTIME)
def test_forward_of_base_detector(model_cfg_path):
def test_forward_of_base_detector(model_cfg_path, backend):
check_backend(backend)
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='onnxruntime'),
backend_config=dict(type=backend.value),
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
codebase_config=dict(
Expand Down

0 comments on commit ab5c51f

Please sign in to comment.