Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 128 additions & 4 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
"""


import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import numpy
from pydantic import BaseModel
from pydantic import BaseModel, Field

from deepsparse import Engine, Scheduler
from deepsparse.benchmark import ORTEngine
Expand All @@ -34,6 +36,7 @@
"ORT_ENGINE",
"SUPPORTED_PIPELINE_ENGINES",
"Pipeline",
"PipelineConfig",
]


Expand Down Expand Up @@ -171,7 +174,7 @@ def create(
input_shapes: List[List[int]] = None,
alias: Optional[str] = None,
**kwargs,
):
) -> "Pipeline":
"""
:param task: name of task to create a pipeline for
:param model_path: path on local system or SparseZoo stub to load the model
Expand All @@ -185,10 +188,10 @@ def create(
Pass None for the default
:param input_shapes: list of shapes to set ONNX the inputs to. Pass None
to use model as-is. Default is None
:param kwargs: extra task specific kwargs to be passed to task Pipeline
implementation
:param alias: optional name to give this pipeline instance, useful when
inferencing with multiple models. Default is None
:param kwargs: extra task specific kwargs to be passed to task Pipeline
implementation
:return: pipeline object initialized for the given task
"""
task = task.lower().replace("-", "_")
Expand Down Expand Up @@ -264,6 +267,34 @@ def _register_pipeline_tasks_decorator(pipeline_class: Pipeline):

return _register_pipeline_tasks_decorator

@classmethod
def from_config(cls, config: Union["PipelineConfig", str, Path]) -> "Pipeline":
"""
:param config: PipelineConfig object, filepath to a json serialized
PipelineConfig, or raw string of a json serialized PipelineConfig
:return: loaded Pipeline object from the config
"""
if isinstance(config, Path) or (
isinstance(config, str) and os.path.exists(config)
):
if isinstance(config, str):
config = Path(config)
config = PipelineConfig.parse_file(config)
if isinstance(config, str):
config = PipelineConfig.parse_raw(config)

return cls.create(
task=config.task,
model_path=config.model_path,
engine_type=config.engine_type,
batch_size=config.batch_size,
num_cores=config.num_cores,
scheduler=config.scheduler,
input_shapes=config.input_shapes,
alias=config.alias,
**config.kwargs,
)

@abstractmethod
def setup_onnx_file_path(self) -> str:
"""
Expand Down Expand Up @@ -363,6 +394,37 @@ def onnx_file_path(self) -> str:
"""
return self._onnx_file_path

def to_config(self) -> "PipelineConfig":
"""
:return: PipelineConfig that can be used to reload this object
"""

if not hasattr(self, "task"):
raise RuntimeError(
f"{self.__class__} instance has no attribute task. Pipeline objects "
"must have a task to be serialized to a config. Pipeline objects "
"must be declared with the Pipeline.register object to be assigned a "
"task"
)

# parse any additional properties as kwargs
kwargs = {}
for attr_name, attr in self.__class__.__dict__.items():
if isinstance(attr, property) and attr_name not in dir(PipelineConfig):
kwargs[attr_name] = getattr(self, attr_name)

return PipelineConfig(
task=self.task,
model_path=self.model_path_orig,
engine_type=self.engine_type,
batch_size=self.batch_size,
num_cores=self.num_cores,
scheduler=self.scheduler,
input_shapes=self.input_shapes,
alias=self.alias,
kwargs=kwargs,
)

def _initialize_engine(self) -> Union[Engine, ORTEngine]:
engine_type = self.engine_type.lower()

Expand All @@ -375,3 +437,65 @@ def _initialize_engine(self) -> Union[Engine, ORTEngine]:
f"Unknown engine_type {self.engine_type}. Supported values include: "
f"{SUPPORTED_PIPELINE_ENGINES}"
)


class PipelineConfig(BaseModel):
"""
Configuration for creating a Pipeline object

Can be used to create a Pipeline from a config object or file with
Pipeline.from_config(), or used as a building block for other configs
such as for deepsparse.server
"""

task: str = Field(
description="name of task to create a pipeline for",
)
model_path: str = Field(
description="path on local system or SparseZoo stub to load the model from",
)
engine_type: str = Field(
default=DEEPSPARSE_ENGINE,
description=(
"inference engine to use. Currently supported values include "
"'deepsparse' and 'onnxruntime'. Default is 'deepsparse'"
),
)
batch_size: int = Field(
default=1,
description=("static batch size to use for inference. Default is 1"),
)
num_cores: int = Field(
default=None,
description=(
"number of CPU cores to allocate for inference engine. None"
"specifies all available cores. Default is None"
),
)
scheduler: str = Field(
default="async",
description=(
"(deepsparse only) kind of scheduler to execute with. Defaults to async"
),
)
input_shapes: List[List[int]] = Field(
default=None,
description=(
"list of shapes to set ONNX the inputs to. Pass None to use model as-is. "
"Default is None"
),
)
alias: str = Field(
default=None,
description=(
"optional name to give this pipeline instance, useful when inferencing "
"with multiple models. Default is None"
),
)
kwargs: Dict[str, Any] = Field(
default={},
description=(
"Additional arguments for inference with the model that will be passed "
"into the pipeline as kwargs"
),
)
72 changes: 6 additions & 66 deletions src/deepsparse/server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@
import json
import os
from functools import lru_cache
from typing import Any, Dict, List
from typing import List

import yaml
from pydantic import BaseModel, Field

from deepsparse import PipelineConfig
from deepsparse.cpu import cpu_architecture


__all__ = [
"ENV_DEEPSPARSE_SERVER_CONFIG",
"ENV_SINGLE_PREFIX",
"ServeModelConfig",
"ServerConfig",
]

Expand All @@ -39,75 +39,15 @@
ENV_SINGLE_PREFIX = "DEEPSPARSE_SINGLE_MODEL:"


class ServeModelConfig(BaseModel):
"""
Configuration for serving a model for a given task in the DeepSparse server
"""

task: str = Field(
description=(
"The task the model_path is serving. For example, one of: "
"question_answering, text_classification, token_classification."
),
)
model_path: str = Field(
description=(
"The path to a model.onnx file, "
"a model folder containing the model.onnx and supporting files, "
"or a SparseZoo model stub."
),
)
batch_size: int = Field(
default=1,
description=(
"The batch size to instantiate the model with and use for serving"
),
)
alias: str = Field(
default=None,
description=(
"Alias name for model pipeline to be served. A convenience route of "
"/predict/alias will be added to the server if present. "
),
)
kwargs: Dict[str, Any] = Field(
default={},
description=(
"Additional arguments for inference with the model that will be passed "
"into the pipeline as kwargs"
),
)
engine: str = Field(
default="deepsparse",
description=(
"The engine to use for serving the models such as deepsparse or onnxruntime"
),
)
num_cores: int = Field(
default=None,
description=(
"The number of physical cores to restrict the DeepSparse Engine to. "
"Defaults to all cores."
),
)
scheduler: str = Field(
default="async",
description=(
"The scheduler to use with the DeepSparse Engine such as sync or async. "
"Defaults to async"
),
)


class ServerConfig(BaseModel):
"""
A configuration for serving models in the DeepSparse inference server
"""

models: List[ServeModelConfig] = Field(
models: List[PipelineConfig] = Field(
default=[],
description=(
"The models to serve in the server defined by the additional arguments"
"The models to serve in the server defined by PipelineConfig objects"
),
)
workers: str = Field(
Expand Down Expand Up @@ -140,7 +80,7 @@ def server_config_from_env(env_key: str = ENV_DEEPSPARSE_SERVER_CONFIG):
config_dict = json.loads(config_file.replace(ENV_SINGLE_PREFIX, ""))
config = ServerConfig()
config.models.append(
ServeModelConfig(
PipelineConfig(
task=config_dict["task"],
model_path=config_dict["model_path"],
batch_size=config_dict["batch_size"],
Expand All @@ -150,7 +90,7 @@ def server_config_from_env(env_key: str = ENV_DEEPSPARSE_SERVER_CONFIG):
with open(config_file) as file:
config_dict = yaml.safe_load(file.read())
config_dict["models"] = (
[ServeModelConfig(**model) for model in config_dict["models"]]
[PipelineConfig(**model) for model in config_dict["models"]]
if "models" in config_dict
else []
)
Expand Down
33 changes: 17 additions & 16 deletions src/deepsparse/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@

import click

from deepsparse import Pipeline
from deepsparse.log import set_logging_level
from deepsparse.server.asynchronous import execute_async, initialize_aysnc
from deepsparse.server.config import (
ServerConfig,
server_config_from_env,
server_config_to_env,
)
from deepsparse.server.pipelines import load_pipelines_definitions
from deepsparse.server.utils import serializable_response
from deepsparse.version import version

Expand Down Expand Up @@ -123,29 +123,30 @@ def _home():
_LOGGER.info("created general routes, visit `/docs` to view available")


def _add_pipeline_route(app, pipeline_def, num_models: int, defined_tasks: set):
def _add_pipeline_route(app, pipeline: Pipeline, num_models: int, defined_tasks: set):
path = "/predict"

if pipeline_def.config.alias:
path = f"/predict/{pipeline_def.config.alias}"
if pipeline.alias:
path = f"/predict/{pipeline.alias}"
elif num_models > 1:
if pipeline_def.config.task in defined_tasks:
if pipeline.task in defined_tasks:
raise ValueError(
f"Multiple tasks defined for {pipeline_def.config.task} and no alias "
f"given for {pipeline_def.config}. "
f"Multiple tasks defined for {pipeline.task} and no alias "
f"given for pipeline with model {pipeline.model_path_orig}. "
"Either define an alias or supply a single model for the task"
)
path = f"/predict/{pipeline_def.config.task}"
defined_tasks.add(pipeline_def.config.task)
path = f"/predict/{pipeline.task}"
defined_tasks.add(pipeline.task)

@app.post(
path,
response_model=pipeline_def.response_model,
response_model=pipeline.output_model,
tags=["prediction"],
)
async def _predict_func(request: pipeline_def.request_model):
async def _predict_func(request: pipeline.input_model):
results = await execute_async(
pipeline_def.pipeline, **vars(request), **pipeline_def.kwargs
pipeline,
**vars(request),
)
return serializable_response(results)

Expand All @@ -167,12 +168,12 @@ def server_app_factory():
_LOGGER.debug("loaded server config %s", config)
_add_general_routes(app, config)

pipeline_defs = load_pipelines_definitions(config)
_LOGGER.debug("loaded pipeline definitions from config %s", pipeline_defs)
pipelines = [Pipeline.from_config(model_config) for model_config in config.models]
_LOGGER.debug("loaded pipeline definitions from config %s", pipelines)
num_tasks = len(config.models)
defined_tasks = set()
for pipeline_def in pipeline_defs:
_add_pipeline_route(app, pipeline_def, num_tasks, defined_tasks)
for pipeline in pipelines:
_add_pipeline_route(app, pipeline, num_tasks, defined_tasks)

return app

Expand Down
Loading