Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support RTMDet and RTMPose ncnn deployment #1857

Merged
merged 18 commits into from
Mar 21, 2023
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py']

backend_config = dict(precision='FP16')
codebase_config = dict(model_type='ncnn_end2end')
onnx_config = dict(output_names=['detection_output'], input_shape=[320, 320])
4 changes: 4 additions & 0 deletions configs/mmdet/detection/single-stage_ncnn_static-320x320.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py']

codebase_config = dict(model_type='ncnn_end2end')
onnx_config = dict(output_names=['detection_output'], input_shape=[320, 320])
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/ncnn.py']

backend_config = dict(precision='FP16')
onnx_config = dict(input_shape=[192, 256], output_names=['simcc_x', 'simcc_y'])
2 changes: 0 additions & 2 deletions csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2200,8 +2200,6 @@ int main(int argc, char** argv) {
}
fprintf(pp, " 4=%d", keepdims);
fprintf(pp, " 5=1");
// Force set Reduction for FP32, FP16 may exceed for some models.
fprintf(pp, " 31=15");
} else if (op == "Reorg") {
int stride = get_node_attr_i(node, "stride", 1);
fprintf(pp, " 0=%d", stride);
Expand Down
118 changes: 118 additions & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/rtmdet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mmdeploy.codebase.mmdet import get_post_processing_params
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.mmcv.ops import multiclass_nms
from mmdeploy.utils import Backend


@FUNCTION_REWRITER.register_rewriter(
Expand Down Expand Up @@ -105,3 +106,120 @@ def __mark_pred_maps(cls_scores, bbox_preds):
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.rtmdet_head.'
'RTMDetHead.predict_by_feat',
backend=Backend.NCNN.value)
def rtmdet_head__predict_by_feat__ncnn(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True):
"""Rewrite `predict_by_feat` of RTMDetHead for ncnn backend.
1. Decode the prior to a box format for ncnn DetectionOutput layer to do
the post-processing.
2. Batch dimension is not supported by ncnn, but supported by pytorch.
The negative value of axis in torch.cat is rewritten as corresponding
positive value to avoid axis shift.
3. 2-dimension tensor broadcast of `BinaryOps` operator is not supported by
ncnn. This function unsqueeze 2-dimension tensor to 3-dimension tensor for
correct `BinaryOps` calculation by ncnn.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
objectnesses (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
output__ncnn (Tensor): outputs, shape is [N, num_det, 6].
"""
ctx = FUNCTION_REWRITER.get_context()
from mmdeploy.codebase.mmdet.ops import ncnn_detection_output_forward
from mmdeploy.utils import get_root_logger
from mmdeploy.utils.config_utils import is_dynamic_shape
dynamic_flag = is_dynamic_shape(ctx.cfg)
if dynamic_flag:
logger = get_root_logger()
logger.warning('RTMDet does not support dynamic shape with ncnn.')
img_height = int(batch_img_metas[0]['img_shape'][0])
img_width = int(batch_img_metas[0]['img_shape'][1])

assert len(cls_scores) == len(bbox_preds)
device = cls_scores[0].device
cfg = self.test_cfg if cfg is None else cfg
batch_size = bbox_preds[0].shape[0]
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, device=device, with_stride=True)
mlvl_priors = [mlvl_prior.unsqueeze(0) for mlvl_prior in mlvl_priors]
flatten_priors = torch.cat(mlvl_priors, dim=1)

flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
for bbox_pred in bbox_preds
]

cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
dummy_cls_scores = torch.zeros(
hanrui1sensetime marked this conversation as resolved.
Show resolved Hide resolved
batch_size, cls_scores.shape[-2], 1, device=cls_scores.device)

batch_mlvl_scores = torch.cat([dummy_cls_scores, cls_scores], dim=2)

flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
assert flatten_priors.shape[-1] == 4, f'rtmdet needs (B, N, 4) priors, got\
(B, N, {flatten_priors.shape[-1]})'

tl_x = (flatten_priors[:, :, 0:1] -
flatten_bbox_preds[:, :, 0:1]) / img_width
tl_y = (flatten_priors[:, :, 1:2] -
flatten_bbox_preds[:, :, 1:2]) / img_height
br_x = (flatten_priors[:, :, 0:1] +
flatten_bbox_preds[:, :, 2:3]) / img_width
br_y = (flatten_priors[:, :, 1:2] +
flatten_bbox_preds[:, :, 3:4]) / img_height
prior_box_ncnn = torch.stack([tl_x, tl_y, br_x, br_y], -1)

scores = batch_mlvl_scores

batch_mlvl_bboxes = flatten_bbox_preds.reshape(batch_size, 1, -1)
batch_mlvl_scores = scores.reshape(batch_size, 1, -1)
batch_mlvl_priors = prior_box_ncnn.reshape(batch_size, 1, -1)
batch_mlvl_vars = torch.ones_like(batch_mlvl_priors)
batch_mlvl_priors = torch.cat([batch_mlvl_priors, batch_mlvl_vars], dim=1)
deploy_cfg = ctx.cfg
post_params = get_post_processing_params(deploy_cfg)
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)

vars = torch.tensor([1, 1, 1, 1], dtype=torch.float32)
output__ncnn = ncnn_detection_output_forward(
batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_priors,
score_threshold, iou_threshold, pre_top_k, keep_top_k,
self.num_classes + 1,
vars.cpu().detach().numpy())
return output__ncnn
1 change: 1 addition & 0 deletions mmdeploy/codebase/mmpose/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

from . import heads # noqa: F401,F403
from . import pose_estimators # noqa: F401,F403
from . import utils # noqa: F401,F403
5 changes: 5 additions & 0 deletions mmdeploy/codebase/mmpose/models/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.

from . import rtmcc_block

__all__ = ['rtmcc_block']
91 changes: 91 additions & 0 deletions mmdeploy/codebase/mmpose/models/utils/rtmcc_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch
import torch.nn.functional as F
from mmpose.models.utils import rope

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.utils.rtmcc_block.ScaleNorm.forward', backend='ncnn')
def scalenorm__forward__ncnn(self, x):
"""Rewrite `scalenorm` for ncnn backend.

Rewrite scalenorm to avoid FP16 exceed in ncnn Android platform.
"""
# The one-dim of Fubinious norm is equal to L2Norm.
# Set p=2 explicitly to map torch.norm to ReduceL2 onnx op,
# which will avoid FP16 exceed.
norm = torch.norm(x, dim=2, keepdim=True)
norm = norm * self.scale
# Rewrite for ncnn binaryop broadcast.
norm = norm.clamp(min=self.eps)
return (x.unsqueeze(2) / norm.unsqueeze(2)).squeeze(2) * self.g


@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.utils.rtmcc_block.RTMCCBlock._forward', backend='ncnn')
def rtmccblock___forward_ncnn(self, inputs):
"""Rewrite `_forward` of RTMBlock for ncnn backend.

Rewrite the matmul and avoid unbind for ncnn backend.
"""
if self.attn_type == 'self-attn':
x = inputs
else:
x, k, v = inputs

x = self.ln(x)
uv = self.uv(x)
if self.attn_type == 'self-attn':
uv = self.act_fn(uv)
u = uv[..., :self.e]
v = uv[..., self.e:2 * self.e]
base = uv[..., 2 * self.e:2 * self.e + self.s]

q = (base.unsqueeze(1) * self.gamma[None, None, 0:1, :] +
self.beta[None, None, 0:1, :]).squeeze(1)
k = (base.unsqueeze(1) * self.gamma[None, None, 1:2, :] +
self.beta[None, None, 1:2, :]).squeeze(1)

if self.pos_enc:
q = rope(q, dim=1)
k = rope(k, dim=1)
else:
u, q = torch.split(self.act_fn(uv), [self.e, self.s], dim=-1)

k = self.k_fc(k)
v = self.v_fc(v)

if self.pos_enc:
q = rope(q, 1)
k = rope(k, 1)
qk = torch.bmm(q, k.permute(0, 2, 1))
if self.use_rel_bias:
if self.attn_type == 'self-attn':
bias = self.rel_pos_bias(q.size(1))
else:
bias = self.rel_pos_bias(q.size(1), k.size(1))
qk += bias[:, :q.size(1), :k.size(1)]

kernel = torch.square(F.relu(qk / self.sqrt_s))
if self.dropout_rate > 0.:
kernel = self.dropout(kernel)

x = u * torch.bmm(kernel, v)
x = self.o(x)

return x


@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.utils.rtmcc_block.Scale.forward', backend='ncnn')
def scale__forward_ncnn(self, x):
"""Rewrite `forward` of Scale for ncnn backend.

Adapt the shape to avoid ncnn BinaryOp seg fault.
"""
x = x.unsqueeze(1)
scale = self.scale[None, None, None, :]
return (x * scale).squeeze(1)
25 changes: 25 additions & 0 deletions mmdeploy/pytorch/functions/normalize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Optional, Sequence, Union

import torch

from mmdeploy.core import FUNCTION_REWRITER
Expand Down Expand Up @@ -39,3 +41,26 @@ def normalize__ncnn(input: torch.Tensor,
input.transpose(1, dim), p=p, dim=1,
eps=eps).transpose(1, dim)
return output


@FUNCTION_REWRITER.register_rewriter(func_name='torch.norm', backend='ncnn')
def norm__ncnn(input: torch.Tensor,
p: Optional[Union[int, str]] = 'fro',
dim: Optional[Union[int, Sequence]] = None,
keepdim: Optional[bool] = False,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None):
"""Rewrite `torch.norm` for ncnn backend.

Rewrite torch.norm when p is Frobenius norm to avoid FP16 exceed in ncnn
Android platform.
"""
ctx = FUNCTION_REWRITER.get_context()
origin_func = ctx.origin_func
if p == 'fro' and (isinstance(dim, int) or len(dim) == 1):
# Substitute Frobenius norm with L2 norm.
return origin_func(
input, p=2, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
else:
return origin_func(
input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
85 changes: 85 additions & 0 deletions tests/test_codebase/test_mmdet/test_mmdet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,3 +2121,88 @@ def test_solo_head_predict_by_feat(backend_type: Backend):
atol=1e-05)
else:
assert rewrite_outputs is not None


def get_rtmdet_head_model():

from mmdet.models.dense_heads import RTMDetHead
from mmdet.models.task_modules.prior_generators.point_generator import \
MlvlPointGenerator

test_cfg = Config(
dict(
deploy_nms_pre=0,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
model = RTMDetHead(1, 64)
model.prior_generator = MlvlPointGenerator([8, 4, 2])
model.test_cfg = test_cfg

model.requires_grad_(False)
return model


def test_rtmdet_head_predict_by_feat_ncnn():
"""Test predict_by_feat rewrite of yolov3 head."""
backend_type = Backend.NCNN
check_backend(backend_type)
rtmdet_head = get_rtmdet_head_model()
rtmdet_head.cpu().eval()
s = 320
batch_img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]

output_names = ['detection_output']
deploy_cfg = Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
model_type='ncnn_end2end',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.45,
confidence_threshold=0.005,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=10,
background_label_id=-1,
))))

seed_everything(1234)
cls_scores = [
torch.rand(1, 1, 40, 40),
torch.rand(1, 1, 20, 20),
torch.rand(1, 1, 10, 10)
]

bbox_preds = [
torch.rand(1, 4, 40, 40),
torch.rand(1, 4, 20, 20),
torch.rand(1, 4, 10, 10)
]

# to get outputs of onnx model after rewrite
wrapped_model = WrapModel(
rtmdet_head,
'predict_by_feat',
batch_img_metas=batch_img_metas,
with_nms=True)
rewrite_inputs = {'cls_scores': cls_scores, 'bbox_preds': bbox_preds}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg,
run_with_backend=False)
# output should be of shape [1, N, 6]
if is_backend_output:
assert rewrite_outputs[0].shape[-1] == 6
else:
assert rewrite_outputs.shape[-1] == 6
Loading