From 07706e036e4ca129f771dd875cf79b84ed3ccd83 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 13 Apr 2022 17:59:51 -0400 Subject: [PATCH 1/4] PipelineConfig pydantic model + Pipeline.from_config --- src/deepsparse/pipeline.py | 101 +++++++++++++++++++++++++++++++++++-- 1 file changed, 97 insertions(+), 4 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 4ffca2af4a..93357094c9 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -18,11 +18,13 @@ """ +import os from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Dict, List, Optional, Union import numpy -from pydantic import BaseModel +from pydantic import BaseModel, Field from deepsparse import Engine, Scheduler from deepsparse.benchmark import ORTEngine @@ -34,6 +36,7 @@ "ORT_ENGINE", "SUPPORTED_PIPELINE_ENGINES", "Pipeline", + "PipelineConfig", ] @@ -171,7 +174,7 @@ def create( input_shapes: List[List[int]] = None, alias: Optional[str] = None, **kwargs, - ): + ) -> "Pipeline": """ :param task: name of task to create a pipeline for :param model_path: path on local system or SparseZoo stub to load the model @@ -185,10 +188,10 @@ def create( Pass None for the default :param input_shapes: list of shapes to set ONNX the inputs to. Pass None to use model as-is. Default is None - :param kwargs: extra task specific kwargs to be passed to task Pipeline - implementation :param alias: optional name to give this pipeline instance, useful when inferencing with multiple models. Default is None + :param kwargs: extra task specific kwargs to be passed to task Pipeline + implementation :return: pipeline object initialized for the given task """ task = task.lower().replace("-", "_") @@ -264,6 +267,34 @@ def _register_pipeline_tasks_decorator(pipeline_class: Pipeline): return _register_pipeline_tasks_decorator + @classmethod + def from_config(cls, config: Union["PipelineConfig", str, Path]) -> "Pipeline": + """ + :param config: PipelineConfig object, filepath to a json serialized + PipelineConfig, or raw string of a json serialized PipelineConfig + :return: loaded Pipeline object from the config + """ + if isinstance(config, Path) or ( + isinstance(config, str) and os.path.exists(config) + ): + if isinstance(config, str): + config = Path(config) + config = PipelineConfig.parse_file(config) + if isinstance(config, str): + config = PipelineConfig.parse_raw(config) + + return cls.create( + task=config.task, + model_path=config.model_path, + engine_type=config.engine_type, + batch_size=config.batch_size, + num_cores=config.num_cores, + scheduler=config.scheduler, + input_shapes=config.input_shapes, + alias=config.alias, + **config.kwargs, + ) + @abstractmethod def setup_onnx_file_path(self) -> str: """ @@ -375,3 +406,65 @@ def _initialize_engine(self) -> Union[Engine, ORTEngine]: f"Unknown engine_type {self.engine_type}. Supported values include: " f"{SUPPORTED_PIPELINE_ENGINES}" ) + + +class PipelineConfig(BaseModel): + """ + Configuration for creating a Pipeline object + + Can be used to create a Pipeline from a config object or file with + Pipeline.from_config(), or used as a building block for other configs + such as for deepsparse.server + """ + + task: str = Field( + description="name of task to create a pipeline for", + ) + model_path: str = Field( + description="path on local system or SparseZoo stub to load the model from", + ) + engine_type: str = Field( + default=DEEPSPARSE_ENGINE, + description=( + "inference engine to use. Currently supported values include " + "'deepsparse' and 'onnxruntime'. Default is 'deepsparse'" + ), + ) + batch_size: int = Field( + default=1, + description=("static batch size to use for inference. Default is 1"), + ) + num_cores: int = Field( + default=None, + description=( + "number of CPU cores to allocate for inference engine. None" + "specifies all available cores. Default is None" + ), + ) + scheduler: str = Field( + default="async", + description=( + "(deepsparse only) kind of scheduler to execute with. Defaults to async" + ), + ) + input_shapes: List[List[int]] = Field( + default=None, + description=( + "list of shapes to set ONNX the inputs to. Pass None to use model as-is. " + "Default is None" + ), + ) + alias: str = Field( + default=None, + description=( + "optional name to give this pipeline instance, useful when inferencing " + "with multiple models. Default is None" + ), + ) + kwargs: Dict[str, Any] = Field( + default={}, + description=( + "Additional arguments for inference with the model that will be passed " + "into the pipeline as kwargs" + ), + ) From 282710def4e68aa49f3c45e2b28049411a189fda Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 13 Apr 2022 18:42:53 -0400 Subject: [PATCH 2/4] Pipeline.to_config() function --- src/deepsparse/pipeline.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 93357094c9..6e210a721c 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -394,6 +394,37 @@ def onnx_file_path(self) -> str: """ return self._onnx_file_path + def to_config(self) -> "PipelineConfig": + """ + :return: PipelineConfig that can be used to reload this object + """ + + if not hasattr(self, "task"): + raise RuntimeError( + f"{self.__class__} instance has no attribute task. Pipeline objects " + "must have a task to be serialized to a config. Pipeline objects " + "must be declared with the Pipeline.register object to be assigned a " + "task" + ) + + # parse any additional properties as kwargs + kwargs = {} + for attr_name, attr in self.__class__.__dict__.items(): + if isinstance(attr, property) and attr_name not in dir(PipelineConfig): + kwargs[attr_name] = getattr(self, attr_name) + + return PipelineConfig( + task=self.task, + model_path=self.model_path_orig, + engine_type=self.engine_type, + batch_size=self.batch_size, + num_cores=self.num_cores, + scheduler=self.scheduler, + input_shapes=self.input_shapes, + alias=self.alias, + kwargs=kwargs, + ) + def _initialize_engine(self) -> Union[Engine, ORTEngine]: engine_type = self.engine_type.lower() From dcb0075c096cec6564c0d94943caa9ac62347b64 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 13 Apr 2022 19:04:43 -0400 Subject: [PATCH 3/4] refactor deepsparse.server to use deepsparse.Pipeline --- src/deepsparse/server/config.py | 72 ++---------------------- src/deepsparse/server/main.py | 33 +++++------ src/deepsparse/server/pipelines.py | 89 ------------------------------ 3 files changed, 23 insertions(+), 171 deletions(-) delete mode 100644 src/deepsparse/server/pipelines.py diff --git a/src/deepsparse/server/config.py b/src/deepsparse/server/config.py index d74ed10231..97c926ca45 100644 --- a/src/deepsparse/server/config.py +++ b/src/deepsparse/server/config.py @@ -19,18 +19,18 @@ import json import os from functools import lru_cache -from typing import Any, Dict, List +from typing import List import yaml from pydantic import BaseModel, Field +from deepsparse import PipelineConfig from deepsparse.cpu import cpu_architecture __all__ = [ "ENV_DEEPSPARSE_SERVER_CONFIG", "ENV_SINGLE_PREFIX", - "ServeModelConfig", "ServerConfig", ] @@ -39,75 +39,15 @@ ENV_SINGLE_PREFIX = "DEEPSPARSE_SINGLE_MODEL:" -class ServeModelConfig(BaseModel): - """ - Configuration for serving a model for a given task in the DeepSparse server - """ - - task: str = Field( - description=( - "The task the model_path is serving. For example, one of: " - "question_answering, text_classification, token_classification." - ), - ) - model_path: str = Field( - description=( - "The path to a model.onnx file, " - "a model folder containing the model.onnx and supporting files, " - "or a SparseZoo model stub." - ), - ) - batch_size: int = Field( - default=1, - description=( - "The batch size to instantiate the model with and use for serving" - ), - ) - alias: str = Field( - default=None, - description=( - "Alias name for model pipeline to be served. A convenience route of " - "/predict/alias will be added to the server if present. " - ), - ) - kwargs: Dict[str, Any] = Field( - default={}, - description=( - "Additional arguments for inference with the model that will be passed " - "into the pipeline as kwargs" - ), - ) - engine: str = Field( - default="deepsparse", - description=( - "The engine to use for serving the models such as deepsparse or onnxruntime" - ), - ) - num_cores: int = Field( - default=None, - description=( - "The number of physical cores to restrict the DeepSparse Engine to. " - "Defaults to all cores." - ), - ) - scheduler: str = Field( - default="async", - description=( - "The scheduler to use with the DeepSparse Engine such as sync or async. " - "Defaults to async" - ), - ) - - class ServerConfig(BaseModel): """ A configuration for serving models in the DeepSparse inference server """ - models: List[ServeModelConfig] = Field( + models: List[PipelineConfig] = Field( default=[], description=( - "The models to serve in the server defined by the additional arguments" + "The models to serve in the server defined by ServerConfig objects" ), ) workers: str = Field( @@ -140,7 +80,7 @@ def server_config_from_env(env_key: str = ENV_DEEPSPARSE_SERVER_CONFIG): config_dict = json.loads(config_file.replace(ENV_SINGLE_PREFIX, "")) config = ServerConfig() config.models.append( - ServeModelConfig( + PipelineConfig( task=config_dict["task"], model_path=config_dict["model_path"], batch_size=config_dict["batch_size"], @@ -150,7 +90,7 @@ def server_config_from_env(env_key: str = ENV_DEEPSPARSE_SERVER_CONFIG): with open(config_file) as file: config_dict = yaml.safe_load(file.read()) config_dict["models"] = ( - [ServeModelConfig(**model) for model in config_dict["models"]] + [PipelineConfig(**model) for model in config_dict["models"]] if "models" in config_dict else [] ) diff --git a/src/deepsparse/server/main.py b/src/deepsparse/server/main.py index 564bc5e42b..5512ea73bb 100644 --- a/src/deepsparse/server/main.py +++ b/src/deepsparse/server/main.py @@ -78,6 +78,7 @@ import click +from deepsparse import Pipeline from deepsparse.log import set_logging_level from deepsparse.server.asynchronous import execute_async, initialize_aysnc from deepsparse.server.config import ( @@ -85,7 +86,6 @@ server_config_from_env, server_config_to_env, ) -from deepsparse.server.pipelines import load_pipelines_definitions from deepsparse.server.utils import serializable_response from deepsparse.version import version @@ -123,29 +123,30 @@ def _home(): _LOGGER.info("created general routes, visit `/docs` to view available") -def _add_pipeline_route(app, pipeline_def, num_models: int, defined_tasks: set): +def _add_pipeline_route(app, pipeline: Pipeline, num_models: int, defined_tasks: set): path = "/predict" - if pipeline_def.config.alias: - path = f"/predict/{pipeline_def.config.alias}" + if pipeline.alias: + path = f"/predict/{pipeline.alias}" elif num_models > 1: - if pipeline_def.config.task in defined_tasks: + if pipeline.task in defined_tasks: raise ValueError( - f"Multiple tasks defined for {pipeline_def.config.task} and no alias " - f"given for {pipeline_def.config}. " + f"Multiple tasks defined for {pipeline.task} and no alias " + f"given for pipeline with model {pipeline.model_path_orig}. " "Either define an alias or supply a single model for the task" ) - path = f"/predict/{pipeline_def.config.task}" - defined_tasks.add(pipeline_def.config.task) + path = f"/predict/{pipeline.task}" + defined_tasks.add(pipeline.task) @app.post( path, - response_model=pipeline_def.response_model, + response_model=pipeline.output_model, tags=["prediction"], ) - async def _predict_func(request: pipeline_def.request_model): + async def _predict_func(request: pipeline.input_model): results = await execute_async( - pipeline_def.pipeline, **vars(request), **pipeline_def.kwargs + pipeline, + **vars(request), ) return serializable_response(results) @@ -167,12 +168,12 @@ def server_app_factory(): _LOGGER.debug("loaded server config %s", config) _add_general_routes(app, config) - pipeline_defs = load_pipelines_definitions(config) - _LOGGER.debug("loaded pipeline definitions from config %s", pipeline_defs) + pipelines = [Pipeline.from_config(model_config) for model_config in config.models] + _LOGGER.debug("loaded pipeline definitions from config %s", pipelines) num_tasks = len(config.models) defined_tasks = set() - for pipeline_def in pipeline_defs: - _add_pipeline_route(app, pipeline_def, num_tasks, defined_tasks) + for pipeline in pipelines: + _add_pipeline_route(app, pipeline, num_tasks, defined_tasks) return app diff --git a/src/deepsparse/server/pipelines.py b/src/deepsparse/server/pipelines.py deleted file mode 100644 index ef07c68ca2..0000000000 --- a/src/deepsparse/server/pipelines.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Pipelines that run preprocessing, postprocessing, and model inference -within the DeepSparse model server. -""" - -from typing import Any, Dict, List - -from pydantic import BaseModel, Field - -from deepsparse.server.config import ServeModelConfig, ServerConfig -from deepsparse.tasks import SupportedTasks - - -__all__ = ["PipelineDefinition", "load_pipelines_definitions"] - - -class PipelineDefinition(BaseModel): - """ - A definition of a pipeline to be served by the model server. - Used to create a prediction route on construction of the server app. - """ - - pipeline: Any = Field(description="the callable pipeline to invoke on each request") - request_model: Any = Field( - description="the pydantic model to validate the request body with" - ) - response_model: Any = Field( - description="the pydantic model to validate the response payload with" - ) - kwargs: Dict[str, Any] = Field( - description="any additional kwargs that should be passed into the pipeline" - ) - config: ServeModelConfig = Field( - description="the config for the model the pipeline is serving" - ) - - -def load_pipelines_definitions(config: ServerConfig) -> List[PipelineDefinition]: - """ - Load the pipeline definitions to use for creating prediction routes from - the given server configuration. - - :param config: the configuration to load pipeline definitions for - :return: the loaded pipeline definitions to use for serving inference requests - """ - defs = [] - - for model_config in config.models: - if SupportedTasks.is_nlp(model_config.task): - # dynamically import so we don't install dependencies when unneeded - from deepsparse.transformers.server import create_pipeline_definitions - - ( - pipeline, - request_model, - response_model, - kwargs, - ) = create_pipeline_definitions(model_config) - else: - raise ValueError( - f"unsupported task given of {model_config.task} " - f"for serve model config {model_config}" - ) - - defs.append( - PipelineDefinition( - pipeline=pipeline, - request_model=request_model, - response_model=response_model, - kwargs=kwargs, - config=model_config, - ) - ) - - return defs From ec05a35e509d44522b6980a1a64bf8cb3f9e1dca Mon Sep 17 00:00:00 2001 From: Benjamin Date: Thu, 14 Apr 2022 15:03:00 -0400 Subject: [PATCH 4/4] review nit fix remove files for separate feature --- src/deepsparse/server/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/deepsparse/server/config.py b/src/deepsparse/server/config.py index 97c926ca45..fe526d5ec2 100644 --- a/src/deepsparse/server/config.py +++ b/src/deepsparse/server/config.py @@ -47,7 +47,7 @@ class ServerConfig(BaseModel): models: List[PipelineConfig] = Field( default=[], description=( - "The models to serve in the server defined by ServerConfig objects" + "The models to serve in the server defined by PipelineConfig objects" ), ) workers: str = Field(