From d606f3a68c1b572417f0444757e3f6be0ac48672 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 12 Apr 2022 17:31:27 -0400 Subject: [PATCH 1/6] Pipeline base class implementation --- src/deepsparse/pipeline.py | 206 +++++++++++++++++++++++++++++++++++++ src/deepsparse/tasks.py | 6 ++ 2 files changed, 212 insertions(+) create mode 100644 src/deepsparse/pipeline.py diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py new file mode 100644 index 0000000000..6721d3d6fd --- /dev/null +++ b/src/deepsparse/pipeline.py @@ -0,0 +1,206 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Classes and registry for end to end inference pipelines that wrap an underlying +inference engine and include pre/postprocessing +""" + + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Union + +import numpy +from pydantic import BaseModel + +from deepsparse import Engine +from deepsparse.benchmark import ORTEngine +from deepsparse.tasks import SupportedTasks + + +__all__ = [ + "DEEPSPARSE_ENGINE", + "ORT_ENGINE", + "SUPPORTED_PIPELINE_ENGINES", + "Pipeline", +] + + +DEEPSPARSE_ENGINE = "deepsparse" +ORT_ENGINE = "onnxruntime" + +SUPPORTED_PIPELINE_ENGINES = [DEEPSPARSE_ENGINE, ORT_ENGINE] + + +_REGISTERED_PIPELINES = {} + + +class Pipeline(ABC): + def __init__( + self, + model_path: str, + engine_type: str, + batch_size: int, + num_cores: int, + scheduler: Scheduler = None, + input_shapes: List[List[int]] = None, + ): + self._model_path_orig = model_path + self._model_path = model_path + self._engine_type = engine_type + + self._engine_args = dict( + batch_size=batch_size, + num_cores=num_cores, + input_shapes=input_shapes, + ) + if engine_type.lower() == DEEPSPARSE_ENGINE: + self._engine_args["scheduler"] = scheduler + + self._onnx_file_path = self.setup_onnx_file_path() + self._engine = self.initialize_engine() + pass + + def __call__(self, inputs: BaseModel) -> BaseModel: + engine_inputs: List[numpy.ndarray] = self.process_inputs(inputs) + engine_outputs: List[numpy.ndarray] = self.engine(engine_inputs) + return self.process_engine_outputs(engine_outputs) + + @staticmethod + def create( + task: str, + model_path: str, + engine_type: str, + batch_size: int, + num_cores: int, + scheduler: Scheduler = None, + input_shapes: List[List[int]] = None, + **kwargs, + ): + task = task.lower().replace("-", "_") + + # extra step to register pipelines for a given task domain + # for cases where imports should only happen once a user specifies + # that domain is to be used. (ie deepsparse.transformers will auto + # install extra packages so should only import and register once a + # transformers task is specified) + SupportedTasks.check_register_task(task) + + if task not in _REGISTERED_PIPELINES: + raise ValueError( + f"Unknown Pipeline task {task}. Pipeline tasks should be " + "must be declared with the Pipeline.register decorator. Currently " + f"registered pipelines: {list(_REGISTERED_PIPELINES.keys())}" + ) + + return _REGISTERED_PIPELINES[task]( + model_path=model_path, + engine_type=engine_type, + batch_size=batch_size, + num_cores=num_cores, + scheduler=scheduler, + input_shapes=input_shapes, + **kwargs, + ) + + @classmethod + def register(cls, task: str, task_aliases: Optional[List[str]]): + task_names = [task] + if task_aliases: + task_names.extend(task_aliases) + + def _register_task(task_name, pipeline_class): + if task_name in _REGISTERED_PIPELINES and ( + pipeline_class is not _REGISTERED_PIPELINES[task_name] + ): + raise RuntimeError( + f"task {task_name} already registered by Pipeline.register. " + f"attempting to register pipeline: {pipeline_class}, but" + f"pipeline: {_REGISTERED_PIPELINES[task_name]}, already registered" + ) + _REGISTERED_PIPELINES[task_name] = pipeline_class + + def decorator(pipeline_class: Pipeline): + if not issubclass(pipeline_class, cls): + raise RuntimeError( + f"Attempting to register pipeline pipeline_class. " + f"Registered pipelines must inherit from {cls}" + ) + for task_name in task_names: + _register_task(task_name, pipeline_class) + + # set task and task_aliases as class level property + pipeline_class.task = task + pipeline_class.task_aliases = task_aliases + + return decorator + + @abstractmethod + def setup_onnx_file_path(self) -> str: + raise NotImplementedError() + + @abstractmethod + def process_inputs(self, *args, **kwargs) -> List[numpy.ndarray]: + raise NotImplementedError() + + @abstractmethod + def process_engine_outputs(self, engine_outputs: List[numpy.ndarray]): + raise NotImplementedError() + + @property + @abstractmethod + def input_model(self) -> BaseModel: + raise NotImplementedError() + + @property + @abstractmethod + def output_model(self) -> BaseModel: + raise NotImplementedError() + + @property + def model_path_orig(self) -> str: + return self._model_path_orig + + @property + def model_path(self) -> str: + return self._model_path + + @property + def engine(self) -> Union[Engine, ORTEngine]: + return self._engine + + @property + def engine_args(self) -> Dict[str, Any]: + return self._engine_args + + @property + def engine_type(self) -> str: + return self._engine_type + + @property + def onnx_file_path(self) -> str: + return self._onnx_file_path + + def initialize_engine(self) -> Union[Engine, ORTEngine]: + engine_type = self.engine_type.lower() + + if engine_type == DEEPSPARSE_ENGINE: + return Engine(self.onnx_file_path, **self._engine_args) + elif engine_type == ORT_ENGINE: + return ORTEngine(self.onnx_file_path, **self._engine_args) + else: + raise ValueError( + f"Unknown engine_type {self.engine_type}. Supported values include: " + f"{SUPPORTED_PIPELINE_ENGINES}" + ) diff --git a/src/deepsparse/tasks.py b/src/deepsparse/tasks.py index 6ffaad7ec3..c8670707ca 100644 --- a/src/deepsparse/tasks.py +++ b/src/deepsparse/tasks.py @@ -78,6 +78,12 @@ class SupportedTasks: token_classification=AliasedTask("token_classification", ["ner"]), ) + @classmethod + def check_register_task(cls, task: str): + if cls.is_nlp(task): + # trigger transformers pipelines to register with Pipeline.register + import deepsparse.transformers.pipelines + @classmethod def is_nlp(cls, task: str) -> bool: """ From 222cd3dc56614bd5b0f0136733b9a0571109f569 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 13 Apr 2022 16:08:45 -0400 Subject: [PATCH 2/6] constructor default values --- src/deepsparse/pipeline.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 6721d3d6fd..4b833ad00e 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -50,9 +50,9 @@ class Pipeline(ABC): def __init__( self, model_path: str, - engine_type: str, - batch_size: int, - num_cores: int, + engine_type: str = DEEPSPARSE_ENGINE, + batch_size: int = 1, + num_cores: int = None, scheduler: Scheduler = None, input_shapes: List[List[int]] = None, ): @@ -81,9 +81,9 @@ def __call__(self, inputs: BaseModel) -> BaseModel: def create( task: str, model_path: str, - engine_type: str, - batch_size: int, - num_cores: int, + engine_type: str = DEEPSPARSE_ENGINE, + batch_size: int = 1, + num_cores: int = None, scheduler: Scheduler = None, input_shapes: List[List[int]] = None, **kwargs, From 63f1d775667d0ca65e5aeae92d6c43fc6aa97674 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 13 Apr 2022 16:22:06 -0400 Subject: [PATCH 3/6] __call__ inputs/outputs parsing + validation --- src/deepsparse/pipeline.py | 34 ++++++++++++++++++++++++++++------ src/deepsparse/tasks.py | 2 +- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 4b833ad00e..5d1d4d94c1 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -24,7 +24,7 @@ import numpy from pydantic import BaseModel -from deepsparse import Engine +from deepsparse import Engine, Scheduler from deepsparse.benchmark import ORTEngine from deepsparse.tasks import SupportedTasks @@ -72,10 +72,32 @@ def __init__( self._engine = self.initialize_engine() pass - def __call__(self, inputs: BaseModel) -> BaseModel: - engine_inputs: List[numpy.ndarray] = self.process_inputs(inputs) + def __call__(self, pipeline_inputs: BaseModel = None, **kwargs) -> BaseModel: + if pipeline_inputs is None and kwargs: + # parse kwarg inputs into the expected input format + pipeline_inputs = self.input_model(**kwargs) + + # validate inputs format + if not isinstance(pipeline_inputs, self.input_model): + raise ValueError( + f"Calling {self.__class__} requires passing inputs as an " + f"{self.input_model} object or a list of kwargs used to create " + f"a {self.input_model} object" + ) + + # run pipeline + engine_inputs: List[numpy.ndarray] = self.process_inputs(pipeline_inputs) engine_outputs: List[numpy.ndarray] = self.engine(engine_inputs) - return self.process_engine_outputs(engine_outputs) + pipeline_outputs = self.process_engine_outputs(engine_outputs) + + # validate outputs format + if not isinstance(pipeline_outputs, self.output_model): + raise ValueError( + f"Outputs of {self.__class__} must be instances of {self.output_model}" + f" found output of type {type(pipeline_outputs)}" + ) + + return pipeline_outputs @staticmethod def create( @@ -151,11 +173,11 @@ def setup_onnx_file_path(self) -> str: raise NotImplementedError() @abstractmethod - def process_inputs(self, *args, **kwargs) -> List[numpy.ndarray]: + def process_inputs(self, inputs: BaseModel) -> List[numpy.ndarray]: raise NotImplementedError() @abstractmethod - def process_engine_outputs(self, engine_outputs: List[numpy.ndarray]): + def process_engine_outputs(self, engine_outputs: List[numpy.ndarray]) -> BaseModel: raise NotImplementedError() @property diff --git a/src/deepsparse/tasks.py b/src/deepsparse/tasks.py index c8670707ca..4b24c6d16c 100644 --- a/src/deepsparse/tasks.py +++ b/src/deepsparse/tasks.py @@ -82,7 +82,7 @@ class SupportedTasks: def check_register_task(cls, task: str): if cls.is_nlp(task): # trigger transformers pipelines to register with Pipeline.register - import deepsparse.transformers.pipelines + import deepsparse.transformers.pipelines # noqa: F401 @classmethod def is_nlp(cls, task: str) -> bool: From b42c40bab4fd6ebe0cdd2be2b813b0a99e42f73a Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 13 Apr 2022 17:12:33 -0400 Subject: [PATCH 4/6] documentation --- src/deepsparse/pipeline.py | 135 ++++++++++++++++++++++++++++++++++++- 1 file changed, 133 insertions(+), 2 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 5d1d4d94c1..73708089a3 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -47,6 +47,63 @@ class Pipeline(ABC): + """ + Generic Pipeline abstract class meant to wrap inference engine objects to include + data pre/post-processing. Inputs and outputs of pipelines should be serialized + as pydantic Models. + + Pipelines should not be instantiated by their constructors, but rather the + `Pipeline.create()` method. The task name given to `create` will be used to + load the appropriate pipeline. When creating a Pipeline, the pipeline should + inherit from `Pipeline` and implement the `setup_onnx_file_path`, `process_inputs`, + `process_engine_outputs`, `input_model`, and `output_model` abstract methods. + + Finally, the class definition should be decorated by the `Pipeline.register` + function. This defines the task name and task aliases for the pipeline and + ensures that it will be accessible by `Pipeline.create`. The implemented + `Pipeline` subclass must be imported at runtime to be accessible. + + Pipeline lifecycle: + - On instantiation + * `onnx_file_path` <- `setup_onnx_file_path` + * `engine` <- `_initialize_engine` + + - on __call__: + * `pre_processed_inputs` <- `process_inputs(inputs: input_model)` + * `engine_outputs` <- `engine(pre_processed_inputs)` + * `outputs: output_model` <- `process_engine_outputs(engine_outputs)` + + Example use of register: + ```python + @Pipeline.register( + task="example_task", + task_aliases=["example_alias_1", "example_asias_2"], + ) + class PipelineImplementation(Pipeline): + # implementation of Pipeline abstract methods here + ``` + + Example use of pipeline: + ```python + example_pipeline = Pipeline.create( + task="example_task", + model_path="model.onnx", + ) + pipeline_outputs = example_pipeline(pipeline_inputs) + ``` + + :param model_path: path on local system or SparseZoo stub to load the model from + :param engine_type: inference engine to use. Currently supported values include + 'deepsparse' and 'onnxruntime'. Default is 'deepsparse' + :param batch_size: static batch size to use for inference. Default is 1 + :param num_cores: number of CPU cores to allocate for inference engine. None + specifies all available cores. Default is None + :param scheduler: (deepsparse only) kind of scheduler to execute with. + Pass None for the default + :param input_shapes: list of shapes to set ONNX the inputs to. Pass None + to use model as-is. Default is None + """ + def __init__( self, model_path: str, @@ -69,7 +126,7 @@ def __init__( self._engine_args["scheduler"] = scheduler self._onnx_file_path = self.setup_onnx_file_path() - self._engine = self.initialize_engine() + self._engine = self._initialize_engine() pass def __call__(self, pipeline_inputs: BaseModel = None, **kwargs) -> BaseModel: @@ -110,6 +167,23 @@ def create( input_shapes: List[List[int]] = None, **kwargs, ): + """ + :param task: name of task to create a pipeline for + :param model_path: path on local system or SparseZoo stub to load the model + from + :param engine_type: inference engine to use. Currently supported values + include 'deepsparse' and 'onnxruntime'. Default is 'deepsparse' + :param batch_size: static batch size to use for inference. Default is 1 + :param num_cores: number of CPU cores to allocate for inference engine. None + specifies all available cores. Default is None + :param scheduler: (deepsparse only) kind of scheduler to execute with. + Pass None for the default + :param input_shapes: list of shapes to set ONNX the inputs to. Pass None + to use model as-is. Default is None + :param kwargs: extra task specific kwargs to be passed to task Pipeline + implementation + :return: pipeline object initialized for the given task + """ task = task.lower().replace("-", "_") # extra step to register pipelines for a given task domain @@ -138,6 +212,18 @@ def create( @classmethod def register(cls, task: str, task_aliases: Optional[List[str]]): + """ + Pipeline implementer class decorator that registers the pipeline + task name and its aliases as valid tasks that can be used to load + the pipeline through `Pipeline.create()`. + + Multiple pipelines may not have the same task name. An error will + be raised if two different pipelines attempt to register the same task name + + :param task: main task name of this pipeline + :param task_aliases: list of extra task names that may be used to reference + this pipeline + """ task_names = [task] if task_aliases: task_names.extend(task_aliases) @@ -170,51 +256,96 @@ def decorator(pipeline_class: Pipeline): @abstractmethod def setup_onnx_file_path(self) -> str: + """ + Performs any setup to unwrap and process the given model_path and other + class properties into an inference ready onnx file to be compiled by the + engine of the pipeline + + :return: file path to the ONNX file for the engine to compile + """ raise NotImplementedError() @abstractmethod def process_inputs(self, inputs: BaseModel) -> List[numpy.ndarray]: + """ + :param inputs: inputs to the pipeline. Must be the type of the `input_model` + of this pipeline + :return: inputs of this model processed into a list of numpy arrays that + can be directly passed into the forward pass of the pipeline engine + """ raise NotImplementedError() @abstractmethod def process_engine_outputs(self, engine_outputs: List[numpy.ndarray]) -> BaseModel: + """ + :param engine_outputs: list of numpy arrays that are the output of the engine + forward pass + :return: outputs of engine post-processed into an object in the `output_model` + format of this pipeline + """ raise NotImplementedError() @property @abstractmethod def input_model(self) -> BaseModel: + """ + :return: pydantic model class that inputs to this pipeline must comply to + """ raise NotImplementedError() @property @abstractmethod def output_model(self) -> BaseModel: + """ + :return: pydantic model class that outputs of this pipeline must comply to + """ raise NotImplementedError() @property def model_path_orig(self) -> str: + """ + :return: value originally passed to the model_path argument to initialize this + Pipeline + """ return self._model_path_orig @property def model_path(self) -> str: + """ + :return: path on local system to the onnx file of this model or directory + containing a model.onnx file along with supporting files + """ return self._model_path @property def engine(self) -> Union[Engine, ORTEngine]: + """ + :return: engine instance used for model forward pass in pipeline + """ return self._engine @property def engine_args(self) -> Dict[str, Any]: + """ + :return: arguments besides onnx filepath used to instantiate engine + """ return self._engine_args @property def engine_type(self) -> str: + """ + :return: type of inference engine used for model forward pass + """ return self._engine_type @property def onnx_file_path(self) -> str: + """ + :return: onnx file path used to instantiate engine + """ return self._onnx_file_path - def initialize_engine(self) -> Union[Engine, ORTEngine]: + def _initialize_engine(self) -> Union[Engine, ORTEngine]: engine_type = self.engine_type.lower() if engine_type == DEEPSPARSE_ENGINE: From edf72bea69745957765ed7438142141adde929c3 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 13 Apr 2022 17:37:47 -0400 Subject: [PATCH 5/6] pipeline 'alias' argument --- src/deepsparse/pipeline.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 73708089a3..858691964b 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -102,6 +102,8 @@ class PipelineImplementation(Pipeline): Pass None for the default :param input_shapes: list of shapes to set ONNX the inputs to. Pass None to use model as-is. Default is None + :param alias: optional name to give this pipeline instance, useful when + inferencing with multiple models. Default is None """ def __init__( @@ -112,10 +114,12 @@ def __init__( num_cores: int = None, scheduler: Scheduler = None, input_shapes: List[List[int]] = None, + alias: Optional[str] = None, ): self._model_path_orig = model_path self._model_path = model_path self._engine_type = engine_type + self._alias = alias self._engine_args = dict( batch_size=batch_size, @@ -165,6 +169,7 @@ def create( num_cores: int = None, scheduler: Scheduler = None, input_shapes: List[List[int]] = None, + alias: Optional[str] = None, **kwargs, ): """ @@ -182,6 +187,8 @@ def create( to use model as-is. Default is None :param kwargs: extra task specific kwargs to be passed to task Pipeline implementation + :param alias: optional name to give this pipeline instance, useful when + inferencing with multiple models. Default is None :return: pipeline object initialized for the given task """ task = task.lower().replace("-", "_") @@ -207,6 +214,7 @@ def create( num_cores=num_cores, scheduler=scheduler, input_shapes=input_shapes, + alias=alias, **kwargs, ) @@ -301,6 +309,14 @@ def output_model(self) -> BaseModel: """ raise NotImplementedError() + @property + def alias(self) -> str: + """ + :return: optional name to give this pipeline instance, useful when + inferencing with multiple models + """ + return self._alias + @property def model_path_orig(self) -> str: """ From 5f5707db80630cf2c801b4987aeaeefec611d2e7 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Thu, 14 Apr 2022 12:58:27 -0400 Subject: [PATCH 6/6] review fixes --- src/deepsparse/pipeline.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 858691964b..4ffca2af4a 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -247,7 +247,7 @@ def _register_task(task_name, pipeline_class): ) _REGISTERED_PIPELINES[task_name] = pipeline_class - def decorator(pipeline_class: Pipeline): + def _register_pipeline_tasks_decorator(pipeline_class: Pipeline): if not issubclass(pipeline_class, cls): raise RuntimeError( f"Attempting to register pipeline pipeline_class. " @@ -260,12 +260,14 @@ def decorator(pipeline_class: Pipeline): pipeline_class.task = task pipeline_class.task_aliases = task_aliases - return decorator + return pipeline_class + + return _register_pipeline_tasks_decorator @abstractmethod def setup_onnx_file_path(self) -> str: """ - Performs any setup to unwrap and process the given model_path and other + Performs any setup to unwrap and process the given `model_path` and other class properties into an inference ready onnx file to be compiled by the engine of the pipeline @@ -320,8 +322,8 @@ def alias(self) -> str: @property def model_path_orig(self) -> str: """ - :return: value originally passed to the model_path argument to initialize this - Pipeline + :return: value originally passed to the `model_path` argument to initialize + this Pipeline """ return self._model_path_orig