diff --git a/.gitignore b/.gitignore index dc02f824b..742b078d1 100644 --- a/.gitignore +++ b/.gitignore @@ -143,3 +143,6 @@ bin/ # !docs/zh_cn/build !docs/en/build + +# ncnn +mmdeploy/backend/ncnn/onnx2ncnn diff --git a/docs/en/tutorials/how_to_support_new_backends.md b/docs/en/tutorials/how_to_support_new_backends.md index c18cd8614..45038e053 100644 --- a/docs/en/tutorials/how_to_support_new_backends.md +++ b/docs/en/tutorials/how_to_support_new_backends.md @@ -38,11 +38,11 @@ The backends in MMDeploy must support the ONNX. The backend loads the ".onnx" fi if is_available(): - from .utils import create_trt_engine, load_trt_engine, save_trt_engine + from .utils import from_onnx, load, save from .wrapper import TRTWrapper __all__ = [ - 'create_trt_engine', 'save_trt_engine', 'load_trt_engine', 'TRTWrapper' + 'from_onnx', 'save', 'load', 'TRTWrapper' ] ``` diff --git a/docs/zh_cn/04-developer-guide/support_new_backend.md b/docs/zh_cn/04-developer-guide/support_new_backend.md index 0f3bc33ed..6267cfa83 100644 --- a/docs/zh_cn/04-developer-guide/support_new_backend.md +++ b/docs/zh_cn/04-developer-guide/support_new_backend.md @@ -38,11 +38,11 @@ MMDeploy 中的后端必须支持 ONNX,因此后端能直接加载“.onnx” if is_available(): - from .utils import create_trt_engine, load_trt_engine, save_trt_engine + from .utils import from_onnx, load, save from .wrapper import TRTWrapper __all__ = [ - 'create_trt_engine', 'save_trt_engine', 'load_trt_engine', 'TRTWrapper' + 'from_onnx', 'save', 'load', 'TRTWrapper' ] ``` diff --git a/mmdeploy/apis/__init__.py b/mmdeploy/apis/__init__.py index 48b1339d1..fc560e8c4 100644 --- a/mmdeploy/apis/__init__.py +++ b/mmdeploy/apis/__init__.py @@ -1,14 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .calibration import create_calib_table -from .extract_model import extract_model -from .inference import inference_model -from .pytorch2onnx import torch2onnx, torch2onnx_impl -from .pytorch2torchscript import torch2torchscript, torch2torchscript_impl -from .utils import build_task_processor, get_predefined_partition_cfg -from .visualize import visualize_model -__all__ = [ - 'create_calib_table', 'extract_model', 'inference_model', 'torch2onnx', - 'torch2onnx_impl', 'torch2torchscript', 'torch2torchscript_impl', - 'build_task_processor', 'get_predefined_partition_cfg', 'visualize_model' -] +# mmcv dependency +try: + from .calibration import create_calib_input_data + from .extract_model import extract_model + from .inference import inference_model + from .pytorch2onnx import torch2onnx + from .pytorch2torchscript import torch2torchscript + from .utils import build_task_processor, get_predefined_partition_cfg + from .visualize import visualize_model + + __all__ = [ + 'create_calib_input_data', 'extract_model', 'inference_model', + 'torch2onnx', 'torch2torchscript', 'build_task_processor', + 'get_predefined_partition_cfg', 'visualize_model' + ] +except Exception: + pass diff --git a/mmdeploy/apis/calibration.py b/mmdeploy/apis/calibration.py index 1939d502f..5623dd5ef 100644 --- a/mmdeploy/apis/calibration.py +++ b/mmdeploy/apis/calibration.py @@ -1,107 +1,78 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Optional, Union -import h5py import mmcv import torch from mmcv.parallel import MMDataParallel -from mmdeploy.core import (RewriterContext, patch_model, - reset_mark_function_count) +from mmdeploy.core import patch_model from mmdeploy.utils import cfg_apply_marks, load_config +from .core import PIPELINE_MANAGER, no_mp +from .utils import create_calib_input_data as create_calib_input_data_impl -def create_calib_table(calib_file: str, - deploy_cfg: Union[str, mmcv.Config], - model_cfg: Union[str, mmcv.Config], - model_checkpoint: Optional[str] = None, - dataset_cfg: Optional[Union[str, mmcv.Config]] = None, - dataset_type: str = 'val', - device: str = 'cuda:0', - **kwargs) -> None: - """Create calibration table. - - Examples: - >>> from mmdeploy.apis import create_calib_table - >>> from mmdeploy.utils import get_calib_filename, load_config - >>> deploy_cfg = 'configs/mmdet/detection/' \ - 'detection_tensorrt-int8_dynamic-320x320-1344x1344.py' - >>> deploy_cfg = load_config(deploy_cfg)[0] - >>> calib_file = get_calib_filename(deploy_cfg) - >>> model_cfg = 'mmdetection/configs/fcos/' \ - 'fcos_r50_caffe_fpn_gn-head_1x_coco.py' - >>> model_checkpoint = 'checkpoints/' \ - 'fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth' - >>> create_calib_table(calib_file, deploy_cfg, \ - model_cfg, model_checkpoint, device='cuda:0') +@PIPELINE_MANAGER.register_pipeline() +def create_calib_input_data(calib_file: str, + deploy_cfg: Union[str, mmcv.Config], + model_cfg: Union[str, mmcv.Config], + model_checkpoint: Optional[str] = None, + dataset_cfg: Optional[Union[str, + mmcv.Config]] = None, + dataset_type: str = 'val', + device: str = 'cpu') -> None: + """Create dataset for post-training quantization. Args: - calib_file (str): Input calibration file. - deploy_cfg (str | mmcv.Config): Deployment config. - model_cfg (str | mmcv.Config): The model config. - model_checkpoint (str): PyTorch model checkpoint, defaults to `None`. - dataset_cfg (str | mmcv.Config): Dataset config, defaults to `None` - dataset_type (str): A string specifying dataset type, e.g.: 'test', - 'val', defaults to 'val'. - device (str): Specifying the device to run on, defaults to 'cuda:0'. + calib_file (str): The output calibration data file. + deploy_cfg (str | mmcv.Config): Deployment config file or + Config object. + model_cfg (str | mmcv.Config): Model config file or Config object. + model_checkpoint (str): A checkpoint path of PyTorch model, + defaults to `None`. + dataset_cfg (Optional[Union[str, mmcv.Config]], optional): Model + config to provide calibration dataset. If none, use `model_cfg` + as the dataset config. Defaults to None. + dataset_type (str, optional): The dataset type. Defaults to 'val'. + device (str, optional): Device to create dataset. Defaults to 'cpu'. """ - if dataset_cfg is None: - dataset_cfg = model_cfg + with no_mp(): + if dataset_cfg is None: + dataset_cfg = model_cfg + + device_id = torch.device(device).index + if device_id is None: + device_id = 0 + + # load cfg if necessary + deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) - # load cfg if necessary - deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) - device_id = torch.device(device).index - if device_id is None: - device_id = 0 + if dataset_cfg is None: + dataset_cfg = model_cfg - if dataset_cfg is None: - dataset_cfg = model_cfg - # load dataset_cfg if necessary - dataset_cfg = load_config(dataset_cfg)[0] + # load dataset_cfg if necessary + dataset_cfg = load_config(dataset_cfg)[0] - from mmdeploy.apis.utils import build_task_processor - task_processor = build_task_processor(model_cfg, deploy_cfg, device) + from mmdeploy.apis.utils import build_task_processor + task_processor = build_task_processor(model_cfg, deploy_cfg, device) - apply_marks = cfg_apply_marks(deploy_cfg) - backend = 'default' - model = task_processor.init_pytorch_model(model_checkpoint) - dataset = task_processor.build_dataset(dataset_cfg, dataset_type) + apply_marks = cfg_apply_marks(deploy_cfg) - # patch model - patched_model = patch_model(model, cfg=deploy_cfg, backend=backend) + model = task_processor.init_pytorch_model(model_checkpoint) + dataset = task_processor.build_dataset(dataset_cfg, dataset_type) - with h5py.File(calib_file, mode='w') as file: - calib_data_group = file.create_group('calib_data') + # patch model + patched_model = patch_model(model, cfg=deploy_cfg) - if not apply_marks: - # create end2end group - input_data_group = calib_data_group.create_group('end2end') - input_group = input_data_group.create_group('input') dataloader = task_processor.build_dataloader( dataset, 1, 1, dist=False, shuffle=False) patched_model = MMDataParallel(patched_model, device_ids=[device_id]) - prog_bar = mmcv.ProgressBar(len(dataset)) - for data_id, input_data in enumerate(dataloader): - - if not apply_marks: - # save end2end data - input_tensor = task_processor.get_tensor_from_input(input_data) - input_ndarray = input_tensor.detach().cpu().numpy() - input_group.create_dataset( - str(data_id), - shape=input_ndarray.shape, - compression='gzip', - compression_opts=4, - data=input_ndarray) - - with torch.no_grad(), RewriterContext( - cfg=deploy_cfg, - backend=backend, - create_calib=True, - calib_file=file, - data_id=data_id): - reset_mark_function_count() - _ = task_processor.run_inference(patched_model, input_data) - file.flush() - prog_bar.update() + create_calib_input_data_impl( + calib_file, + patched_model, + dataloader, + get_tensor_func=task_processor.get_tensor_from_input, + inference_func=task_processor.run_inference, + model_partition=apply_marks, + context_info=dict(cfg=deploy_cfg), + device=device) diff --git a/mmdeploy/apis/core/__init__.py b/mmdeploy/apis/core/__init__.py new file mode 100644 index 000000000..b2f35bf75 --- /dev/null +++ b/mmdeploy/apis/core/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .pipeline_manager import PIPELINE_MANAGER, no_mp + +__all__ = ['PIPELINE_MANAGER', 'no_mp'] diff --git a/mmdeploy/apis/core/pipeline_manager.py b/mmdeploy/apis/core/pipeline_manager.py new file mode 100644 index 000000000..f46697a23 --- /dev/null +++ b/mmdeploy/apis/core/pipeline_manager.py @@ -0,0 +1,376 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import importlib +import inspect +import logging +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +from mmdeploy.utils import get_root_logger + +try: + import torch.multiprocessing as mp +except Exception: + import multiprocessing as mp + + +def _get_func_name(func: Callable) -> str: + """get function name.""" + assert isinstance(func, Callable), f'{func} is not a Callable object.' + _func_name = None + if hasattr(func, '__qualname__'): + _func_name = f'{func.__module__}.{func.__qualname__}' + elif hasattr(func, '__class__'): + _func_name = func.__class__ + else: + _func_name = str(func) + return _func_name + + +class PipelineCaller: + """Classes to record the attribute of each pipeline function.""" + + def __init__(self, + module_name: str, + impl_name: str, + func_name: Optional[str] = None, + log_level: int = logging.DEBUG, + is_multiprocess_available: bool = True) -> None: + if func_name is not None: + self._func_name = func_name + else: + self._func_name = impl_name + # Can not save the function directly since multiprocess with spawn mode + # require all field can be pickled. + self._module_name = module_name + self._impl_name = impl_name + self._is_multiprocess_available = is_multiprocess_available + self._enable_multiprocess = False + self._mp_dict = None + self._mp_async = False + self._call_id = 0 + self._log_level = log_level + self._input_hooks: List[Callable] = [] + self._output_hooks: List[Callable] = [] + + @property + def is_multiprocess_available(self) -> bool: + """check if multiprocess is available for this pipeline.""" + return self._is_multiprocess_available + + @property + def is_multiprocess(self) -> bool: + """check if this pipeline is multiprocess.""" + return self._enable_multiprocess + + @property + def input_hooks(self) -> List[Callable]: + """get input hooks.""" + return self._input_hooks + + @property + def output_hooks(self) -> List[Callable]: + """get output hooks.""" + return self._output_hooks + + def pop_mp_output(self, call_id: int = None) -> Any: + """pop multiprocess output.""" + assert self._mp_dict is not None, 'mp_dict is None.' + call_id = self._call_id if call_id is None else call_id + assert call_id in self._mp_dict, \ + f'`{self._func_name}` with Call id: {call_id} failed.' + ret = self._mp_dict[call_id] + self._mp_dict.pop(call_id) + return ret + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + do_multiprocess = self.is_multiprocess_available \ + and self.is_multiprocess\ + and self._mp_dict is not None + + logger = get_root_logger(log_level=self._log_level) + mp_log_str = 'subprocess' if do_multiprocess else 'main process' + logger.log(self._log_level, + f'Start pipeline {self._func_name} in {mp_log_str}') + + for input_hook in self.input_hooks: + args, kwargs = input_hook(*args, **kwargs) + + module_name = self._module_name + impl_name = self._impl_name + # TODO: find another way to load function + mod = importlib.import_module(module_name) + func = getattr(mod, impl_name, None) + assert func is not None, \ + f'Can not find implementation of {self._func_name}' + ret = func(*args, **kwargs) + for output_hook in self.output_hooks: + ret = output_hook(ret) + + if do_multiprocess: + self._mp_dict[self._call_id] = ret + + logger.log(self._log_level, f'Finish pipeline {self._func_name}') + return ret + + +class PipelineResult: + """The result of async pipeline.""" + + def __init__(self, manager: Any, call_id: int) -> None: + self._manager = manager + self._call_id = call_id + + @property + def call_id(self) -> int: + return self._call_id + + def get(self) -> Any: + """get result.""" + return self._manager.get_result_sync(self._call_id) + + +FUNC_NAME_TYPE = Union[str, Callable] + + +class PipelineManager: + """This is a tool to manager all pipeline functions.""" + + def __init__(self) -> None: + self._enable_multiprocess = True + self._mp_manager = None + self._callers: Dict[str, PipelineCaller] = dict() + self._call_id = 0 + self._proc_async: Dict[int, (str, mp.Process)] = dict() + + @property + def mp_manager(self) -> Optional[mp.Manager]: + """get multiprocess manager.""" + return self._mp_manager + + def get_caller(self, func_name: FUNC_NAME_TYPE) -> PipelineCaller: + """get caller of given function.""" + if isinstance(func_name, Callable): + func_name = _get_func_name(func_name) + assert func_name in self._callers, \ + f'{func_name} has not been registered.' + return self._callers[func_name] + + def __set_caller_val(self, + val_name: str, + val: Any, + func_name: Optional[FUNC_NAME_TYPE] = None) -> None: + """helper to set any caller value.""" + if func_name is None: + for func_name_ in self._callers: + setattr(self.get_caller(func_name_), val_name, val) + else: + setattr(self.get_caller(func_name), val_name, val) + + def _create_mp_manager(self) -> None: + """create multiprocess manager if not exists.""" + if self._mp_manager is None: + self._mp_manager = mp.Manager() + + def _enable_multiprocess_single(self, + val: bool, + func_name: FUNC_NAME_TYPE = None) -> None: + """implement of enable_multiprocess.""" + pipe_caller = self.get_caller(func_name) + # check if multiprocess is available for this function + if not pipe_caller.is_multiprocess_available: + return + pipe_caller._enable_multiprocess = val + if val is True and self.mp_manager is not None: + pipe_caller._mp_dict = self.mp_manager.dict() + else: + pipe_caller._mp_dict = None + + def enable_multiprocess( + self, + val: bool, + func_names: Optional[Union[FUNC_NAME_TYPE, + Sequence[FUNC_NAME_TYPE]]] = None + ) -> None: + """enable multiprocess for pipeline function. + + Args: + val (bool): enable or disable multiprocess. + func_names (str | List[str]): function names to enable. If + func_name is None, all registered function will be enabled. + """ + if val is True: + self._create_mp_manager() + if func_names is None: + for func_name in self._callers: + self._enable_multiprocess_single(val, func_name=func_name) + else: + if isinstance(func_names, str): + func_names = [func_names] + for func_name in func_names: + self._enable_multiprocess_single(val, func_name=func_name) + + def set_mp_async(self, + val: bool, + func_name: Optional[FUNC_NAME_TYPE] = None) -> None: + """set multiprocess async of the pipeline function. + + Args: + val (bool): enable async call. + func_name (str | None): function name to set. If func_name is + None, all registered function will be set. + """ + self.__set_caller_val('_mp_async', val, func_name) + + def set_log_level( + self, + level: int, + func_names: Optional[Union[FUNC_NAME_TYPE, + Sequence[FUNC_NAME_TYPE]]] = None + ) -> None: + """set log level of the pipeline function. + + Args: + level (int): the log level. + func_names (str | List[str]): function names to set. If func_names + is None, all registered function will be set. + """ + if isinstance(func_names, str): + func_names = [func_names] + for func_name in func_names: + self.__set_caller_val('_log_level', level, func_name) + + def get_input_hooks(self, func_name: FUNC_NAME_TYPE): + """get input hooks of given function name. + + Args: + func_name (str): function name. + """ + pipe_caller = self.get_caller(func_name) + return pipe_caller.input_hooks + + def get_output_hooks(self, func_name: FUNC_NAME_TYPE): + """get output hooks of given function name. + + Args: + func_name (str): function name. + """ + pipe_caller = self.get_caller(func_name) + return pipe_caller.output_hooks + + def call_function_local(self, func_name: FUNC_NAME_TYPE, *args, + **kwargs) -> Any: + """call pipeline function. + + Args: + func_name (str): function name to be called. + + Returns: + Any: The result of call function + """ + pipe_caller = self.get_caller(func_name) + pipe_caller._call_id = self._call_id + self._call_id += 1 + return pipe_caller(*args, **kwargs) + + def call_function_async(self, func_name: FUNC_NAME_TYPE, *args, + **kwargs) -> int: + """call pipeline function. + + Args: + func_name (str): function name to be called. + + Returns: + int: Call id of this function + """ + pipe_caller = self.get_caller(func_name) + assert pipe_caller.is_multiprocess, \ + f'multiprocess of {func_name} has not been enabled.' + + call_id = self._call_id + pipe_caller._call_id = call_id + self._call_id += 1 + proc = mp.Process(target=pipe_caller, args=args, kwargs=kwargs) + proc.start() + self._proc_async[call_id] = (func_name, proc) + + return call_id + + def get_result_sync(self, call_id: int): + """get result of async call.""" + assert call_id in self._proc_async, f'Unknown call id: {call_id}' + func_name, proc = self._proc_async.pop(call_id) + proc.join() + ret = self.get_caller(func_name).pop_mp_output(call_id) + + return ret + + def call_function(self, func_name: FUNC_NAME_TYPE, *args, **kwargs) -> Any: + """call pipeline function. + + Args: + func_name (str): function name to be called. + + Returns: + Any: The result of call function + """ + pipe_caller = self.get_caller(func_name) + + if self._enable_multiprocess and pipe_caller.is_multiprocess: + call_id = self.call_function_async(func_name, *args, **kwargs) + if pipe_caller._mp_async: + return PipelineResult(self, call_id) + return self.get_result_sync(call_id) + else: + return self.call_function_local(func_name, *args, **kwargs) + + def register_pipeline(self, + is_multiprocess_available: bool = True, + log_level: int = logging.DEBUG): + """register the pipeline function.""" + + def _register(func): + assert isinstance(func, Callable), f'{func} is not Callable.' + func_name_ = _get_func_name(func) + + # save the implementation into the registry module + impl_name = f'_pipe_{func.__name__}__impl_' + frame = inspect.stack()[1] + outer_mod = inspect.getmodule(frame[0]) + mod_name = outer_mod.__name__ + setattr(outer_mod, impl_name, func) + + # create caller + pipe_caller = PipelineCaller( + mod_name, + impl_name, + func_name=func_name_, + log_level=log_level, + is_multiprocess_available=is_multiprocess_available) + PIPELINE_MANAGER._callers[func_name_] = pipe_caller + + # wrap call + @wraps(func) + def _wrap(*args, **kwargs): + return self.call_function(func_name_, *args, **kwargs) + + return _wrap + + return _register + + +PIPELINE_MANAGER = PipelineManager() + + +class no_mp: + """The context manager used to disable multiprocess.""" + + def __init__(self, manager: PipelineManager = PIPELINE_MANAGER) -> None: + self._manager = manager + self._old_enable_multiprocess = True + + def __enter__(self): + self._old_enable_multiprocess = self._manager._enable_multiprocess + self._manager._enable_multiprocess = False + + def __exit__(self, type, val, tb): + self._manager._enable_multiprocess = self._old_enable_multiprocess diff --git a/mmdeploy/apis/extract_model.py b/mmdeploy/apis/extract_model.py index a01afd59b..62b53b185 100644 --- a/mmdeploy/apis/extract_model.py +++ b/mmdeploy/apis/extract_model.py @@ -1,33 +1,31 @@ # Copyright (c) OpenMMLab. All rights reserved. + from typing import Dict, Iterable, Optional, Union import onnx -import onnx.helper -import onnx.utils -from mmdeploy.core.optimizers import (attribute_to_dict, create_extractor, - get_new_name, parse_extractor_io_string, - remove_identity, rename_value) -from mmdeploy.utils import get_root_logger +from .core import PIPELINE_MANAGER +from .onnx import extract_partition +@PIPELINE_MANAGER.register_pipeline() def extract_model(model: Union[str, onnx.ModelProto], - start: Union[str, Iterable[str]], - end: Union[str, Iterable[str]], + start_marker: Union[str, Iterable[str]], + end_marker: Union[str, Iterable[str]], start_name_map: Optional[Dict[str, str]] = None, end_name_map: Optional[Dict[str, str]] = None, dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, save_file: Optional[str] = None) -> onnx.ModelProto: - """Extract sub-model from an ONNX model. + """Extract partition-model from an ONNX model. - The sub-model is defined by the names of the input and output tensors + The partition-model is defined by the names of the input and output tensors exactly. Examples: >>> from mmdeploy.apis import extract_model >>> model = 'work_dir/fastrcnn.onnx' - >>> start = 'detector:input' - >>> end = ['extract_feat:output', 'multiclass_nms[0]:input'] + >>> start_marker = 'detector:input' + >>> end_marker = ['extract_feat:output', 'multiclass_nms[0]:input'] >>> dynamic_axes = { 'input': { 0: 'batch', @@ -44,13 +42,14 @@ def extract_model(model: Union[str, onnx.ModelProto], } } >>> save_file = 'partition_model.onnx' - >>> extract_model(model, start, end, dynamic_axes=dynamic_axes, \ + >>> extract_model(model, start_marker, end_marker, \ + dynamic_axes=dynamic_axes, \ save_file=save_file) Args: model (str | onnx.ModelProto): Input ONNX model to be extracted. - start (str | Sequence[str]): Start marker(s) to extract. - end (str | Sequence[str]): End marker(s) to extract. + start_marker (str | Sequence[str]): Start marker(s) to extract. + end_marker (str | Sequence[str]): End marker(s) to extract. start_name_map (Dict[str, str]): A mapping of start names, defaults to `None`. end_name_map (Dict[str, str]): A mapping of end names, defaults to @@ -61,142 +60,8 @@ def extract_model(model: Union[str, onnx.ModelProto], `None`. Returns: - onnx.ModelProto: The extracted sub-model. + onnx.ModelProto: The extracted model. """ - if isinstance(model, str): - model = onnx.load(model) - - num_value_info = len(model.graph.value_info) - inputs = [] - outputs = [] - logger = get_root_logger() - if not isinstance(start, (list, tuple)): - start = [start] - for s in start: - start_name, func_id, start_type = parse_extractor_io_string(s) - for node in model.graph.node: - if node.op_type == 'Mark': - attr = attribute_to_dict(node.attribute) - if attr['func'] == start_name and attr[ - 'type'] == start_type and attr['func_id'] == func_id: - name = node.input[0] - if name not in inputs: - new_name = get_new_name( - attr, mark_name=s, name_map=start_name_map) - rename_value(model, name, new_name) - if not any([ - v_info.name == new_name - for v_info in model.graph.value_info - ]): - new_val_info = onnx.helper.make_tensor_value_info( - new_name, attr['dtype'], attr['shape']) - model.graph.value_info.append(new_val_info) - inputs.append(new_name) - - logger.info(f'inputs: {", ".join(inputs)}') - - # collect outputs - if not isinstance(end, (list, tuple)): - end = [end] - for e in end: - end_name, func_id, end_type = parse_extractor_io_string(e) - for node in model.graph.node: - if node.op_type == 'Mark': - attr = attribute_to_dict(node.attribute) - if attr['func'] == end_name and attr[ - 'type'] == end_type and attr['func_id'] == func_id: - name = node.output[0] - if name not in outputs: - new_name = get_new_name( - attr, mark_name=e, name_map=end_name_map) - rename_value(model, name, new_name) - if not any([ - v_info.name == new_name - for v_info in model.graph.value_info - ]): - new_val_info = onnx.helper.make_tensor_value_info( - new_name, attr['dtype'], attr['shape']) - model.graph.value_info.append(new_val_info) - outputs.append(new_name) - - logger.info(f'outputs: {", ".join(outputs)}') - - # replace Mark with Identity - for node in model.graph.node: - if node.op_type == 'Mark': - del node.attribute[:] - node.domain = '' - node.op_type = 'Identity' - - extractor = create_extractor(model) - extracted_model = extractor.extract_model(inputs, outputs) - - # remove all Identity, this may be done by onnx simplifier - remove_identity(extracted_model) - - # collect all used inputs - used = set() - for node in extracted_model.graph.node: - for input in node.input: - used.add(input) - - for output in extracted_model.graph.output: - used.add(output.name) - - # delete unused inputs - success = True - while success: - success = False - for i, input in enumerate(extracted_model.graph.input): - if input.name not in used: - del extracted_model.graph.input[i] - success = True - break - - # eliminate output without shape - for xs in [extracted_model.graph.output]: - for x in xs: - if not x.type.tensor_type.shape.dim: - logger.info(f'fixing output shape: {x.name}') - x.CopyFrom( - onnx.helper.make_tensor_value_info( - x.name, x.type.tensor_type.elem_type, [])) - - # eliminate 0-batch dimension, dirty workaround for two-stage detectors - for input in extracted_model.graph.input: - if input.name in inputs: - if input.type.tensor_type.shape.dim[0].dim_value == 0: - input.type.tensor_type.shape.dim[0].dim_value = 1 - - # eliminate duplicated value_info for inputs - success = True - # num_value_info == 0 if dynamic shape - if num_value_info == 0: - while len(extracted_model.graph.value_info) > 0: - extracted_model.graph.value_info.pop() - while success: - success = False - for i, x in enumerate(extracted_model.graph.value_info): - if x.name in inputs: - del extracted_model.graph.value_info[i] - success = True - break - - # dynamic shape support - if dynamic_axes is not None: - for input_node in extracted_model.graph.input: - if input_node.name in dynamic_axes: - axes = dynamic_axes[input_node.name] - for k, v in axes.items(): - input_node.type.tensor_type.shape.dim[k].dim_value = 0 - input_node.type.tensor_type.shape.dim[k].dim_param = v - for output_node in extracted_model.graph.output: - for idx, dim in enumerate(output_node.type.tensor_type.shape.dim): - dim.dim_value = 0 - dim.dim_param = f'dim_{idx}' - - # save extract_model if save_file is given - if save_file is not None: - onnx.save(extracted_model, save_file) - return extracted_model + return extract_partition(model, start_marker, end_marker, start_name_map, + end_name_map, dynamic_axes, save_file) diff --git a/mmdeploy/apis/ncnn/__init__.py b/mmdeploy/apis/ncnn/__init__.py index 4d56ae678..196e7f2bf 100644 --- a/mmdeploy/apis/ncnn/__init__.py +++ b/mmdeploy/apis/ncnn/__init__.py @@ -1,13 +1,18 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmdeploy.backend.ncnn import is_available, is_plugin_available +from mmdeploy.backend.ncnn import from_onnx as _from_onnx +from mmdeploy.backend.ncnn import is_available, is_custom_ops_available +from ..core import PIPELINE_MANAGER -__all__ = ['is_available', 'is_plugin_available'] +from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx) + +__all__ = ['is_available', 'is_custom_ops_available', 'from_onnx'] if is_available(): - from mmdeploy.backend.ncnn.onnx2ncnn import (get_output_model_file, - onnx2ncnn) - from mmdeploy.backend.ncnn.quant import get_quant_model_file, ncnn2int8 - __all__ += [ - 'onnx2ncnn', 'get_output_model_file', 'ncnn2int8', - 'get_quant_model_file' - ] + try: + from mmdeploy.backend.ncnn.onnx2ncnn import get_output_model_file + from mmdeploy.backend.ncnn.quant import get_quant_model_file, ncnn2int8 + __all__ += [ + 'get_output_model_file', 'ncnn2int8', 'get_quant_model_file' + ] + except Exception: + pass diff --git a/mmdeploy/apis/onnx/__init__.py b/mmdeploy/apis/onnx/__init__.py new file mode 100644 index 000000000..7ea2c9fd0 --- /dev/null +++ b/mmdeploy/apis/onnx/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .export import export +from .partition import extract_partition + +__all__ = ['export', 'extract_partition'] diff --git a/mmdeploy/apis/onnx/export.py b/mmdeploy/apis/onnx/export.py new file mode 100644 index 000000000..ddad7f30c --- /dev/null +++ b/mmdeploy/apis/onnx/export.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from functools import partial +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch + +from mmdeploy.apis.core import PIPELINE_MANAGER +from mmdeploy.core import RewriterContext, patch_model +from mmdeploy.utils import Backend, get_root_logger + + +@PIPELINE_MANAGER.register_pipeline() +def export(model: torch.nn.Module, + args: Union[torch.Tensor, Tuple, Dict], + output_path_prefix: str, + backend: Union[Backend, str] = 'default', + input_metas: Optional[Dict] = None, + context_info: Dict = dict(), + input_names: Optional[Sequence[str]] = None, + output_names: Optional[Sequence[str]] = None, + opset_version: int = 11, + dynamic_axes: Optional[Dict] = None, + verbose: bool = False, + keep_initializers_as_inputs: Optional[bool] = None, + **kwargs): + """Export a PyTorch model into ONNX format. This is a wrap of + `torch.onnx.export` with some enhancement. + + Examples: + >>> from mmdeploy.apis.onnx import export + >>> + >>> model = create_model() + >>> args = get_input_tensor() + >>> + >>> export( + >>> model, + >>> args, + >>> 'place/to/save/model', + >>> backend='tensorrt', + >>> input_names=['input'], + >>> output_names=['output'], + >>> dynamic_axes={'input': { + >>> 0: 'batch', + >>> 2: 'height', + >>> 3: 'width' + >>> }}) + + Args: + model (torch.nn.Module): the model to be exported. + args (torch.Tensor|Tuple|Dict): Dummy input of the model. + output_path_prefix (str): The output file prefix. The model will + be saved to `.onnx`. + backend (Backend|str): Which backend will the graph be used. Different + backend would generate different graph. + input_metas (Dict): The constant inputs of the model. + context_info (Dict): The information that would be used in the context + of exporting. + input_names (Sequence[str]): The input names of the model. + output_names (Sequence[str]): The output names of the model. + opset_version (int): The version of ONNX opset version. 11 as default. + dynamic_axes (Dict): The information used to determine which axes are + dynamic. + verbose (bool): Enable verbose model on `torch.onnx.export`. + keep_initializers_as_inputs (bool): Whether we should add inputs for + each initializer. + """ + output_path = output_path_prefix + '.onnx' + + logger = get_root_logger() + logger.info(f'Export PyTorch model to ONNX: {output_path}.') + + def _add_or_update(cfg: dict, key: str, val: Any): + if key in cfg and isinstance(cfg[key], dict) and isinstance(val, dict): + cfg[key].update(val) + else: + cfg[key] = val + + context_info = deepcopy(context_info) + deploy_cfg = context_info.pop('deploy_cfg', dict()) + ir_config = dict( + type='onnx', + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + verbose=verbose, + keep_initializers_as_inputs=keep_initializers_as_inputs) + _add_or_update(deploy_cfg, 'ir_config', ir_config) + + if isinstance(backend, Backend): + backend = backend.value + backend_config = dict(type=backend) + _add_or_update(deploy_cfg, 'backend_config', backend_config) + + context_info['cfg'] = deploy_cfg + if 'backend' not in context_info: + context_info['backend'] = backend + if 'opset' not in context_info: + context_info['opset'] = opset_version + + # patch model + patched_model = patch_model(model, cfg=deploy_cfg, backend=backend) + + with RewriterContext(**context_info), torch.no_grad(): + # patch input_metas + if input_metas is not None: + assert isinstance( + input_metas, dict + ), f'Expect input_metas type is dict, get {type(input_metas)}.' + model_forward = model.forward + model.forward = partial(model.forward, **input_metas) + + torch.onnx.export( + patched_model, + args, + output_path, + export_params=True, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + verbose=verbose) + + if input_metas is not None: + model.forward = model_forward diff --git a/mmdeploy/apis/onnx/partition.py b/mmdeploy/apis/onnx/partition.py new file mode 100644 index 000000000..31e0663db --- /dev/null +++ b/mmdeploy/apis/onnx/partition.py @@ -0,0 +1,205 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Iterable, Optional, Union + +import onnx +import onnx.helper +import onnx.utils + +from mmdeploy.apis.core import PIPELINE_MANAGER +from mmdeploy.core.optimizers import (attribute_to_dict, create_extractor, + get_new_name, parse_extractor_io_string, + remove_identity, rename_value) +from mmdeploy.utils import get_root_logger + + +@PIPELINE_MANAGER.register_pipeline() +def extract_partition(model: Union[str, onnx.ModelProto], + start_marker: Union[str, Iterable[str]], + end_marker: Union[str, Iterable[str]], + start_name_map: Optional[Dict[str, str]] = None, + end_name_map: Optional[Dict[str, str]] = None, + dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, + save_file: Optional[str] = None) -> onnx.ModelProto: + """Extract partition-model from an ONNX model. + + The partition-model is defined by the names of the input and output tensors + exactly. + + Examples: + >>> from mmdeploy.apis import extract_model + >>> model = 'work_dir/fastrcnn.onnx' + >>> start_marker = 'detector:input' + >>> end_marker = ['extract_feat:output', 'multiclass_nms[0]:input'] + >>> dynamic_axes = { + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'scores': { + 0: 'batch', + 1: 'num_boxes', + }, + 'boxes': { + 0: 'batch', + 1: 'num_boxes', + } + } + >>> save_file = 'partition_model.onnx' + >>> extract_partition(model, start_marker, end_marker, \ + dynamic_axes=dynamic_axes, \ + save_file=save_file) + + Args: + model (str | onnx.ModelProto): Input ONNX model to be extracted. + start_marker (str | Sequence[str]): Start marker(s) to extract. + end_marker (str | Sequence[str]): End marker(s) to extract. + start_name_map (Dict[str, str]): A mapping of start names, defaults to + `None`. + end_name_map (Dict[str, str]): A mapping of end names, defaults to + `None`. + dynamic_axes (Dict[str, Dict[int, str]]): A dictionary to specify + dynamic axes of input/output, defaults to `None`. + save_file (str): A file to save the extracted model, defaults to + `None`. + + Returns: + onnx.ModelProto: The extracted model. + """ + if isinstance(model, str): + model = onnx.load(model) + + num_value_info = len(model.graph.value_info) + inputs = [] + outputs = [] + logger = get_root_logger() + if not isinstance(start_marker, (list, tuple)): + start_marker = [start_marker] + for s in start_marker: + start_name, func_id, start_type = parse_extractor_io_string(s) + for node in model.graph.node: + if node.op_type == 'Mark': + attr = attribute_to_dict(node.attribute) + if attr['func'] == start_name and attr[ + 'type'] == start_type and attr['func_id'] == func_id: + name = node.input[0] + if name not in inputs: + new_name = get_new_name( + attr, mark_name=s, name_map=start_name_map) + rename_value(model, name, new_name) + if not any([ + v_info.name == new_name + for v_info in model.graph.value_info + ]): + new_val_info = onnx.helper.make_tensor_value_info( + new_name, attr['dtype'], attr['shape']) + model.graph.value_info.append(new_val_info) + inputs.append(new_name) + + logger.info(f'inputs: {", ".join(inputs)}') + + # collect outputs + if not isinstance(end_marker, (list, tuple)): + end_marker = [end_marker] + for e in end_marker: + end_name, func_id, end_type = parse_extractor_io_string(e) + for node in model.graph.node: + if node.op_type == 'Mark': + attr = attribute_to_dict(node.attribute) + if attr['func'] == end_name and attr[ + 'type'] == end_type and attr['func_id'] == func_id: + name = node.output[0] + if name not in outputs: + new_name = get_new_name( + attr, mark_name=e, name_map=end_name_map) + rename_value(model, name, new_name) + if not any([ + v_info.name == new_name + for v_info in model.graph.value_info + ]): + new_val_info = onnx.helper.make_tensor_value_info( + new_name, attr['dtype'], attr['shape']) + model.graph.value_info.append(new_val_info) + outputs.append(new_name) + + logger.info(f'outputs: {", ".join(outputs)}') + + # replace Mark with Identity + for node in model.graph.node: + if node.op_type == 'Mark': + del node.attribute[:] + node.domain = '' + node.op_type = 'Identity' + + extractor = create_extractor(model) + extracted_model = extractor.extract_model(inputs, outputs) + + # remove all Identity, this may be done by onnx simplifier + remove_identity(extracted_model) + + # collect all used inputs + used = set() + for node in extracted_model.graph.node: + for input in node.input: + used.add(input) + + for output in extracted_model.graph.output: + used.add(output.name) + + # delete unused inputs + success = True + while success: + success = False + for i, input in enumerate(extracted_model.graph.input): + if input.name not in used: + del extracted_model.graph.input[i] + success = True + break + + # eliminate output without shape + for xs in [extracted_model.graph.output]: + for x in xs: + if not x.type.tensor_type.shape.dim: + logger.info(f'fixing output shape: {x.name}') + x.CopyFrom( + onnx.helper.make_tensor_value_info( + x.name, x.type.tensor_type.elem_type, [])) + + # eliminate 0-batch dimension, dirty workaround for two-stage detectors + for input in extracted_model.graph.input: + if input.name in inputs: + if input.type.tensor_type.shape.dim[0].dim_value == 0: + input.type.tensor_type.shape.dim[0].dim_value = 1 + + # eliminate duplicated value_info for inputs + success = True + # num_value_info == 0 if dynamic shape + if num_value_info == 0: + while len(extracted_model.graph.value_info) > 0: + extracted_model.graph.value_info.pop() + while success: + success = False + for i, x in enumerate(extracted_model.graph.value_info): + if x.name in inputs: + del extracted_model.graph.value_info[i] + success = True + break + + # dynamic shape support + if dynamic_axes is not None: + for input_node in extracted_model.graph.input: + if input_node.name in dynamic_axes: + axes = dynamic_axes[input_node.name] + for k, v in axes.items(): + input_node.type.tensor_type.shape.dim[k].dim_value = 0 + input_node.type.tensor_type.shape.dim[k].dim_param = v + for output_node in extracted_model.graph.output: + for idx, dim in enumerate(output_node.type.tensor_type.shape.dim): + dim.dim_value = 0 + dim.dim_param = f'dim_{idx}' + + # save extract_model if save_file is given + if save_file is not None: + onnx.save(extracted_model, save_file) + + return extracted_model diff --git a/mmdeploy/apis/onnxruntime/__init__.py b/mmdeploy/apis/onnxruntime/__init__.py index fd70945e7..63ef448d5 100644 --- a/mmdeploy/apis/onnxruntime/__init__.py +++ b/mmdeploy/apis/onnxruntime/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmdeploy.backend.onnxruntime import is_available, is_plugin_available +from mmdeploy.backend.onnxruntime import is_available, is_custom_ops_available -__all__ = ['is_available', 'is_plugin_available'] +__all__ = ['is_available', 'is_custom_ops_available'] diff --git a/mmdeploy/apis/openvino/__init__.py b/mmdeploy/apis/openvino/__init__.py index 97f6ade95..c06d29996 100644 --- a/mmdeploy/apis/openvino/__init__.py +++ b/mmdeploy/apis/openvino/__init__.py @@ -4,10 +4,14 @@ __all__ = ['is_available'] if is_available(): - from mmdeploy.backend.openvino.onnx2openvino import (get_output_model_file, - onnx2openvino) + from mmdeploy.backend.openvino.onnx2openvino import from_onnx as _from_onnx + from mmdeploy.backend.openvino.onnx2openvino import get_output_model_file + from ..core import PIPELINE_MANAGER + + from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx) + from .utils import get_input_info_from_cfg, get_mo_options_from_cfg __all__ += [ - 'onnx2openvino', 'get_output_model_file', 'get_input_info_from_cfg', + 'from_onnx', 'get_output_model_file', 'get_input_info_from_cfg', 'get_mo_options_from_cfg' ] diff --git a/mmdeploy/apis/pplnn/__init__.py b/mmdeploy/apis/pplnn/__init__.py index b0d585506..696a2526a 100644 --- a/mmdeploy/apis/pplnn/__init__.py +++ b/mmdeploy/apis/pplnn/__init__.py @@ -4,6 +4,8 @@ __all__ = ['is_available'] if is_available(): - from mmdeploy.backend.pplnn.onnx2pplnn import onnx2pplnn + from mmdeploy.backend.pplnn.onnx2pplnn import from_onnx as _from_onnx + from ..core import PIPELINE_MANAGER + from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx) - __all__ += ['onnx2pplnn'] + __all__ += ['from_onnx'] diff --git a/mmdeploy/apis/pytorch2onnx.py b/mmdeploy/apis/pytorch2onnx.py index e9912bc89..5d8d9c91a 100644 --- a/mmdeploy/apis/pytorch2onnx.py +++ b/mmdeploy/apis/pytorch2onnx.py @@ -1,60 +1,18 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Union import mmcv import torch -from mmdeploy.core import RewriterContext, patch_model +from mmdeploy.apis.core.pipeline_manager import no_mp from mmdeploy.utils import (get_backend, get_dynamic_axes, get_input_shape, get_onnx_config, load_config) +from .core import PIPELINE_MANAGER +from .onnx import export -def torch2onnx_impl(model: torch.nn.Module, input: Union[torch.Tensor, Tuple], - deploy_cfg: Union[str, mmcv.Config], output_file: str): - """Converting torch model to ONNX. - - Args: - model (torch.nn.Module): Input pytorch model. - input (torch.Tensor | Tuple): Input tensor used to convert model. - deploy_cfg (str | mmcv.Config): Deployment config file or - Config object. - output_file (str): Output file to save ONNX model. - """ - # load deploy_cfg if needed - deploy_cfg = load_config(deploy_cfg)[0] - - onnx_cfg = get_onnx_config(deploy_cfg) - backend = get_backend(deploy_cfg).value - opset_version = onnx_cfg.get('opset_version', 11) - - input_names = onnx_cfg['input_names'] - output_names = onnx_cfg['output_names'] - axis_names = input_names + output_names - dynamic_axes = get_dynamic_axes(deploy_cfg, axis_names) - verbose = not onnx_cfg.get('strip_doc_string', True) or onnx_cfg.get( - 'verbose', False) - - # patch model - patched_model = patch_model(model, cfg=deploy_cfg, backend=backend) - - with RewriterContext( - cfg=deploy_cfg, backend=backend, - opset=opset_version), torch.no_grad(): - torch.onnx.export( - patched_model, - input, - output_file, - export_params=onnx_cfg['export_params'], - input_names=input_names, - output_names=output_names, - opset_version=opset_version, - dynamic_axes=dynamic_axes, - keep_initializers_as_inputs=onnx_cfg[ - 'keep_initializers_as_inputs'], - verbose=verbose) - - +@PIPELINE_MANAGER.register_pipeline() def torch2onnx(img: Any, work_dir: str, save_file: str, @@ -94,10 +52,10 @@ def torch2onnx(img: Any, # load deploy_cfg if necessary deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) mmcv.mkdir_or_exist(osp.abspath(work_dir)) - output_file = osp.join(work_dir, save_file) input_shape = get_input_shape(deploy_cfg) + # create model an inputs from mmdeploy.apis import build_task_processor task_processor = build_task_processor(model_cfg, deploy_cfg, device) @@ -106,8 +64,34 @@ def torch2onnx(img: Any, if not isinstance(model_inputs, torch.Tensor) and len(model_inputs) == 1: model_inputs = model_inputs[0] - torch2onnx_impl( - torch_model, - model_inputs, - deploy_cfg=deploy_cfg, - output_file=output_file) + # export to onnx + context_info = dict() + context_info['deploy_cfg'] = deploy_cfg + output_prefix = osp.join(work_dir, + osp.splitext(osp.basename(save_file))[0]) + backend = get_backend(deploy_cfg).value + + onnx_cfg = get_onnx_config(deploy_cfg) + opset_version = onnx_cfg.get('opset_version', 11) + + input_names = onnx_cfg['input_names'] + output_names = onnx_cfg['output_names'] + axis_names = input_names + output_names + dynamic_axes = get_dynamic_axes(deploy_cfg, axis_names) + verbose = not onnx_cfg.get('strip_doc_string', True) or onnx_cfg.get( + 'verbose', False) + keep_initializers_as_inputs = onnx_cfg.get('keep_initializers_as_inputs', + True) + with no_mp(): + export( + torch_model, + model_inputs, + output_path_prefix=output_prefix, + backend=backend, + input_names=input_names, + output_names=output_names, + context_info=context_info, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + verbose=verbose, + keep_initializers_as_inputs=keep_initializers_as_inputs) diff --git a/mmdeploy/apis/pytorch2torchscript.py b/mmdeploy/apis/pytorch2torchscript.py index 8b54ce4ce..451084931 100644 --- a/mmdeploy/apis/pytorch2torchscript.py +++ b/mmdeploy/apis/pytorch2torchscript.py @@ -1,73 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Union import mmcv import torch -from packaging.version import parse as version_parse -from mmdeploy.backend.torchscript import get_ops_path -from mmdeploy.core import RewriterContext, patch_model -from mmdeploy.utils import (IR, get_backend, get_input_shape, get_root_logger, - load_config) - - -def torch2torchscript_impl(model: torch.nn.Module, - inputs: Union[torch.Tensor, Sequence[torch.Tensor]], - deploy_cfg: Union[str, - mmcv.Config], output_file: str): - """Converting torch model to torchscript. - - Args: - model (torch.nn.Module): Input pytorch model. - inputs (torch.Tensor | Sequence[torch.Tensor]): Input tensors used to - convert model. - deploy_cfg (str | mmcv.Config): Deployment config file or - Config object. - output_file (str): Output file to save torchscript model. - """ - # load custom ops if exist - custom_ops_path = get_ops_path() - if osp.exists(custom_ops_path): - torch.ops.load_library(custom_ops_path) - - deploy_cfg = load_config(deploy_cfg)[0] - - backend = get_backend(deploy_cfg).value - - patched_model = patch_model(model, cfg=deploy_cfg, backend=backend) - - with RewriterContext( - cfg=deploy_cfg, backend=backend, - ir=IR.TORCHSCRIPT), torch.no_grad(), torch.jit.optimized_execution( - True): - # for exporting models with weight that depends on inputs - patched_model(*inputs) if isinstance(inputs, Sequence) \ - else patched_model(inputs) - ts_model = torch.jit.trace(patched_model, inputs) - - # perform optimize, note that optimizing models may trigger errors when - # loading the saved .pt file, as described in - # https://github.com/pytorch/pytorch/issues/62706 - logger = get_root_logger() - logger.info('perform torchscript optimizer.') - try: - # custom optimizer - from mmdeploy.backend.torchscript import ts_optimizer - logger = get_root_logger() - ts_optimizer.optimize_for_backend( - ts_model._c, ir=IR.TORCHSCRIPT.value, backend=backend) - except Exception: - # use pytorch builtin optimizer - ts_model = torch.jit.freeze(ts_model) - torch_version = version_parse(torch.__version__) - if torch_version.minor >= 9: - ts_model = torch.jit.optimize_for_inference(ts_model) - - # save model - torch.jit.save(ts_model, output_file) +from mmdeploy.apis.core.pipeline_manager import PIPELINE_MANAGER, no_mp +from mmdeploy.utils import get_backend, get_input_shape, load_config +from .torch_jit import trace +@PIPELINE_MANAGER.register_pipeline() def torch2torchscript(img: Any, work_dir: str, save_file: str, @@ -92,7 +35,6 @@ def torch2torchscript(img: Any, # load deploy_cfg if necessary deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) mmcv.mkdir_or_exist(osp.abspath(work_dir)) - output_file = osp.join(work_dir, save_file) input_shape = get_input_shape(deploy_cfg) @@ -104,8 +46,15 @@ def torch2torchscript(img: Any, if not isinstance(model_inputs, torch.Tensor): model_inputs = model_inputs[0] - torch2torchscript_impl( - torch_model, - model_inputs, - deploy_cfg=deploy_cfg, - output_file=output_file) + context_info = dict(deploy_cfg=deploy_cfg) + backend = get_backend(deploy_cfg).value + output_prefix = osp.join(work_dir, osp.splitext(save_file)[0]) + + with no_mp(): + trace( + torch_model, + model_inputs, + output_path_prefix=output_prefix, + backend=backend, + context_info=context_info, + check_trace=False) diff --git a/mmdeploy/apis/sdk/__init__.py b/mmdeploy/apis/sdk/__init__.py new file mode 100644 index 000000000..3f9013e3d --- /dev/null +++ b/mmdeploy/apis/sdk/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdeploy.backend.sdk import is_available + +__all__ = ['is_available'] + +if is_available(): + try: + from mmdeploy.backend.sdk.export_info import export2SDK as _export2SDK + from ..core import PIPELINE_MANAGER + export2SDK = PIPELINE_MANAGER.register_pipeline()(_export2SDK) + + __all__ += ['export2SDK'] + except Exception: + pass diff --git a/mmdeploy/apis/tensorrt/__init__.py b/mmdeploy/apis/tensorrt/__init__.py index 8b912b23e..fe31b6ac2 100644 --- a/mmdeploy/apis/tensorrt/__init__.py +++ b/mmdeploy/apis/tensorrt/__init__.py @@ -1,9 +1,21 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmdeploy.backend.tensorrt import is_available, is_plugin_available +from mmdeploy.backend.tensorrt import from_onnx as _from_onnx +from mmdeploy.backend.tensorrt import (is_available, is_custom_ops_available, + load, save) +from ..core import PIPELINE_MANAGER -__all__ = ['is_available', 'is_plugin_available'] +from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx) + +__all__ = [ + 'is_available', 'is_custom_ops_available', 'from_onnx', 'save', 'load' +] if is_available(): - from mmdeploy.backend.tensorrt.onnx2tensorrt import onnx2tensorrt + try: + from mmdeploy.backend.tensorrt.onnx2tensorrt import \ + onnx2tensorrt as _onnx2tensorrt - __all__ += ['onnx2tensorrt'] + onnx2tensorrt = PIPELINE_MANAGER.register_pipeline()(_onnx2tensorrt) + __all__ += ['onnx2tensorrt'] + except Exception: + pass diff --git a/mmdeploy/apis/torch_jit/__init__.py b/mmdeploy/apis/torch_jit/__init__.py new file mode 100644 index 000000000..d93d95de5 --- /dev/null +++ b/mmdeploy/apis/torch_jit/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdeploy.backend.torchscript import get_ops_path +from .trace import trace + +__all__ = ['get_ops_path', 'trace'] diff --git a/mmdeploy/apis/torch_jit/trace.py b/mmdeploy/apis/torch_jit/trace.py new file mode 100644 index 000000000..901a22913 --- /dev/null +++ b/mmdeploy/apis/torch_jit/trace.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch +from packaging.version import parse as version_parse + +from mmdeploy.core import RewriterContext, patch_model +from mmdeploy.utils import IR, Backend, get_root_logger +from ..core import PIPELINE_MANAGER + + +@PIPELINE_MANAGER.register_pipeline() +def trace(func: torch.nn.Module, + inputs: Union[torch.Tensor, Tuple], + output_path_prefix: Optional[str] = None, + backend: Union[Backend, str] = 'default', + context_info: Dict = dict(), + check_trace: bool = True, + check_tolerance: float = 1e-05) -> torch.jit.TracedModule: + """A wrapper of `torch.jit.trace` with some enhancement. + + Examples: + >>> from mmdeploy.apis.torch_jit import trace + >>> + >>> func = create_model() + >>> inputs = get_input_tensor() + >>> + >>> jit_model = trace( + >>> func, + >>> inputs, + >>> backend='torchscript', + >>> check_trace=False) + >>> + + Args: + func (torch.nn.Module): A Python function or `torch.nn.Module` that + will be run with `example_inputs`. + inputs (torch.Tensor, Tuple): A tuple of example inputs that will be + passed to the function while tracing. + output_path_prefix (str): The model would be serialized in + `.pth`, None if you don't want to + save the model. + backend (Backend|str): Which backend will the graph be used. Different + backend would generate different graph. + context_info (Dict): The information that would be used in the context + of exporting. + check_trace (bool): Check if the same inputs run through traced code + produce the same outputs. + check_tolerance (float): Floating-point comparison tolerance to use in + the checker procedure. + + Returns: + torch.jit.TracedModule: The traced torch jit model. + """ + logger = get_root_logger() + logger.info('Export PyTorch model to torchscript.') + + def _add_or_update(cfg: dict, key: str, val: Any): + if key in cfg and isinstance(cfg[key], dict) and isinstance(val, dict): + cfg[key].update(val) + else: + cfg[key] = val + + context_info = deepcopy(context_info) + deploy_cfg = context_info.pop('deploy_cfg', dict()) + ir_config = dict(type='torchscript') + _add_or_update(deploy_cfg, 'ir_config', ir_config) + + if isinstance(backend, Backend): + backend = backend.value + backend_config = dict(type=backend) + _add_or_update(deploy_cfg, 'backend_config', backend_config) + + context_info['cfg'] = deploy_cfg + if 'backend' not in context_info: + context_info['backend'] = backend + elif context_info['backend'] != backend: + logger.warning( + f'Find backend {context_info["backend"]} in context_info.' + f' Expect {backend}.') + if 'ir' not in context_info: + context_info['ir'] = IR.TORCHSCRIPT + elif context_info['ir'] != backend: + logger.warning(f'Find ir {context_info["ir"]} in context_info.' + f' Expect {IR.TORCHSCRIPT}.') + + # patch model + if isinstance(func, torch.nn.Module): + func = patch_model(func, cfg=deploy_cfg, backend=backend) + + with RewriterContext(**context_info), torch.no_grad(): + # for exporting models with weight that depends on inputs + func(*inputs) if isinstance(inputs, Sequence) \ + else func(inputs) + ts_model = torch.jit.trace( + func, + inputs, + check_trace=check_trace, + check_tolerance=check_tolerance) + + logger.info('perform torchscript optimizer.') + try: + # custom optimizer + from mmdeploy.backend.torchscript import ts_optimizer + logger = get_root_logger() + ts_optimizer.optimize_for_backend( + ts_model._c, ir=IR.TORCHSCRIPT.value, backend=backend) + except Exception: + # use pytorch builtin optimizer + ts_model = torch.jit.freeze(ts_model) + torch_version = version_parse(torch.__version__) + if torch_version.minor >= 9: + ts_model = torch.jit.optimize_for_inference(ts_model) + + # save model + if output_path_prefix is not None: + output_path = output_path_prefix + '.pt' + logger.info(f'Save PyTorch model: {output_path}.') + torch.jit.save(ts_model, output_path) + + return ts_model diff --git a/mmdeploy/apis/utils/__init__.py b/mmdeploy/apis/utils/__init__.py new file mode 100644 index 000000000..27740b5e9 --- /dev/null +++ b/mmdeploy/apis/utils/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .calibration import create_calib_input_data +from .utils import build_task_processor, get_predefined_partition_cfg + +__all__ = [ + 'create_calib_input_data', 'build_task_processor', + 'get_predefined_partition_cfg' +] diff --git a/mmdeploy/apis/utils/calibration.py b/mmdeploy/apis/utils/calibration.py new file mode 100644 index 000000000..4b250bfe1 --- /dev/null +++ b/mmdeploy/apis/utils/calibration.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Callable, Dict, Optional + +import h5py +import torch +import tqdm +from torch.utils.data import DataLoader + +from mmdeploy.core import RewriterContext, reset_mark_function_count +from ..core import PIPELINE_MANAGER + + +@PIPELINE_MANAGER.register_pipeline() +def create_calib_input_data(calib_file: str, + model: torch.nn.Module, + dataloader: DataLoader, + get_tensor_func: Optional[Callable] = None, + inference_func: Optional[Callable] = None, + model_partition: bool = False, + context_info: Dict = dict(), + device: str = 'cpu') -> None: + """Create calibration table. + + Examples: + >>> from mmdeploy.apis.utils import create_calib_input_data + >>> from mmdeploy.utils import get_calib_filename, load_config + >>> deploy_cfg = 'configs/mmdet/detection/' + 'detection_tensorrt-int8_dynamic-320x320-1344x1344.py' + >>> deploy_cfg = load_config(deploy_cfg)[0] + >>> calib_file = get_calib_filename(deploy_cfg) + >>> model_cfg = 'mmdetection/configs/fcos/' + 'fcos_r50_caffe_fpn_gn-head_1x_coco.py' + >>> model_checkpoint = 'checkpoints/' + 'fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth' + >>> create_calib_input_data(calib_file, deploy_cfg, + model_cfg, model_checkpoint, device='cuda:0') + + Args: + calib_file (str): Input calibration file. + deploy_cfg (str | mmcv.Config): Deployment config. + model_cfg (str | mmcv.Config): The model config. + model_checkpoint (str): PyTorch model checkpoint, defaults to `None`. + dataset_cfg (str | mmcv.Config): Dataset config, defaults to `None` + dataset_type (str): A string specifying dataset type, e.g.: 'test', + 'val', defaults to 'val'. + device (str): Specifying the device to run on, defaults to 'cpu'. + """ + + backend = 'default' + + with h5py.File(calib_file, mode='w') as file: + calib_data_group = file.create_group('calib_data') + + if not model_partition: + # create end2end group + input_data_group = calib_data_group.create_group('end2end') + input_group = input_data_group.create_group('input') + for data_id, input_data in enumerate(tqdm.tqdm(dataloader)): + + if not model_partition: + # save end2end data + if get_tensor_func is not None: + input_tensor = get_tensor_func(input_data) + else: + input_tensor = input_data + input_ndarray = input_tensor.detach().cpu().numpy() + input_group.create_dataset( + str(data_id), + shape=input_ndarray.shape, + compression='gzip', + compression_opts=4, + data=input_ndarray) + else: + context_info_ = deepcopy(context_info) + if 'cfg' not in context_info: + context_info_['cfg'] = dict() + context_info_['backend'] = backend + context_info_['create_calib'] = True + context_info_['calib_file'] = file + context_info_['data_id'] = data_id + + with torch.no_grad(), RewriterContext(**context_info_): + reset_mark_function_count() + if inference_func is not None: + inference_func(model, input_data) + else: + model(input_data) + + file.flush() diff --git a/mmdeploy/apis/utils.py b/mmdeploy/apis/utils/utils.py similarity index 100% rename from mmdeploy/apis/utils.py rename to mmdeploy/apis/utils/utils.py diff --git a/mmdeploy/backend/ncnn/__init__.py b/mmdeploy/backend/ncnn/__init__.py index 41b2c0ad9..134493242 100644 --- a/mmdeploy/backend/ncnn/__init__.py +++ b/mmdeploy/backend/ncnn/__init__.py @@ -3,6 +3,7 @@ import os.path as osp from .init_plugins import get_onnx2ncnn_path, get_ops_path +from .onnx2ncnn import from_onnx def is_available(): @@ -19,7 +20,7 @@ def is_available(): return has_pyncnn and osp.exists(onnx2ncnn) -def is_plugin_available(): +def is_custom_ops_available(): """Check whether ncnn extension and custom ops are installed. Returns: @@ -31,10 +32,12 @@ def is_plugin_available(): return has_pyncnn_ext and osp.exists(ncnn_ops_path) +__all__ = ['from_onnx'] + if is_available(): try: from .wrapper import NCNNWrapper - __all__ = ['NCNNWrapper'] + __all__ += ['NCNNWrapper'] except Exception: pass diff --git a/mmdeploy/backend/ncnn/onnx2ncnn.py b/mmdeploy/backend/ncnn/onnx2ncnn.py index ec69d941d..fd5c0251c 100644 --- a/mmdeploy/backend/ncnn/onnx2ncnn.py +++ b/mmdeploy/backend/ncnn/onnx2ncnn.py @@ -1,8 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import os import os.path as osp +import tempfile from subprocess import call -from typing import List +from typing import List, Union + +import onnx from .init_plugins import get_onnx2ncnn_path @@ -32,7 +35,8 @@ def get_output_model_file(onnx_path: str, work_dir: str) -> List[str]: return [save_param, save_bin] -def onnx2ncnn(onnx_path: str, save_param: str, save_bin: str): +def from_onnx(onnx_model: Union[onnx.ModelProto, str], + output_file_prefix: str): """Convert ONNX to ncnn. The inputs of ncnn include a model file and a weight file. We need to use @@ -40,18 +44,24 @@ def onnx2ncnn(onnx_path: str, save_param: str, save_bin: str): a `.bin` file. The output files will save to work_dir. Example: - >>> from mmdeploy.backend.ncnn.onnx2ncnn import onnx2ncnn + >>> from mmdeploy.apis.ncnn import from_onnx >>> onnx_path = 'work_dir/end2end.onnx' - >>> save_param = 'work_dir/end2end.param' - >>> save_bin = 'work_dir/end2end.bin' - >>> onnx2ncnn(onnx_path, save_param, save_bin) + >>> output_file_prefix = 'work_dir/end2end' + >>> from_onnx(onnx_path, output_file_prefix) Args: - onnx_path (str): The path of the onnx model. - save_param (str): The path to save the output `.param` file. - save_bin (str): The path to save the output `.bin` file. + onnx_path (ModelProto|str): The path of the onnx model. + output_file_prefix (str): The path to save the output ncnn file. """ - onnx2ncnn_path = get_onnx2ncnn_path() + if not isinstance(onnx_model, str): + onnx_path = tempfile.NamedTemporaryFile(suffix='.onnx').name + onnx.save(onnx_model, onnx_path) + else: + onnx_path = onnx_model + save_param = output_file_prefix + '.param' + save_bin = output_file_prefix + '.bin' + + onnx2ncnn_path = get_onnx2ncnn_path() call([onnx2ncnn_path, onnx_path, save_param, save_bin]) diff --git a/mmdeploy/backend/onnxruntime/__init__.py b/mmdeploy/backend/onnxruntime/__init__.py index a6dc88c6b..e808311bc 100644 --- a/mmdeploy/backend/onnxruntime/__init__.py +++ b/mmdeploy/backend/onnxruntime/__init__.py @@ -15,7 +15,7 @@ def is_available(): return importlib.util.find_spec('onnxruntime') is not None -def is_plugin_available(): +def is_custom_ops_available(): """Check whether ONNX Runtime custom ops are installed. Returns: diff --git a/mmdeploy/backend/openvino/onnx2openvino.py b/mmdeploy/backend/openvino/onnx2openvino.py index 7252efabb..c96088903 100644 --- a/mmdeploy/backend/openvino/onnx2openvino.py +++ b/mmdeploy/backend/openvino/onnx2openvino.py @@ -1,11 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import subprocess +import tempfile from subprocess import PIPE, CalledProcessError, run -from typing import Dict, List, Optional, Union +from typing import Dict, Optional, Sequence, Union import mmcv -import torch +import onnx from mmdeploy.utils import get_root_logger from .utils import ModelOptimizerOptions @@ -55,30 +56,33 @@ def get_output_model_file(onnx_path: str, work_dir: str) -> str: return model_xml -def onnx2openvino(input_info: Dict[str, Union[List[int], torch.Size]], - output_names: List[str], - onnx_path: str, - work_dir: str, - mo_options: Optional[ModelOptimizerOptions] = None): +def from_onnx(onnx_model: Union[str, onnx.ModelProto], + output_file_prefix: str, + input_info: Dict[str, Sequence[int]], + output_names: Sequence[str], + mo_options: Optional[ModelOptimizerOptions] = None): """Convert ONNX to OpenVINO. Examples: - >>> from mmdeploy.backend.openvino.onnx2openvino import onnx2openvino + >>> from mmdeploy.apis.openvino import from_onnx >>> input_info = {'input': [1,3,800,1344]} >>> output_names = ['dets', 'labels'] >>> onnx_path = 'work_dir/end2end.onnx' - >>> work_dir = 'work_dir' - >>> onnx2openvino(input_info, output_names, onnx_path, work_dir) + >>> output_dir = 'work_dir' + >>> from_onnx( onnx_path, output_dir, input_info, output_names) Args: - input_info (Dict[str, Union[List[int], torch.Size]]): + onnx_model (str|ModelProto): The onnx model or its path. + output_file_prefix (str): The path to the directory for saving + the results. + input_info (Dict[str, Sequence[int]]): The shape of each input. - output_names (List[str]): Output names. Example: ['dets', 'labels']. - onnx_path (str): The path to the onnx model. - work_dir (str): The path to the directory for saving the results. + output_names (Sequence[str]): Output names. Example: + ['dets', 'labels']. mo_options (None | ModelOptimizerOptions): The class with additional arguments for the Model Optimizer. """ + work_dir = output_file_prefix input_names = ','.join(input_info.keys()) input_shapes = ','.join(str(list(elem)) for elem in input_info.values()) output = ','.join(output_names) @@ -89,6 +93,12 @@ def onnx2openvino(input_info: Dict[str, Union[List[int], torch.Size]], raise RuntimeError( 'OpenVINO Model Optimizer is not found or configured improperly') + if isinstance(onnx_model, str): + onnx_path = onnx_model + else: + onnx_path = tempfile.NamedTemporaryFile(suffix='.onnx').name + onnx.save(onnx_model, onnx_path) + mo_args = f'--input_model="{onnx_path}" '\ f'--output_dir="{work_dir}" ' \ f'--output="{output}" ' \ diff --git a/mmdeploy/backend/pplnn/__init__.py b/mmdeploy/backend/pplnn/__init__.py index 8e639c161..d9a5e70f3 100644 --- a/mmdeploy/backend/pplnn/__init__.py +++ b/mmdeploy/backend/pplnn/__init__.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import importlib +from .utils import register_engines + def is_available(): """Check whether pplnn is installed. @@ -11,6 +13,8 @@ def is_available(): return importlib.util.find_spec('pyppl') is not None +__all__ = ['register_engines'] + if is_available(): - from .wrapper import PPLNNWrapper, register_engines - __all__ = ['register_engines', 'PPLNNWrapper'] + from .wrapper import PPLNNWrapper + __all__ += ['PPLNNWrapper'] diff --git a/mmdeploy/backend/pplnn/onnx2pplnn.py b/mmdeploy/backend/pplnn/onnx2pplnn.py index d5fb97759..878238d04 100644 --- a/mmdeploy/backend/pplnn/onnx2pplnn.py +++ b/mmdeploy/backend/pplnn/onnx2pplnn.py @@ -4,14 +4,14 @@ from pyppl import nn as pplnn from mmdeploy.utils.device import parse_cuda_device_id -from .wrapper import register_engines +from .utils import register_engines -def onnx2pplnn(algo_file: str, - onnx_model: str, - device: str = 'cuda:0', - input_shapes: Optional[Sequence[Sequence[int]]] = None, - **kwargs): +def from_onnx(onnx_model: str, + output_file_prefix: str, + device: str = 'cuda:0', + input_shapes: Optional[Sequence[Sequence[int]]] = None, + **kwargs): """Convert ONNX to PPLNN. PPLNN is capable of optimizing onnx model. The optimized algorithm is saved @@ -21,16 +21,18 @@ def onnx2pplnn(algo_file: str, own preferences. Args: - algo_file (str): File path to save PPLNN optimization algorithm. + output_file_prefix (str): File path to save PPLNN optimization + algorithm and ONNX file onnx_model (str): Input onnx model. device (str): A string specifying device, defaults to 'cuda:0'. input_shapes (Sequence[Sequence[int]] | None): Shapes for PPLNN optimization, default to None. Examples: - >>> from mmdeploy.apis.pplnn import onnx2pplnn + >>> from mmdeploy.apis.pplnn import from_onnx >>> - >>> onnx2pplnn(algo_file = 'example.json', onnx_model = 'example.onnx') + >>> from_onnx(onnx_model = 'example.onnx', + output_file_prefix = 'example') """ if device == 'cpu': device_id = -1 @@ -42,6 +44,8 @@ def onnx2pplnn(algo_file: str, input_shapes = [[1, 3, 224, 224]] # PPLNN default shape for optimization + algo_file = output_file_prefix + '.json' + onnx_output_path = output_file_prefix + '.onnx' engines = register_engines( device_id, disable_avx512=False, @@ -52,3 +56,6 @@ def onnx2pplnn(algo_file: str, onnx_model, engines) assert runtime_builder is not None, 'Failed to create '\ 'OnnxRuntimeBuilder.' + import shutil + if onnx_output_path != onnx_model: + shutil.copy2(onnx_model, onnx_output_path) diff --git a/mmdeploy/backend/pplnn/utils.py b/mmdeploy/backend/pplnn/utils.py new file mode 100644 index 000000000..f18d45330 --- /dev/null +++ b/mmdeploy/backend/pplnn/utils.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +from typing import List, Sequence + +import pyppl.common as pplcommon +import pyppl.nn as pplnn + +from mmdeploy.utils import get_root_logger + + +def register_engines(device_id: int, + disable_avx512: bool = False, + quick_select: bool = False, + input_shapes: Sequence[Sequence[int]] = None, + export_algo_file: str = None, + import_algo_file: str = None) -> List[pplnn.Engine]: + """Register engines for pplnn runtime. + + Args: + device_id (int): Specifying device index. `-1` for cpu. + disable_avx512 (bool): Whether to disable avx512 for x86. + Defaults to `False`. + quick_select (bool): Whether to use default algorithms. + Defaults to `False`. + input_shapes (Sequence[Sequence[int]]): shapes for PPLNN optimization. + export_algo_file (str): File path for exporting PPLNN optimization + file. + import_algo_file (str): File path for loading PPLNN optimization file. + + Returns: + list[pplnn.Engine]: A list of registered pplnn engines. + """ + engines = [] + logger = get_root_logger() + if device_id == -1: + x86_options = pplnn.X86EngineOptions() + x86_engine = pplnn.X86EngineFactory.Create(x86_options) + if not x86_engine: + logger.error('Failed to create x86 engine') + sys.exit(-1) + + if disable_avx512: + status = x86_engine.Configure(pplnn.X86_CONF_DISABLE_AVX512) + if status != pplcommon.RC_SUCCESS: + logger.error('x86 engine Configure() failed: ' + + pplcommon.GetRetCodeStr(status)) + sys.exit(-1) + + engines.append(pplnn.Engine(x86_engine)) + + else: + cuda_options = pplnn.CudaEngineOptions() + cuda_options.device_id = device_id + + cuda_engine = pplnn.CudaEngineFactory.Create(cuda_options) + if not cuda_engine: + logger.error('Failed to create cuda engine.') + sys.exit(-1) + + if quick_select: + status = cuda_engine.Configure( + pplnn.CUDA_CONF_USE_DEFAULT_ALGORITHMS) + if status != pplcommon.RC_SUCCESS: + logger.error('cuda engine Configure() failed: ' + + pplcommon.GetRetCodeStr(status)) + sys.exit(-1) + + if input_shapes is not None: + status = cuda_engine.Configure(pplnn.CUDA_CONF_SET_INPUT_DIMS, + input_shapes) + if status != pplcommon.RC_SUCCESS: + logger.error( + 'cuda engine Configure(CUDA_CONF_SET_INPUT_DIMS) failed: ' + + pplcommon.GetRetCodeStr(status)) + sys.exit(-1) + + if export_algo_file is not None: + status = cuda_engine.Configure(pplnn.CUDA_CONF_EXPORT_ALGORITHMS, + export_algo_file) + if status != pplcommon.RC_SUCCESS: + logger.error( + 'cuda engine Configure(CUDA_CONF_EXPORT_ALGORITHMS) ' + 'failed: ' + pplcommon.GetRetCodeStr(status)) + sys.exit(-1) + + if import_algo_file is not None: + status = cuda_engine.Configure(pplnn.CUDA_CONF_IMPORT_ALGORITHMS, + import_algo_file) + if status != pplcommon.RC_SUCCESS: + logger.error( + 'cuda engine Configure(CUDA_CONF_IMPORT_ALGORITHMS) ' + 'failed: ' + pplcommon.GetRetCodeStr(status)) + sys.exit(-1) + + engines.append(pplnn.Engine(cuda_engine)) + + return engines diff --git a/mmdeploy/backend/pplnn/wrapper.py b/mmdeploy/backend/pplnn/wrapper.py index 2692925fe..aaba37f5d 100644 --- a/mmdeploy/backend/pplnn/wrapper.py +++ b/mmdeploy/backend/pplnn/wrapper.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -import sys -from typing import Dict, List, Optional, Sequence +from typing import Dict, Optional, Sequence import numpy as np import onnx @@ -8,98 +7,10 @@ import pyppl.nn as pplnn import torch -from mmdeploy.utils import Backend, get_root_logger, parse_device_id +from mmdeploy.utils import Backend, parse_device_id from mmdeploy.utils.timer import TimeCounter from ..base import BACKEND_WRAPPER, BaseWrapper - - -def register_engines(device_id: int, - disable_avx512: bool = False, - quick_select: bool = False, - input_shapes: Sequence[Sequence[int]] = None, - export_algo_file: str = None, - import_algo_file: str = None) -> List[pplnn.Engine]: - """Register engines for pplnn runtime. - - Args: - device_id (int): Specifying device index. `-1` for cpu. - disable_avx512 (bool): Whether to disable avx512 for x86. - Defaults to `False`. - quick_select (bool): Whether to use default algorithms. - Defaults to `False`. - input_shapes (Sequence[Sequence[int]]): shapes for PPLNN optimization. - export_algo_file (str): File path for exporting PPLNN optimization - file. - import_algo_file (str): File path for loading PPLNN optimization file. - - Returns: - list[pplnn.Engine]: A list of registered pplnn engines. - """ - engines = [] - logger = get_root_logger() - if device_id == -1: - x86_options = pplnn.X86EngineOptions() - x86_engine = pplnn.X86EngineFactory.Create(x86_options) - if not x86_engine: - logger.error('Failed to create x86 engine') - sys.exit(-1) - - if disable_avx512: - status = x86_engine.Configure(pplnn.X86_CONF_DISABLE_AVX512) - if status != pplcommon.RC_SUCCESS: - logger.error('x86 engine Configure() failed: ' + - pplcommon.GetRetCodeStr(status)) - sys.exit(-1) - - engines.append(pplnn.Engine(x86_engine)) - - else: - cuda_options = pplnn.CudaEngineOptions() - cuda_options.device_id = device_id - - cuda_engine = pplnn.CudaEngineFactory.Create(cuda_options) - if not cuda_engine: - logger.error('Failed to create cuda engine.') - sys.exit(-1) - - if quick_select: - status = cuda_engine.Configure( - pplnn.CUDA_CONF_USE_DEFAULT_ALGORITHMS) - if status != pplcommon.RC_SUCCESS: - logger.error('cuda engine Configure() failed: ' + - pplcommon.GetRetCodeStr(status)) - sys.exit(-1) - - if input_shapes is not None: - status = cuda_engine.Configure(pplnn.CUDA_CONF_SET_INPUT_DIMS, - input_shapes) - if status != pplcommon.RC_SUCCESS: - logger.error( - 'cuda engine Configure(CUDA_CONF_SET_INPUT_DIMS) failed: ' - + pplcommon.GetRetCodeStr(status)) - sys.exit(-1) - - if export_algo_file is not None: - status = cuda_engine.Configure(pplnn.CUDA_CONF_EXPORT_ALGORITHMS, - export_algo_file) - if status != pplcommon.RC_SUCCESS: - logger.error( - 'cuda engine Configure(CUDA_CONF_EXPORT_ALGORITHMS) ' - 'failed: ' + pplcommon.GetRetCodeStr(status)) - sys.exit(-1) - - if import_algo_file is not None: - status = cuda_engine.Configure(pplnn.CUDA_CONF_IMPORT_ALGORITHMS, - import_algo_file) - if status != pplcommon.RC_SUCCESS: - logger.error( - 'cuda engine Configure(CUDA_CONF_IMPORT_ALGORITHMS) ' - 'failed: ' + pplcommon.GetRetCodeStr(status)) - sys.exit(-1) - - engines.append(pplnn.Engine(cuda_engine)) - - return engines +from .utils import register_engines @BACKEND_WRAPPER.register_module(Backend.PPLNN.value) diff --git a/mmdeploy/utils/export_info.py b/mmdeploy/backend/sdk/export_info.py similarity index 99% rename from mmdeploy/utils/export_info.py rename to mmdeploy/backend/sdk/export_info.py index 3dadce53f..459e82b18 100644 --- a/mmdeploy/utils/export_info.py +++ b/mmdeploy/backend/sdk/export_info.py @@ -350,8 +350,8 @@ def get_detail(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, calib_config=calib_config) -def dump_info(deploy_cfg: Union[str, mmcv.Config], - model_cfg: Union[str, mmcv.Config], work_dir: str, pth: str): +def export2SDK(deploy_cfg: Union[str, mmcv.Config], + model_cfg: Union[str, mmcv.Config], work_dir: str, pth: str): """Export information to SDK. This function dump `deploy.json`, `pipeline.json` and `detail.json` to work dir. diff --git a/mmdeploy/backend/tensorrt/__init__.py b/mmdeploy/backend/tensorrt/__init__.py index b86fd8efd..885bae40d 100644 --- a/mmdeploy/backend/tensorrt/__init__.py +++ b/mmdeploy/backend/tensorrt/__init__.py @@ -16,7 +16,7 @@ def is_available(): return importlib.util.find_spec('tensorrt') is not None -def is_plugin_available(): +def is_custom_ops_available(): """Check whether TensorRT custom ops are installed. Returns: @@ -27,12 +27,9 @@ def is_plugin_available(): if is_available(): - from .utils import create_trt_engine, load_trt_engine, save_trt_engine + from .utils import from_onnx, load, save - __all__ = [ - 'create_trt_engine', 'save_trt_engine', 'load_trt_engine', - 'load_tensorrt_plugin' - ] + __all__ = ['from_onnx', 'save', 'load', 'load_tensorrt_plugin'] try: # import wrapper if pytorch is available diff --git a/mmdeploy/backend/tensorrt/onnx2tensorrt.py b/mmdeploy/backend/tensorrt/onnx2tensorrt.py index f0e316e46..8e38f100b 100644 --- a/mmdeploy/backend/tensorrt/onnx2tensorrt.py +++ b/mmdeploy/backend/tensorrt/onnx2tensorrt.py @@ -8,7 +8,7 @@ from mmdeploy.utils import (get_calib_filename, get_common_config, get_model_inputs, load_config, parse_device_id) from mmdeploy.utils.config_utils import get_ir_config -from .utils import create_trt_engine, get_trt_log_level, save_trt_engine +from .utils import from_onnx, get_trt_log_level def onnx2tensorrt(work_dir: str, @@ -72,8 +72,13 @@ def onnx2tensorrt(work_dir: str, but given: {device}' device_id = parse_device_id(device) - engine = create_trt_engine( + assert save_file.endswith( + '.engine' + ), 'Expect save file ends with `.engine`.' f' but get {save_file}' + save_path = osp.join(work_dir, save_file) + from_onnx( onnx_model, + osp.splitext(save_path)[0], input_shapes=input_shapes, log_level=get_trt_log_level(), fp16_mode=final_params.get('fp16_mode', False), @@ -81,5 +86,3 @@ def onnx2tensorrt(work_dir: str, int8_param=int8_param, max_workspace_size=final_params.get('max_workspace_size', 0), device_id=device_id) - - save_trt_engine(engine, osp.join(work_dir, save_file)) diff --git a/mmdeploy/backend/tensorrt/utils.py b/mmdeploy/backend/tensorrt/utils.py index 4a2c56ca2..ace6b5598 100644 --- a/mmdeploy/backend/tensorrt/utils.py +++ b/mmdeploy/backend/tensorrt/utils.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging -from typing import Dict, Sequence, Union +from typing import Dict, Optional, Sequence, Union import onnx import tensorrt as trt @@ -10,38 +10,68 @@ from .init_plugins import load_tensorrt_plugin -def create_trt_engine(onnx_model: Union[str, onnx.ModelProto], - input_shapes: Dict[str, Sequence[int]], - log_level: trt.Logger.Severity = trt.Logger.ERROR, - fp16_mode: bool = False, - int8_mode: bool = False, - int8_param: dict = None, - max_workspace_size: int = 0, - device_id: int = 0, - **kwargs) -> trt.ICudaEngine: +def save(engine: trt.ICudaEngine, path: str) -> None: + """Serialize TensorRT engine to disk. + + Args: + engine (tensorrt.ICudaEngine): TensorRT engine to be serialized. + path (str): The absolute disk path to write the engine. + """ + with open(path, mode='wb') as f: + f.write(bytearray(engine.serialize())) + + +def load(path: str) -> trt.ICudaEngine: + """Deserialize TensorRT engine from disk. + + Args: + path (str): The disk path to read the engine. + + Returns: + tensorrt.ICudaEngine: The TensorRT engine loaded from disk. + """ + load_tensorrt_plugin() + with trt.Logger() as logger, trt.Runtime(logger) as runtime: + with open(path, mode='rb') as f: + engine_bytes = f.read() + engine = runtime.deserialize_cuda_engine(engine_bytes) + return engine + + +def from_onnx(onnx_model: Union[str, onnx.ModelProto], + output_file_prefix: str, + input_shapes: Dict[str, Sequence[int]], + max_workspace_size: int = 0, + fp16_mode: bool = False, + int8_mode: bool = False, + int8_param: Optional[dict] = None, + device_id: int = 0, + log_level: trt.Logger.Severity = trt.Logger.ERROR, + **kwargs) -> trt.ICudaEngine: """Create a tensorrt engine from ONNX. Args: onnx_model (str or onnx.ModelProto): Input onnx model to convert from. + output_file_prefix (str): The path to save the output ncnn file. input_shapes (Dict[str, Sequence[int]]): The min/opt/max shape of each input. - log_level (trt.Logger.Severity): The log level of TensorRT. Defaults to - `trt.Logger.ERROR`. + max_workspace_size (int): To set max workspace size of TensorRT engine. + some tactics and layers need large workspace. Defaults to `0`. fp16_mode (bool): Specifying whether to enable fp16 mode. Defaults to `False`. int8_mode (bool): Specifying whether to enable int8 mode. Defaults to `False`. int8_param (dict): A dict of parameter int8 mode. Defaults to `None`. - max_workspace_size (int): To set max workspace size of TensorRT engine. - some tactics and layers need large workspace. Defaults to `0`. device_id (int): Choice the device to create engine. Defaults to `0`. + log_level (trt.Logger.Severity): The log level of TensorRT. Defaults to + `trt.Logger.ERROR`. Returns: tensorrt.ICudaEngine: The TensorRT engine created from onnx_model. Example: - >>> from mmdeploy.apis.tensorrt import create_trt_engine - >>> engine = create_trt_engine( + >>> from mmdeploy.apis.tensorrt import from_onnx + >>> engine = from_onnx( >>> "onnx_model.onnx", >>> {'input': {"min_shape" : [1, 3, 160, 160], >>> "opt_shape" : [1, 3, 320, 320], @@ -121,35 +151,9 @@ def create_trt_engine(onnx_model: Union[str, onnx.ModelProto], engine = builder.build_engine(network, config) assert engine is not None, 'Failed to create TensorRT engine' - return engine - - -def save_trt_engine(engine: trt.ICudaEngine, path: str) -> None: - """Serialize TensorRT engine to disk. - - Args: - engine (tensorrt.ICudaEngine): TensorRT engine to be serialized. - path (str): The absolute disk path to write the engine. - """ - with open(path, mode='wb') as f: - f.write(bytearray(engine.serialize())) - -def load_trt_engine(path: str) -> trt.ICudaEngine: - """Deserialize TensorRT engine from disk. - - Args: - path (str): The disk path to read the engine. - - Returns: - tensorrt.ICudaEngine: The TensorRT engine loaded from disk. - """ - load_tensorrt_plugin() - with trt.Logger() as logger, trt.Runtime(logger) as runtime: - with open(path, mode='rb') as f: - engine_bytes = f.read() - engine = runtime.deserialize_cuda_engine(engine_bytes) - return engine + save(engine, output_file_prefix + '.engine') + return engine def get_trt_log_level() -> trt.Logger.Severity: diff --git a/mmdeploy/backend/tensorrt/wrapper.py b/mmdeploy/backend/tensorrt/wrapper.py index 9a23d5b2a..888c3cae2 100644 --- a/mmdeploy/backend/tensorrt/wrapper.py +++ b/mmdeploy/backend/tensorrt/wrapper.py @@ -8,7 +8,7 @@ from mmdeploy.utils.timer import TimeCounter from ..base import BACKEND_WRAPPER, BaseWrapper from .init_plugins import load_tensorrt_plugin -from .utils import load_trt_engine +from .utils import load def torch_dtype_from_trt(dtype: trt.DataType) -> torch.dtype: @@ -81,7 +81,7 @@ def __init__(self, load_tensorrt_plugin() self.engine = engine if isinstance(self.engine, str): - self.engine = load_trt_engine(engine) + self.engine = load(engine) if not isinstance(self.engine, trt.ICudaEngine): raise TypeError(f'`engine` should be str or trt.ICudaEngine, \ diff --git a/mmdeploy/utils/logging.py b/mmdeploy/utils/logging.py index 7a6ea65d1..2a7c7afc6 100644 --- a/mmdeploy/utils/logging.py +++ b/mmdeploy/utils/logging.py @@ -15,7 +15,6 @@ def get_logger(name: str, logger by adding one or two handlers, otherwise the initialized logger will be directly returned. During initialization, a StreamHandler will always be added. If `log_file` is specified, a FileHandler will also be added. - Args: name (str): Logger name. log_file (str | None): The log filename. If specified, a FileHandler @@ -23,7 +22,6 @@ def get_logger(name: str, log_level (int): The logger level. file_mode (str): The file mode used in opening log file. Defaults to 'w'. - Returns: logging.Logger: The expected logger. """ diff --git a/mmdeploy/utils/test.py b/mmdeploy/utils/test.py index 6f1decaa6..f484e3803 100644 --- a/mmdeploy/utils/test.py +++ b/mmdeploy/utils/test.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp import random import string import tempfile @@ -28,21 +29,21 @@ def backend_checker(backend: Backend, require_plugin: bool = False): will also check if the backend plugin has been compiled. Default to `False`. """ - is_plugin_available = None + is_custom_ops_available = None if backend == Backend.ONNXRUNTIME: from mmdeploy.apis.onnxruntime import is_available if require_plugin: - from mmdeploy.apis.onnxruntime import is_plugin_available + from mmdeploy.apis.onnxruntime import is_custom_ops_available elif backend == Backend.TENSORRT: from mmdeploy.apis.tensorrt import is_available if require_plugin: - from mmdeploy.apis.tensorrt import is_plugin_available + from mmdeploy.apis.tensorrt import is_custom_ops_available elif backend == Backend.PPLNN: from mmdeploy.apis.pplnn import is_available elif backend == Backend.NCNN: from mmdeploy.apis.ncnn import is_available if require_plugin: - from mmdeploy.apis.ncnn import is_plugin_available + from mmdeploy.apis.ncnn import is_custom_ops_available elif backend == Backend.OPENVINO: from mmdeploy.apis.openvino import is_available else: @@ -51,9 +52,9 @@ def backend_checker(backend: Backend, require_plugin: bool = False): checker = pytest.mark.skipif( not is_available(), reason=f'{backend.value} package is not available') - if require_plugin and is_plugin_available is not None: + if require_plugin and is_custom_ops_available is not None: plugin_checker = pytest.mark.skipif( - not is_plugin_available(), + not is_custom_ops_available(), reason=f'{backend.value} plugin is not available') def double_checker(func): @@ -76,21 +77,21 @@ def check_backend(backend: Backend, require_plugin: bool = False): will also check if the backend plugin has been compiled. Default to `False`. """ - is_plugin_available = None + is_custom_ops_available = None if backend == Backend.ONNXRUNTIME: from mmdeploy.apis.onnxruntime import is_available if require_plugin: - from mmdeploy.apis.onnxruntime import is_plugin_available + from mmdeploy.apis.onnxruntime import is_custom_ops_available elif backend == Backend.TENSORRT: from mmdeploy.apis.tensorrt import is_available if require_plugin: - from mmdeploy.apis.tensorrt import is_plugin_available + from mmdeploy.apis.tensorrt import is_custom_ops_available elif backend == Backend.PPLNN: from mmdeploy.apis.pplnn import is_available elif backend == Backend.NCNN: from mmdeploy.apis.ncnn import is_available if require_plugin: - from mmdeploy.apis.ncnn import is_plugin_available + from mmdeploy.apis.ncnn import is_custom_ops_available elif backend == Backend.OPENVINO: from mmdeploy.apis.openvino import is_available elif backend == Backend.TORCHSCRIPT: @@ -101,8 +102,8 @@ def check_backend(backend: Backend, require_plugin: bool = False): if not is_available(): pytest.skip(f'{backend.value} package is not available') - if require_plugin and is_plugin_available is not None: - if not is_plugin_available(): + if require_plugin and is_custom_ops_available is not None: + if not is_custom_ops_available(): pytest.skip(f'{backend.value} plugin is not available') @@ -409,14 +410,18 @@ def get_ts_model(wrapped_model: nn.Module, """ ir_file_path = tempfile.NamedTemporaryFile(suffix='.pt').name backend = get_backend(deploy_cfg) - patched_model = patch_model( - wrapped_model, cfg=deploy_cfg, backend=backend.value) - from mmdeploy.apis.pytorch2torchscript import torch2torchscript_impl - torch2torchscript_impl( - patched_model, [v for _, v in model_inputs.items()], - deploy_cfg=deploy_cfg, - output_file=ir_file_path) + from mmdeploy.apis.torch_jit import trace + context_info = dict(deploy_cfg=deploy_cfg) + output_prefix = osp.splitext(ir_file_path)[0] + + example_inputs = [v for _, v in model_inputs.items()] + trace( + wrapped_model, + example_inputs, + output_path_prefix=output_prefix, + backend=backend, + context_info=context_info) return ir_file_path @@ -450,7 +455,8 @@ def get_backend_outputs(ir_file_path: str, if backend == Backend.TENSORRT: # convert to engine import mmdeploy.apis.tensorrt as trt_apis - if not (trt_apis.is_available() and trt_apis.is_plugin_available()): + if not (trt_apis.is_available() + and trt_apis.is_custom_ops_available()): return None trt_file_path = tempfile.NamedTemporaryFile(suffix='.engine').name trt_apis.onnx2tensorrt( @@ -467,7 +473,8 @@ def get_backend_outputs(ir_file_path: str, device = 'cuda:0' elif backend == Backend.ONNXRUNTIME: import mmdeploy.apis.onnxruntime as ort_apis - if not (ort_apis.is_available() and ort_apis.is_plugin_available()): + if not (ort_apis.is_available() + and ort_apis.is_custom_ops_available()): return None feature_list = [] backend_feats = {} @@ -495,12 +502,14 @@ def get_backend_outputs(ir_file_path: str, device = 'cpu' elif backend == Backend.NCNN: import mmdeploy.apis.ncnn as ncnn_apis - if not (ncnn_apis.is_available() and ncnn_apis.is_plugin_available()): + if not (ncnn_apis.is_available() + and ncnn_apis.is_custom_ops_available()): return None work_dir = tempfile.TemporaryDirectory().name param_path, bin_path = ncnn_apis.get_output_model_file( ir_file_path, work_dir) - ncnn_apis.onnx2ncnn(ir_file_path, param_path, bin_path) + ir_file_name = osp.splitext(ir_file_path)[0] + ncnn_apis.from_onnx(ir_file_path, osp.join(work_dir, ir_file_name)) backend_files = [param_path, bin_path] backend_feats = flatten_model_inputs device = 'cpu' @@ -518,8 +527,8 @@ def get_backend_outputs(ir_file_path: str, for name, value in flatten_model_inputs.items() } mo_options = get_mo_options_from_cfg(deploy_cfg) - openvino_apis.onnx2openvino(input_info, output_names, ir_file_path, - openvino_work_dir, mo_options) + openvino_apis.from_onnx(ir_file_path, openvino_work_dir, input_info, + output_names, mo_options) backend_files = [openvino_file_path] backend_feats = flatten_model_inputs device = 'cpu' diff --git a/mmdeploy/utils/utils.py b/mmdeploy/utils/utils.py index 29b94dde6..12cc89f77 100644 --- a/mmdeploy/utils/utils.py +++ b/mmdeploy/utils/utils.py @@ -6,7 +6,10 @@ import traceback from typing import Callable, Optional, Union -import multiprocess as mp +try: + from torch import multiprocessing as mp +except ImportError: + import multiprocess as mp from mmdeploy.utils.logging import get_logger diff --git a/tests/test_apis/test_calibration.py b/tests/test_apis/test_calibration.py index 6734fc58f..ce8fc76ba 100644 --- a/tests/test_apis/test_calibration.py +++ b/tests/test_apis/test_calibration.py @@ -6,7 +6,7 @@ import h5py import mmcv -from mmdeploy.apis import create_calib_table +from mmdeploy.apis import create_calib_input_data calib_file = tempfile.NamedTemporaryFile(suffix='.h5').name ann_file = 'tests/data/annotation.json' @@ -173,7 +173,7 @@ def get_model_cfg(): def run_test_create_calib_end2end(): model_cfg = get_model_cfg() deploy_cfg = get_end2end_deploy_cfg() - create_calib_table( + create_calib_input_data( calib_file, deploy_cfg, model_cfg, @@ -205,7 +205,7 @@ def test_create_calib_end2end(): def run_test_create_calib_parittion(): model_cfg = get_model_cfg() deploy_cfg = get_partition_deploy_cfg() - create_calib_table( + create_calib_input_data( calib_file, deploy_cfg, model_cfg, diff --git a/tests/test_apis/test_extract.py b/tests/test_apis/test_extract.py index 754ce5bcf..07432b7b0 100644 --- a/tests/test_apis/test_extract.py +++ b/tests/test_apis/test_extract.py @@ -4,7 +4,7 @@ import onnx import torch -from mmdeploy.apis import extract_model +from mmdeploy.apis.onnx import extract_partition from mmdeploy.core import mark output_file = tempfile.NamedTemporaryFile(suffix='.onnx').name @@ -33,7 +33,7 @@ def forward(self, x, y): torch.onnx.export(model, (x, y), output_file) onnx_model = onnx.load(output_file) - extracted = extract_model(onnx_model, 'add:input', 'add:output') + extracted = extract_partition(onnx_model, 'add:input', 'add:output') assert extracted.graph.input[0].name == 'x' assert extracted.graph.input[1].name == 'y' diff --git a/tests/test_apis/test_onnx2ncnn.py b/tests/test_apis/test_onnx2ncnn.py index 8073b7754..ef6cf1f1d 100644 --- a/tests/test_apis/test_onnx2ncnn.py +++ b/tests/test_apis/test_onnx2ncnn.py @@ -55,13 +55,14 @@ def generate_onnx_file(model): @backend_checker(Backend.NCNN) def test_onnx2ncnn(): - from mmdeploy.apis.ncnn import onnx2ncnn + from mmdeploy.apis.ncnn import from_onnx model = test_model generate_onnx_file(model) work_dir, _ = osp.split(onnx_file) save_param, save_bin = get_output_model_file(onnx_file, work_dir=work_dir) - onnx2ncnn(onnx_file, save_param, save_bin) + file_name = osp.splitext(onnx_file)[0] + from_onnx(onnx_file, osp.join(work_dir, file_name)) assert osp.exists(work_dir) assert osp.exists(save_param) assert osp.exists(save_bin) diff --git a/tests/test_apis/test_onnx2openvino.py b/tests/test_apis/test_onnx2openvino.py index 885d00b31..09d2b84d4 100644 --- a/tests/test_apis/test_onnx2openvino.py +++ b/tests/test_apis/test_onnx2openvino.py @@ -80,8 +80,8 @@ def get_deploy_cfg_with_mo_args(): [get_base_deploy_cfg, get_deploy_cfg_with_mo_args]) @backend_checker(Backend.OPENVINO) def test_onnx2openvino(get_deploy_cfg): - from mmdeploy.apis.openvino import (get_mo_options_from_cfg, - get_output_model_file, onnx2openvino) + from mmdeploy.apis.openvino import (from_onnx, get_mo_options_from_cfg, + get_output_model_file) pytorch_model = TestModel().eval() export_img = torch.rand([1, 3, 8, 8]) onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name @@ -95,8 +95,7 @@ def test_onnx2openvino(get_deploy_cfg): openvino_dir = tempfile.TemporaryDirectory().name deploy_cfg = get_deploy_cfg() mo_options = get_mo_options_from_cfg(deploy_cfg) - onnx2openvino(input_info, output_names, onnx_file, openvino_dir, - mo_options) + from_onnx(onnx_file, openvino_dir, input_info, output_names, mo_options) openvino_model_path = get_output_model_file(onnx_file, openvino_dir) assert osp.exists(openvino_model_path), \ 'The file (.xml) for OpenVINO IR has not been created.' @@ -117,8 +116,8 @@ def test_can_not_run_onnx2openvino_without_mo(): is_error = False try: - from mmdeploy.apis.openvino import onnx2openvino - onnx2openvino({}, ['output'], 'tmp.onnx', '/tmp') + from mmdeploy.apis.openvino import from_onnx + from_onnx('tmp.onnx', '/tmp', {}, ['output']) except RuntimeError: is_error = True diff --git a/tests/test_apis/test_onnx2tensorrt.py b/tests/test_apis/test_onnx2tensorrt.py index f6ef73eb9..f58c33e9f 100644 --- a/tests/test_apis/test_onnx2tensorrt.py +++ b/tests/test_apis/test_onnx2tensorrt.py @@ -75,7 +75,7 @@ def generate_onnx_file(model): @backend_checker(Backend.TENSORRT) def test_onnx2tensorrt(): from mmdeploy.apis.tensorrt import onnx2tensorrt - from mmdeploy.backend.tensorrt import load_trt_engine + from mmdeploy.backend.tensorrt import load model = test_model generate_onnx_file(model) deploy_cfg = get_deploy_cfg() @@ -85,5 +85,5 @@ def test_onnx2tensorrt(): onnx2tensorrt(work_dir, save_file, 0, deploy_cfg, onnx_file) assert osp.exists(work_dir) assert osp.exists(engine_file) - engine = load_trt_engine(engine_file) + engine = load(engine_file) assert engine is not None diff --git a/tests/test_apis/test_torch2onnx.py b/tests/test_apis/test_torch2onnx.py index 349a9c642..e0e620913 100644 --- a/tests/test_apis/test_torch2onnx.py +++ b/tests/test_apis/test_torch2onnx.py @@ -8,7 +8,9 @@ import torch import torch.nn as nn -from mmdeploy.apis import torch2onnx_impl +from mmdeploy.apis.onnx import export +from mmdeploy.utils.config_utils import (get_backend, get_dynamic_axes, + get_onnx_config) from mmdeploy.utils.test import get_random_name onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name @@ -63,7 +65,33 @@ def get_deploy_cfg(input_name, output_name, dynamic_axes): [dynamic_axes_dict, dynamic_axes_list]) def test_torch2onnx(input_name, output_name, dynamic_axes): deploy_cfg = get_deploy_cfg(input_name, output_name, dynamic_axes) - torch2onnx_impl(test_model, test_img, deploy_cfg, onnx_file) + + output_prefix = osp.splitext(onnx_file)[0] + context_info = dict(cfg=deploy_cfg) + backend = get_backend(deploy_cfg).value + onnx_cfg = get_onnx_config(deploy_cfg) + opset_version = onnx_cfg.get('opset_version', 11) + + input_names = onnx_cfg['input_names'] + output_names = onnx_cfg['output_names'] + axis_names = input_names + output_names + dynamic_axes = get_dynamic_axes(deploy_cfg, axis_names) + verbose = not onnx_cfg.get('strip_doc_string', True) or onnx_cfg.get( + 'verbose', False) + keep_initializers_as_inputs = onnx_cfg.get('keep_initializers_as_inputs', + True) + export( + test_model, + test_img, + context_info=context_info, + output_path_prefix=output_prefix, + backend=backend, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + verbose=verbose, + keep_initializers_as_inputs=keep_initializers_as_inputs) assert osp.exists(onnx_file) diff --git a/tests/test_apis/test_torch2torchscript.py b/tests/test_apis/test_torch2torchscript.py index 4bb1c5c99..c2f98db06 100644 --- a/tests/test_apis/test_torch2torchscript.py +++ b/tests/test_apis/test_torch2torchscript.py @@ -84,4 +84,5 @@ def test_torch2torchscript(input_name, output_name): model_cfg=get_model_cfg(), device='cpu') + print(ts_file) assert osp.exists(ts_file) diff --git a/tests/test_backend/test_wrapper.py b/tests/test_backend/test_wrapper.py index b177a2ee5..33ecb9ef9 100644 --- a/tests/test_backend/test_wrapper.py +++ b/tests/test_backend/test_wrapper.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp import subprocess import tempfile @@ -49,28 +50,35 @@ def generate_onnx_file(): def generate_torchscript_file(): import mmcv - from mmdeploy.apis import torch2torchscript_impl - deploy_cfg = mmcv.Config( - {'backend_config': dict(type=Backend.TORCHSCRIPT.value)}) - with torch.no_grad(): - torch2torchscript_impl(model, torch.rand(1, 3, 8, 8), deploy_cfg, - ts_file) + backend = Backend.TORCHSCRIPT.value + deploy_cfg = mmcv.Config({'backend_config': dict(type=backend)}) + + from mmdeploy.apis.torch_jit import trace + context_info = dict(deploy_cfg=deploy_cfg) + output_prefix = osp.splitext(ts_file)[0] + + example_inputs = torch.rand(1, 3, 8, 8) + trace( + model, + example_inputs, + output_path_prefix=output_prefix, + backend=backend, + context_info=context_info) def onnx2backend(backend, onnx_file): if backend == Backend.TENSORRT: - from mmdeploy.backend.tensorrt import (create_trt_engine, - save_trt_engine) + from mmdeploy.backend.tensorrt import from_onnx backend_file = tempfile.NamedTemporaryFile(suffix='.engine').name - engine = create_trt_engine( - onnx_file, { + from_onnx( + onnx_file, + osp.splitext(backend_file)[0], { 'input': { 'min_shape': [1, 3, 8, 8], 'opt_shape': [1, 3, 8, 8], 'max_shape': [1, 3, 8, 8] } }) - save_trt_engine(engine, backend_file) return backend_file elif backend == Backend.ONNXRUNTIME: return onnx_file @@ -87,13 +95,13 @@ def onnx2backend(backend, onnx_file): subprocess.call([onnx2ncnn_path, onnx_file, param_file, bin_file]) return param_file, bin_file elif backend == Backend.OPENVINO: - from mmdeploy.apis.openvino import get_output_model_file, onnx2openvino + from mmdeploy.apis.openvino import from_onnx, get_output_model_file backend_dir = tempfile.TemporaryDirectory().name backend_file = get_output_model_file(onnx_file, backend_dir) input_info = {'input': test_img.shape} output_names = ['output'] work_dir = backend_dir - onnx2openvino(input_info, output_names, onnx_file, work_dir) + from_onnx(onnx_file, work_dir, input_info, output_names) return backend_file diff --git a/tests/test_ops/utils.py b/tests/test_ops/utils.py index dbe325894..90f3f4e5b 100644 --- a/tests/test_ops/utils.py +++ b/tests/test_ops/utils.py @@ -230,21 +230,27 @@ def run_and_validate(self, tolerate_small_mismatch) def onnx2ncnn(self, model, model_name, output_names, save_dir=None): - if save_dir is None: - onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name - ncnn_param_path = tempfile.NamedTemporaryFile(suffix='.param').name - ncnn_bin_path = tempfile.NamedTemporaryFile(suffix='.bin').name - else: + + def _from_onnx(self, model, model_name, output_names, save_dir=None): onnx_file_path = os.path.join(save_dir, model_name + '.onnx') ncnn_param_path = os.path.join(save_dir, model_name + '.param') ncnn_bin_path = os.path.join(save_dir, model_name + '.bin') - onnx.save_model(model, onnx_file_path) + onnx.save_model(model, onnx_file_path) - from mmdeploy.backend.ncnn.onnx2ncnn import onnx2ncnn - onnx2ncnn(onnx_file_path, ncnn_param_path, ncnn_bin_path) + from mmdeploy.backend.ncnn import from_onnx + from_onnx(onnx_file_path, os.path.join(save_dir, model_name)) - from mmdeploy.backend.ncnn import NCNNWrapper - ncnn_model = NCNNWrapper(ncnn_param_path, ncnn_bin_path, output_names) + from mmdeploy.backend.ncnn import NCNNWrapper + ncnn_model = NCNNWrapper(ncnn_param_path, ncnn_bin_path, + output_names) + + return ncnn_model - return ncnn_model + if save_dir is None: + with tempfile.TemporaryDirectory() as save_dir: + return _from_onnx( + self, model, model_name, output_names, save_dir=save_dir) + else: + return _from_onnx( + self, model, model_name, output_names, save_dir=save_dir) diff --git a/tests/test_utils/test_util.py b/tests/test_utils/test_util.py index 081a8e779..9f597d1d7 100644 --- a/tests/test_utils/test_util.py +++ b/tests/test_utils/test_util.py @@ -10,9 +10,9 @@ import torch.multiprocessing as mp import mmdeploy.utils as util +from mmdeploy.backend.sdk.export_info import export2SDK from mmdeploy.utils import target_wrapper from mmdeploy.utils.constants import Backend, Codebase, Task -from mmdeploy.utils.export_info import dump_info from mmdeploy.utils.test import get_random_name correct_model_path = 'tests/data/srgan.py' @@ -413,9 +413,11 @@ def test_AdvancedEnum(): assert k.value == v +@pytest.mark.skipif( + not importlib.util.find_spec('mmedit'), reason='requires mmedit') def test_export_info(): with tempfile.TemporaryDirectory() as dir: - dump_info(correct_deploy_cfg, correct_model_cfg, dir, '') + export2SDK(correct_deploy_cfg, correct_model_cfg, dir, '') deploy_json = os.path.join(dir, 'deploy.json') pipeline_json = os.path.join(dir, 'pipeline.json') detail_json = os.path.join(dir, 'detail.json') diff --git a/tools/deploy.py b/tools/deploy.py index 6efa6574f..77ada1294 100644 --- a/tools/deploy.py +++ b/tools/deploy.py @@ -8,14 +8,15 @@ import torch.multiprocessing as mp from torch.multiprocessing import Process, set_start_method -from mmdeploy.apis import (create_calib_table, extract_model, +from mmdeploy.apis import (create_calib_input_data, extract_model, get_predefined_partition_cfg, torch2onnx, torch2torchscript, visualize_model) +from mmdeploy.apis.core import PIPELINE_MANAGER +from mmdeploy.backend.sdk.export_info import export2SDK from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename, get_ir_config, get_model_inputs, get_partition_config, get_root_logger, load_config, target_wrapper) -from mmdeploy.utils.export_info import dump_info def parse_args(): @@ -91,7 +92,14 @@ def main(): args = parse_args() set_start_method('spawn') logger = get_root_logger() - logger.setLevel(args.log_level) + log_level = logging.getLevelName(args.log_level) + logger.setLevel(log_level) + + pipeline_funcs = [ + torch2onnx, torch2torchscript, extract_model, create_calib_input_data + ] + PIPELINE_MANAGER.enable_multiprocess(True, pipeline_funcs) + PIPELINE_MANAGER.set_log_level(log_level, pipeline_funcs) deploy_cfg_path = args.deploy_cfg model_cfg_path = args.model_cfg @@ -106,7 +114,7 @@ def main(): mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) if args.dump_info: - dump_info(deploy_cfg, model_cfg, args.work_dir, pth=checkpoint_path) + export2SDK(deploy_cfg, model_cfg, args.work_dir, pth=checkpoint_path) ret_value = mp.Value('d', 0, lock=False) @@ -114,13 +122,14 @@ def main(): ir_config = get_ir_config(deploy_cfg) ir_save_file = ir_config['save_file'] ir_type = IR.get(ir_config['type']) - create_process( - f'torch2{ir_type.value}', - target=torch2ir(ir_type), - args=(args.img, args.work_dir, ir_save_file, deploy_cfg_path, - model_cfg_path, checkpoint_path), - kwargs=dict(device=args.device), - ret_value=ret_value) + torch2ir(ir_type)( + args.img, + args.work_dir, + ir_save_file, + deploy_cfg_path, + model_cfg_path, + checkpoint_path, + device=args.device) # convert backend ir_files = [osp.join(args.work_dir, ir_save_file)] @@ -146,12 +155,12 @@ def main(): end = partition_cfg['end'] dynamic_axes = partition_cfg.get('dynamic_axes', None) - create_process( - f'partition model {save_file} with start: {start}, end: {end}', - extract_model, - args=(origin_ir_file, start, end), - kwargs=dict(dynamic_axes=dynamic_axes, save_file=save_path), - ret_value=ret_value) + extract_model( + origin_ir_file, + start, + end, + dynamic_axes=dynamic_axes, + save_file=save_path) ir_files.append(save_path) @@ -159,17 +168,14 @@ def main(): calib_filename = get_calib_filename(deploy_cfg) if calib_filename is not None: calib_path = osp.join(args.work_dir, calib_filename) - - create_process( - 'calibration', - create_calib_table, - args=(calib_path, deploy_cfg_path, model_cfg_path, - checkpoint_path), - kwargs=dict( - dataset_cfg=args.calib_dataset_cfg, - dataset_type='val', - device=args.device), - ret_value=ret_value) + create_calib_input_data( + calib_path, + deploy_cfg_path, + model_cfg_path, + checkpoint_path, + dataset_cfg=args.calib_dataset_cfg, + dataset_type='val', + device=args.device) backend_files = ir_files # convert backend @@ -179,10 +185,14 @@ def main(): assert len(model_params) == len(ir_files) from mmdeploy.apis.tensorrt import is_available as trt_is_available - from mmdeploy.apis.tensorrt import onnx2tensorrt assert trt_is_available( ), 'TensorRT is not available,' \ + ' please install TensorRT and build TensorRT custom ops first.' + + from mmdeploy.apis.tensorrt import onnx2tensorrt + PIPELINE_MANAGER.enable_multiprocess(True, [onnx2tensorrt]) + PIPELINE_MANAGER.set_log_level(logging.INFO, [onnx2tensorrt]) + backend_files = [] for model_id, model_param, onnx_path in zip( range(len(ir_files)), model_params, ir_files): @@ -191,13 +201,14 @@ def main(): partition_type = 'end2end' if partition_cfgs is None \ else onnx_name - create_process( - f'onnx2tensorrt of {onnx_path}', - target=onnx2tensorrt, - args=(args.work_dir, save_file, model_id, deploy_cfg_path, - onnx_path), - kwargs=dict(device=args.device, partition_type=partition_type), - ret_value=ret_value) + onnx2tensorrt( + args.work_dir, + save_file, + model_id, + deploy_cfg_path, + onnx_path, + device=args.device, + partition_type=partition_type) backend_files.append(osp.join(args.work_dir, save_file)) @@ -208,18 +219,17 @@ def main(): logger.error('ncnn support is not available.') exit(1) - from mmdeploy.apis.ncnn import get_output_model_file, onnx2ncnn + import mmdeploy.apis.ncnn as ncnn_api + from mmdeploy.apis.ncnn import get_output_model_file + + PIPELINE_MANAGER.set_log_level(log_level, [ncnn_api.from_onnx]) backend_files = [] for onnx_path in ir_files: model_param_path, model_bin_path = get_output_model_file( onnx_path, args.work_dir) - create_process( - f'onnx2ncnn with {onnx_path}', - target=onnx2ncnn, - args=(onnx_path, model_param_path, model_bin_path), - kwargs=dict(), - ret_value=ret_value) + onnx_name = osp.splitext(osp.split(onnx_path)[1])[0] + ncnn_api.from_onnx(onnx_path, osp.join(args.work_dir, onnx_name)) if quant: from onnx2ncnn_quant_table import get_table @@ -256,23 +266,21 @@ def main(): assert is_available_openvino(), \ 'OpenVINO is not available, please install OpenVINO first.' + import mmdeploy.apis.openvino as openvino_api from mmdeploy.apis.openvino import (get_input_info_from_cfg, get_mo_options_from_cfg, - get_output_model_file, - onnx2openvino) + get_output_model_file) + + PIPELINE_MANAGER.set_log_level(log_level, [openvino_api.from_onnx]) + openvino_files = [] for onnx_path in ir_files: model_xml_path = get_output_model_file(onnx_path, args.work_dir) input_info = get_input_info_from_cfg(deploy_cfg) output_names = get_ir_config(deploy_cfg).output_names mo_options = get_mo_options_from_cfg(deploy_cfg) - create_process( - f'onnx2openvino with {onnx_path}', - target=onnx2openvino, - args=(input_info, output_names, onnx_path, args.work_dir, - mo_options), - kwargs=dict(), - ret_value=ret_value) + openvino_api.from_onnx(onnx_path, args.work_dir, input_info, + output_names, mo_options) openvino_files.append(model_xml_path) backend_files = openvino_files @@ -281,7 +289,11 @@ def main(): assert is_available_pplnn(), \ 'PPLNN is not available, please install PPLNN first.' - from mmdeploy.apis.pplnn import onnx2pplnn + from mmdeploy.apis.pplnn import from_onnx + + pplnn_pipeline_funcs = [from_onnx] + PIPELINE_MANAGER.set_log_level(logging.INFO, pplnn_pipeline_funcs) + pplnn_files = [] for onnx_path in ir_files: algo_file = onnx_path.replace('.onnx', '.json') @@ -291,12 +303,12 @@ def main(): # PPLNN accepts only 1 input shape for optimization, # may get changed in the future input_shapes = [model_inputs.opt_shape] - create_process( - f'onnx2pplnn with {onnx_path}', - target=onnx2pplnn, - args=(algo_file, onnx_path), - kwargs=dict(device=args.device, input_shapes=input_shapes), - ret_value=ret_value) + algo_prefix = osp.splitext(algo_file)[0] + from_onnx( + onnx_path, + algo_prefix, + device=args.device, + input_shapes=input_shapes) pplnn_files += [onnx_path, algo_file] backend_files = pplnn_files diff --git a/tools/extract.py b/tools/extract.py index a98d59265..89d0df8c0 100644 --- a/tools/extract.py +++ b/tools/extract.py @@ -6,7 +6,7 @@ import onnx import onnx.helper -from mmdeploy.apis import extract_model +from mmdeploy.apis.onnx import extract_partition from mmdeploy.utils import get_root_logger @@ -53,7 +53,7 @@ def main(): marks = collect_avaiable_marks(model) logger.info('Available marks:\n {}'.format('\n '.join(marks))) - extracted_model = extract_model(model, args.start, args.end) + extracted_model = extract_partition(model, args.start, args.end) if osp.splitext(args.output_model)[-1] != '.onnx': args.output_model += '.onnx' diff --git a/tools/onnx2ncnn.py b/tools/onnx2ncnn.py index 0bddd6e03..9c17bcab3 100644 --- a/tools/onnx2ncnn.py +++ b/tools/onnx2ncnn.py @@ -2,15 +2,14 @@ import argparse import logging -from mmdeploy.apis.ncnn import onnx2ncnn +from mmdeploy.apis.ncnn import from_onnx from mmdeploy.utils import get_root_logger def parse_args(): parser = argparse.ArgumentParser(description='Convert ONNX to ncnn.') parser.add_argument('onnx_path', help='ONNX model path') - parser.add_argument('output_param', help='output ncnn param path') - parser.add_argument('output_bin', help='output bin path') + parser.add_argument('output_prefix', help='output ncnn model path') parser.add_argument( '--log-level', help='set log level', @@ -26,12 +25,11 @@ def main(): logger = get_root_logger(log_level=args.log_level) onnx_path = args.onnx_path - output_param = args.output_param - output_bin = args.output_bin + output_prefix = args.output_prefix logger.info(f'onnx2ncnn: \n\tonnx_path: {onnx_path} ') try: - onnx2ncnn(onnx_path, output_param, output_bin) + from_onnx(onnx_path, output_prefix) logger.info('onnx2ncnn success.') except Exception as e: logger.error(e) diff --git a/tools/onnx2pplnn.py b/tools/onnx2pplnn.py index 6046132a5..5a26a4487 100644 --- a/tools/onnx2pplnn.py +++ b/tools/onnx2pplnn.py @@ -3,7 +3,7 @@ import collections import logging -from mmdeploy.apis.pplnn import onnx2pplnn +from mmdeploy.apis.pplnn import from_onnx from mmdeploy.utils import get_root_logger @@ -11,7 +11,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='Convert ONNX to PPLNN.') parser.add_argument('onnx_path', help='ONNX model path') parser.add_argument( - 'output_path', help='output PPLNN algorithm path in json format') + 'output_prefix', help='output PPLNN algorithm prefix in json format') parser.add_argument( '--device', help='`the device of model during conversion', @@ -36,7 +36,7 @@ def main(): logger = get_root_logger(log_level=args.log_level) onnx_path = args.onnx_path - output_path = args.output_path + output_prefix = args.output_prefix device = args.device input_shapes = eval(args.opt_shapes) @@ -50,10 +50,10 @@ def main(): input_shapes = [input_shapes] logger.info(f'onnx2ppl: \n\tonnx_path: {onnx_path} ' - f'\n\toutput_path: {output_path}' + f'\n\toutput_prefix: {output_prefix}' f'\n\topt_shapes: {input_shapes}') try: - onnx2pplnn(output_path, onnx_path, device, input_shapes) + from_onnx(onnx_path, output_prefix, device, input_shapes) logger.info('onnx2tpplnn success.') except Exception as e: logger.error(e) diff --git a/tools/onnx2tensorrt.py b/tools/onnx2tensorrt.py index c852afd51..78a4558cc 100644 --- a/tools/onnx2tensorrt.py +++ b/tools/onnx2tensorrt.py @@ -2,7 +2,7 @@ import argparse import logging -from mmdeploy.backend.tensorrt import create_trt_engine, save_trt_engine +from mmdeploy.backend.tensorrt import from_onnx from mmdeploy.backend.tensorrt.utils import get_trt_log_level from mmdeploy.utils import (get_common_config, get_model_inputs, get_root_logger, load_config) @@ -12,7 +12,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='Convert ONNX to TensorRT.') parser.add_argument('deploy_cfg', help='deploy config path') parser.add_argument('onnx_path', help='ONNX model path') - parser.add_argument('output', help='output TensorRT engine path') + parser.add_argument('output_prefix', help='output TensorRT engine prefix') parser.add_argument('--device-id', help='`the CUDA device id', default=0) parser.add_argument( '--calib-file', @@ -35,7 +35,7 @@ def main(): deploy_cfg_path = args.deploy_cfg deploy_cfg = load_config(deploy_cfg_path)[0] onnx_path = args.onnx_path - output_path = args.output + output_prefix = args.output_prefix device_id = args.device_id calib_file = args.calib_file @@ -56,8 +56,9 @@ def main(): logger.info(f'onnx2tensorrt: \n\tonnx_path: {onnx_path} ' f'\n\tdeploy_cfg: {deploy_cfg_path}') try: - engine = create_trt_engine( + from_onnx( onnx_path, + output_prefix, input_shapes=final_params['input_shapes'], log_level=get_trt_log_level(), fp16_mode=final_params.get('fp16_mode', False), @@ -66,7 +67,6 @@ def main(): max_workspace_size=final_params.get('max_workspace_size', 0), device_id=device_id) - save_trt_engine(engine, output_path) logger.info('onnx2tensorrt success.') except Exception as e: logger.error(e)