Skip to content

Commit

Permalink
Fix coreml (#1658)
Browse files Browse the repository at this point in the history
* fix coreml topk

* update

* fix lint
  • Loading branch information
grimoire committed Jan 19, 2023
1 parent bce276e commit 513b1c3
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 47 deletions.
11 changes: 11 additions & 0 deletions configs/mmcls/classification_coreml_static-384x384.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = ['./classification_coreml_dynamic-224x224-224x224.py']

ir_config = dict(input_shape=(384, 384))
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 384, 384],
max_shape=[1, 3, 384, 384],
default_shape=[1, 3, 384, 384])))
])
11 changes: 11 additions & 0 deletions configs/mmdet/detection/detection_coreml_static-608x608.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = ['../_base_/base_torchscript.py', '../../_base_/backends/coreml.py']

ir_config = dict(input_shape=(608, 608))
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 608, 608],
max_shape=[1, 3, 608, 608],
default_shape=[1, 3, 608, 608])))
])
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_base_ = [
'../../_base_/torchscript_config.py', '../../_base_/backends/coreml.py'
]

codebase_config = dict(type='mmocr', task='TextRecognition')
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 32, 32],
max_shape=[1, 3, 32, 640],
default_shape=[1, 3, 32, 64])))
])
42 changes: 37 additions & 5 deletions mmdeploy/backend/coreml/backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def is_available(cls, with_custom_ops: bool = False) -> bool:
bool: True if backend package is installed.
"""
import importlib
return importlib.util.find_spec('coreml') is not None
return importlib.util.find_spec('coremltools') is not None

@classmethod
def get_version(cls) -> str:
Expand All @@ -53,7 +53,7 @@ def get_version(cls) -> str:
else:
import pkg_resources
try:
return pkg_resources.get_distribution('coreml').version
return pkg_resources.get_distribution('coremltools').version
except Exception:
return 'None'

Expand All @@ -78,14 +78,46 @@ def to_backend(cls,
Returns:
Seqeuence[str]: Backend files.
"""
from .torchscript2coreml import from_torchscript
from mmdeploy.utils import (get_common_config, get_ir_config,
get_model_inputs, load_config)
from .torchscript2coreml import from_torchscript, get_model_suffix

coreml_files = []
for model_id, torchscript_path in enumerate(ir_files):
torchscript_name = osp.splitext(osp.split(torchscript_path)[1])[0]
output_file_prefix = osp.join(work_dir, torchscript_name)

from_torchscript(model_id, torchscript_path, output_file_prefix,
deploy_cfg, coreml_files)
deploy_cfg = load_config(deploy_cfg)[0]

common_params = get_common_config(deploy_cfg)
model_params = get_model_inputs(deploy_cfg)[model_id]

final_params = common_params
final_params.update(model_params)

ir_config = get_ir_config(deploy_cfg)
input_names = ir_config.get('input_names', [])
output_names = ir_config.get('output_names', [])
input_shapes = final_params['input_shapes']
compute_precision = final_params.get('compute_precision',
'FLOAT32')
convert_to = deploy_cfg.backend_config.convert_to

minimum_deployment_target = final_params.get(
'minimum_deployment_target', None)
skip_model_load = final_params.get('skip_model_load', False)
from_torchscript(
torchscript_path,
output_file_prefix,
input_names=input_names,
output_names=output_names,
input_shapes=input_shapes,
compute_precision=compute_precision,
convert_to=convert_to,
minimum_deployment_target=minimum_deployment_target,
skip_model_load=skip_model_load)

suffix = get_model_suffix(convert_to)
output_path = output_file_prefix + suffix
coreml_files.append(output_path)
return coreml_files
50 changes: 13 additions & 37 deletions mmdeploy/backend/coreml/torchscript2coreml.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Dict, List, Union
from typing import Dict, Optional, Sequence, Union

import coremltools as ct
import mmcv
import torch

from mmdeploy.utils import (get_common_config, get_model_inputs,
get_root_logger, load_config)
from mmdeploy.utils.config_utils import get_ir_config
from mmdeploy.utils import get_root_logger

try:
# user might need ops from torchvision
Expand Down Expand Up @@ -50,24 +46,23 @@ def create_shape(name: str, input_shapes: Dict) -> ct.Shape:
return ct.TensorType(shape=shape, name=name)


def from_torchscript(model_id: int,
torchscript_model: Union[str,
def from_torchscript(torchscript_model: Union[str,
torch.jit.RecursiveScriptModule],
output_file_prefix: str, deploy_cfg: Union[str,
mmcv.Config],
backend_files: List[str], **kwargs):
output_file_prefix: str,
input_names: Sequence[str],
output_names: Sequence[str],
input_shapes: Dict[str, Dict],
compute_precision: str = 'FLOAT32',
convert_to: str = 'neuralnetwork',
minimum_deployment_target: Optional[str] = None,
skip_model_load: bool = False):
"""Create a coreml engine from torchscript.
Args:
model_id (int): Index of input model.
torchscript_model (Union[str, torch.jit.RecursiveScriptModule]):
The torchscript model to be converted.
output_file_prefix (str): The output file prefix.
deploy_cfg (str | mmcv.Config): Deployment config.
backend_files (List[str]):
Backend files used by deployment for testing pipeline
"""

try:
from mmdeploy.backend.torchscript import get_ops_path
torch.ops.load_library(get_ops_path())
Expand All @@ -80,40 +75,22 @@ def from_torchscript(model_id: int,
if isinstance(torchscript_model, str):
torchscript_model = torch.jit.load(torchscript_model)

deploy_cfg = load_config(deploy_cfg)[0]

common_params = get_common_config(deploy_cfg)
model_params = get_model_inputs(deploy_cfg)[model_id]

final_params = common_params
final_params.update(model_params)

ir_config = get_ir_config(deploy_cfg)

input_names = ir_config.get('input_names', [])
input_shapes = final_params['input_shapes']
inputs = []

for name in input_names:
shape = create_shape(name, input_shapes[name])
inputs.append(shape)

output_names = ir_config.get('output_names', [])
outputs = []

for name in output_names:
outputs.append(ct.TensorType(name=name))

convert_to = deploy_cfg.backend_config.convert_to
if convert_to == 'neuralnetwork':
# Compute precision must be None for neuralnetwork conversion
compute_precision = None
else:
compute_precision = ct.precision[final_params.get(
'compute_precision', 'FLOAT32')]

minimum_deployment_target = final_params.get('minimum_deployment_target',
None)
compute_precision = ct.precision[compute_precision]

mlmodel = ct.convert(
model=torchscript_model,
Expand All @@ -123,9 +100,8 @@ def from_torchscript(model_id: int,
convert_to=convert_to,
minimum_deployment_target=ct.target[minimum_deployment_target]
if minimum_deployment_target else None,
skip_model_load=final_params.get('skip_model_load', False))
skip_model_load=skip_model_load)

suffix = get_model_suffix(convert_to)
output_path = output_file_prefix + suffix
backend_files.append(output_path)
mlmodel.save(output_path)
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def base_dense_head__get_bbox(ctx,
else:
max_scores, _ = nms_pre_score[..., :-1].max(-1)
_, topk_inds = max_scores.topk(pre_topk)

bbox_pred, scores, score_factors = gather_topk(
bbox_pred,
scores,
Expand Down
24 changes: 24 additions & 0 deletions mmdeploy/pytorch/functions/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,27 @@ def topk__tensorrt(ctx,
k = TENSORRT_MAX_TOPK

return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)


@FUNCTION_REWRITER.register_rewriter(func_name='torch.topk', backend='coreml')
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.topk', backend='coreml')
def topk__coreml(ctx,
input: torch.Tensor,
k: int,
dim: Optional[int] = None,
largest: bool = True,
sorted: bool = True):
"""Rewrite `topk` for coreml backend.
Cast k to tensor and make sure k is smaller than input.shape[dim].
"""

if dim is None:
dim = int(input.ndim - 1)
size = input.shape[dim]
if not isinstance(k, torch.Tensor):
k = torch.tensor(k, device=input.device, dtype=torch.long)
# Always keep topk op for dynamic input
k = torch.where(k < size, k, size)
return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)
35 changes: 30 additions & 5 deletions tests/test_backend/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def generate_torchscript_file():
context_info=context_info)


def onnx2backend(backend, onnx_file):
def ir2backend(backend, onnx_file, ts_file):
if backend == Backend.TENSORRT:
from mmdeploy.backend.tensorrt import from_onnx
backend_file = tempfile.NamedTemporaryFile(suffix='.engine').name
Expand Down Expand Up @@ -143,6 +143,34 @@ def onnx2backend(backend, onnx_file):
onnx_file, lib_file, shape=shape, dtype=dtype, tuner=tuner_dict)
assert osp.exists(lib_file)
return lib_file
elif backend == Backend.TORCHSCRIPT:
return ts_file
elif backend == Backend.COREML:
output_names = ['output']
from mmdeploy.backend.coreml.torchscript2coreml import (
from_torchscript, get_model_suffix)
backend_dir = tempfile.TemporaryDirectory().name
work_dir = backend_dir
torchscript_name = osp.splitext(osp.split(ts_file)[1])[0]
output_file_prefix = osp.join(work_dir, torchscript_name)
convert_to = 'mlprogram'
from_torchscript(
ts_file,
output_file_prefix,
input_names=input_names,
output_names=output_names,
input_shapes=dict(
input=dict(
min_shape=[1, 3, 8, 8],
default_shape=[1, 3, 8, 8],
max_shape=[1, 3, 8, 8])),
convert_to=convert_to)

suffix = get_model_suffix(convert_to)
return output_file_prefix + suffix
else:
raise NotImplementedError(
f'Convert for {backend.value} has not been implemented.')


def create_wrapper(backend, model_files):
Expand Down Expand Up @@ -186,10 +214,7 @@ def run_wrapper(backend, wrapper, input):
@pytest.mark.parametrize('backend', ALL_BACKEND)
def test_wrapper(backend):
check_backend(backend, True)
if backend == Backend.TORCHSCRIPT:
model_files = ts_file
else:
model_files = onnx2backend(backend, onnx_file)
model_files = ir2backend(backend, onnx_file, ts_file)
assert model_files is not None
wrapper = create_wrapper(backend, model_files)
assert wrapper is not None
Expand Down

0 comments on commit 513b1c3

Please sign in to comment.