Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support RTMDet and RTMPose ncnn deployment #1857

Merged
merged 18 commits into from
Mar 21, 2023
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py']

backend_config = dict(precision='FP16')
codebase_config = dict(model_type='ncnn_end2end')
onnx_config = dict(output_names=['detection_output'], input_shape=[320, 320])
4 changes: 4 additions & 0 deletions configs/mmdet/detection/single-stage_ncnn_static-320x320.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py']

codebase_config = dict(model_type='ncnn_end2end')
onnx_config = dict(output_names=['detection_output'], input_shape=[320, 320])
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/ncnn.py']

backend_config = dict(precision='FP16')
onnx_config = dict(input_shape=[192, 256], output_names=['simcc_x', 'simcc_y'])
2 changes: 0 additions & 2 deletions csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2200,8 +2200,6 @@ int main(int argc, char** argv) {
}
fprintf(pp, " 4=%d", keepdims);
fprintf(pp, " 5=1");
// Force set Reduction for FP32, FP16 may exceed for some models.
fprintf(pp, " 31=15");
} else if (op == "Reorg") {
int stride = get_node_attr_i(node, "stride", 1);
fprintf(pp, " 0=%d", stride);
Expand Down
1 change: 1 addition & 0 deletions mmdeploy/codebase/mmpose/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

from . import heads # noqa: F401,F403
from . import pose_estimators # noqa: F401,F403
from . import utils # noqa: F401,F403
5 changes: 5 additions & 0 deletions mmdeploy/codebase/mmpose/models/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.

from . import rtmcc_block

__all__ = ['rtmcc_block']
91 changes: 91 additions & 0 deletions mmdeploy/codebase/mmpose/models/utils/rtmcc_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch
import torch.nn.functional as F
from mmpose.models.utils import rope

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.utils.rtmcc_block.ScaleNorm.forward', backend='ncnn')
def scalenorm__forward__ncnn(self, x):
"""Rewrite `scalenorm` for ncnn backend.

Rewrite scalenorm to avoid FP16 exceed in ncnn Android platform.
"""
# The one-dim of Fubinious norm is equal to L2Norm.
# Set p=2 explicitly to map torch.norm to ReduceL2 onnx op,
# which will avoid FP16 exceed.
norm = torch.norm(x, dim=2, keepdim=True)
norm = norm * self.scale
# Rewrite for ncnn binaryop broadcast.
norm = norm.clamp(min=self.eps)
return (x.unsqueeze(2) / norm.unsqueeze(2)).squeeze(2) * self.g


@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.utils.rtmcc_block.RTMCCBlock._forward', backend='ncnn')
def rtmccblock___forward_ncnn(self, inputs):
"""Rewrite `_forward` of RTMBlock for ncnn backend.

Rewrite the matmul and avoid unbind for ncnn backend.
"""
if self.attn_type == 'self-attn':
x = inputs
else:
x, k, v = inputs

x = self.ln(x)
uv = self.uv(x)
if self.attn_type == 'self-attn':
uv = self.act_fn(uv)
u = uv[..., :self.e]
v = uv[..., self.e:2 * self.e]
base = uv[..., 2 * self.e:2 * self.e + self.s]

q = (base.unsqueeze(1) * self.gamma[None, None, 0:1, :] +
self.beta[None, None, 0:1, :]).squeeze(1)
k = (base.unsqueeze(1) * self.gamma[None, None, 1:2, :] +
self.beta[None, None, 1:2, :]).squeeze(1)

if self.pos_enc:
q = rope(q, dim=1)
k = rope(k, dim=1)
else:
u, q = torch.split(self.act_fn(uv), [self.e, self.s], dim=-1)

k = self.k_fc(k)
v = self.v_fc(v)

if self.pos_enc:
q = rope(q, 1)
k = rope(k, 1)
qk = torch.bmm(q, k.permute(0, 2, 1))
if self.use_rel_bias:
if self.attn_type == 'self-attn':
bias = self.rel_pos_bias(q.size(1))
else:
bias = self.rel_pos_bias(q.size(1), k.size(1))
qk += bias[:, :q.size(1), :k.size(1)]

kernel = torch.square(F.relu(qk / self.sqrt_s))
if self.dropout_rate > 0.:
kernel = self.dropout(kernel)

x = u * torch.bmm(kernel, v)
x = self.o(x)

return x


@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.utils.rtmcc_block.Scale.forward', backend='ncnn')
def scale__forward_ncnn(self, x):
"""Rewrite `forward` of Scale for ncnn backend.

Adapt the shape to avoid ncnn BinaryOp seg fault.
"""
x = x.unsqueeze(1)
scale = self.scale[None, None, None, :]
return (x * scale).squeeze(1)
25 changes: 25 additions & 0 deletions mmdeploy/pytorch/functions/normalize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Optional, Sequence, Union

import torch

from mmdeploy.core import FUNCTION_REWRITER
Expand Down Expand Up @@ -39,3 +41,26 @@ def normalize__ncnn(input: torch.Tensor,
input.transpose(1, dim), p=p, dim=1,
eps=eps).transpose(1, dim)
return output


@FUNCTION_REWRITER.register_rewriter(func_name='torch.norm', backend='ncnn')
def norm__ncnn(input: torch.Tensor,
p: Optional[Union[int, str]] = 'fro',
dim: Optional[Union[int, Sequence]] = None,
keepdim: Optional[bool] = False,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None):
"""Rewrite `torch.norm` for ncnn backend.

Rewrite torch.norm when p is Frobenius norm to avoid FP16 exceed in ncnn
Android platform.
"""
ctx = FUNCTION_REWRITER.get_context()
origin_func = ctx.origin_func
if p == 'fro' and (isinstance(dim, int) or len(dim) == 1):
# Substitute Frobenius norm with L2 norm.
return origin_func(
input, p=2, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
else:
return origin_func(
input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
80 changes: 80 additions & 0 deletions tests/test_codebase/test_mmpose/test_mmpose_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from mmdeploy.utils import Backend, Codebase
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs

try:
from torch.testing import assert_close as torch_assert_close
except Exception:
from torch.testing import assert_allclose as torch_assert_close

try:
import_codebase(Codebase.MMPOSE)
except ImportError:
Expand Down Expand Up @@ -108,3 +113,78 @@ def test_estimator_forward(backend_type: Backend):
run_with_backend=False,
deploy_cfg=deploy_cfg)
assert isinstance(rewrite_outputs, torch.Tensor)


def get_scale_norm_model():
from mmpose.models.utils.rtmcc_block import ScaleNorm

model = ScaleNorm(48)
model.requires_grad_(False)
return model


@pytest.mark.parametrize('backend_type', [Backend.NCNN])
def test_scale_norm_forward(backend_type: Backend):
check_backend(backend_type, True)
deploy_cfg = generate_mmpose_deploy_config(backend_type.value)
model = get_scale_norm_model()
x = torch.rand(1, 17, 48)
wrapped_model = WrapModel(model, 'forward')
model_outputs = model.forward(x)
rewrite_inputs = {'x': x}
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg,
run_with_backend=False)
torch_assert_close(rewrite_outputs, model_outputs)


def get_rtmcc_block_model():
from mmpose.models.utils.rtmcc_block import RTMCCBlock

model = RTMCCBlock(48, 48, 48)
model.requires_grad_(False)
return model


@pytest.mark.parametrize('backend_type', [Backend.NCNN])
def test_rtmcc_block_forward(backend_type: Backend):
check_backend(backend_type, True)
deploy_cfg = generate_mmpose_deploy_config(backend_type.value)
model = get_rtmcc_block_model()
inputs = torch.rand(1, 17, 48)
wrapped_model = WrapModel(model, '_forward')
model_outputs = model._forward(inputs)
rewrite_inputs = {'inputs': inputs}
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg,
run_with_backend=False)
torch_assert_close(rewrite_outputs, model_outputs)


def get_scale_model():
from mmpose.models.utils.rtmcc_block import Scale

model = Scale(48)
model.requires_grad_(False)
return model


@pytest.mark.parametrize('backend_type', [Backend.NCNN])
def test_scale_forward(backend_type: Backend):
check_backend(backend_type, True)
deploy_cfg = generate_mmpose_deploy_config(backend_type.value)
model = get_scale_model()
x = torch.rand(1, 17, 48)
wrapped_model = WrapModel(model, 'forward')
model_outputs = model.forward(x)
rewrite_inputs = {'x': x}
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg,
run_with_backend=False)
torch_assert_close(rewrite_outputs, model_outputs)
22 changes: 22 additions & 0 deletions tests/test_pytorch/test_pytorch_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,28 @@ def linear_caller(*arg, **kwargs):
assert np.allclose(model_output, rewrite_output[0], rtol=1e-03, atol=1e-05)


@backend_checker(Backend.NCNN)
def test_norm_ncnn():
import onnx

import mmdeploy.apis.ncnn as ncnn_apis
from mmdeploy.utils.test import get_onnx_model

input = torch.rand(1, 17, 24)
wrapped_func = WrapFunction(torch.norm, p='fro', dim=2, keepdim=True)
model_inputs = {'input': input}
ir_file_path = get_onnx_model(wrapped_func, model_inputs, deploy_cfg_ncnn)
assert osp.exists(ir_file_path)
onnx_model = onnx.load(ir_file_path)
nodes = onnx_model.graph.node
assert nodes[-1].name.startswith('ReduceL2')
ncnn_files_prefix = osp.splitext(ir_file_path)[0]
ncnn_apis.from_onnx(ir_file_path, ncnn_files_prefix)
param_path, bin_path = ncnn_apis.get_output_model_file(ir_file_path)
assert osp.exists(param_path)
assert osp.exists(bin_path)


@backend_checker(Backend.TENSORRT)
def test_repeat_static():
input = torch.rand([1])
Expand Down