From 41ab7e352c1e5c765a1d9d6ab625cb5ccbf1ee22 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 29 Apr 2022 16:31:48 -0400 Subject: [PATCH 1/2] rename input/output _model to _schema --- .../image_classification/pipelines.py | 8 ++-- src/deepsparse/pipeline.py | 42 +++++++++---------- .../pipelines/question_answering.py | 10 ++--- .../pipelines/text_classification.py | 24 +++++------ .../pipelines/token_classification.py | 24 +++++------ 5 files changed, 54 insertions(+), 54 deletions(-) diff --git a/src/deepsparse/image_classification/pipelines.py b/src/deepsparse/image_classification/pipelines.py index b909dd12f1..aa28c9aa60 100644 --- a/src/deepsparse/image_classification/pipelines.py +++ b/src/deepsparse/image_classification/pipelines.py @@ -89,14 +89,14 @@ def class_names(self) -> Optional[Dict[str, str]]: return self._class_names @property - def input_model(self) -> Type[ImageClassificationInput]: + def input_schema(self) -> Type[ImageClassificationInput]: """ :return: pydantic model class that inputs to this pipeline must comply to """ return ImageClassificationInput @property - def output_model(self) -> Type[ImageClassificationOutput]: + def output_schema(self) -> Type[ImageClassificationOutput]: """ :return: pydantic model class that outputs of this pipeline must comply to """ @@ -164,7 +164,7 @@ def process_engine_outputs( """ :param engine_outputs: list of numpy arrays that are the output of the engine forward pass - :return: outputs of engine post-processed into an object in the `output_model` + :return: outputs of engine post-processed into an object in the `output_schema` format of this pipeline """ labels = numpy.argmax(engine_outputs[0], axis=1).tolist() @@ -172,7 +172,7 @@ def process_engine_outputs( if self.class_names is not None: labels = [self.class_names[str(class_id)] for class_id in labels] - return ImageClassificationOutput( + return self.output_schema( scores=numpy.max(engine_outputs[0], axis=1).tolist(), labels=labels, ) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 155744fd90..dc30927613 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -59,7 +59,7 @@ class Pipeline(ABC): `Pipeline.create()` method. The task name given to `create` will be used to load the appropriate pipeline. When creating a Pipeline, the pipeline should inherit from `Pipeline` and implement the `setup_onnx_file_path`, `process_inputs`, - `process_engine_outputs`, `input_model`, and `output_model` abstract methods. + `process_engine_outputs`, `input_schema`, and `output_schema` abstract methods. Finally, the class definition should be decorated by the `Pipeline.register` function. This defines the task name and task aliases for the pipeline and @@ -72,10 +72,10 @@ class Pipeline(ABC): * `engine` <- `_initialize_engine` - on __call__: - * `parsed_inputs: input_model` <- `parse_inputs(*args, **kwargs)` + * `parsed_inputs: input_schema` <- `parse_inputs(*args, **kwargs)` * `pre_processed_inputs` <- `process_inputs(parsed_inputs)` * `engine_outputs` <- `engine(pre_processed_inputs)` - * `outputs: output_model` <- `process_engine_outputs(engine_outputs)` + * `outputs: output_schema` <- `process_engine_outputs(engine_outputs)` Example use of register: ```python @@ -137,12 +137,12 @@ def __init__( self.engine = self._initialize_engine() def __call__(self, *args, **kwargs) -> BaseModel: - # parse inputs into input_model schema if necessary + # parse inputs into input_schema schema if necessary pipeline_inputs = self.parse_inputs(*args, **kwargs) - if not isinstance(pipeline_inputs, self.input_model): + if not isinstance(pipeline_inputs, self.input_schema): raise RuntimeError( f"Unable to parse {self.__class__} inputs into a " - f"{self.input_model} object. Inputs parsed to {type(pipeline_inputs)}" + f"{self.input_schema} object. Inputs parsed to {type(pipeline_inputs)}" ) # run pipeline @@ -159,10 +159,10 @@ def __call__(self, *args, **kwargs) -> BaseModel: ) # validate outputs format - if not isinstance(pipeline_outputs, self.output_model): + if not isinstance(pipeline_outputs, self.output_schema): raise ValueError( - f"Outputs of {self.__class__} must be instances of {self.output_model}" - f" found output of type {type(pipeline_outputs)}" + f"Outputs of {self.__class__} must be instances of " + f"{self.output_schema} found output of type {type(pipeline_outputs)}" ) return pipeline_outputs @@ -316,7 +316,7 @@ def process_inputs( inputs: BaseModel, ) -> Union[List[numpy.ndarray], Tuple[List[numpy.ndarray], Dict[str, Any]]]: """ - :param inputs: inputs to the pipeline. Must be the type of the `input_model` + :param inputs: inputs to the pipeline. Must be the type of the `input_schema` of this pipeline :return: inputs of this model processed into a list of numpy arrays that can be directly passed into the forward pass of the pipeline engine. Can @@ -335,14 +335,14 @@ def process_engine_outputs( """ :param engine_outputs: list of numpy arrays that are the output of the engine forward pass - :return: outputs of engine post-processed into an object in the `output_model` + :return: outputs of engine post-processed into an object in the `output_schema` format of this pipeline """ raise NotImplementedError() @property @abstractmethod - def input_model(self) -> Type[BaseModel]: + def input_schema(self) -> Type[BaseModel]: """ :return: pydantic model class that inputs to this pipeline must comply to """ @@ -350,7 +350,7 @@ def input_model(self) -> Type[BaseModel]: @property @abstractmethod - def output_model(self) -> Type[BaseModel]: + def output_schema(self) -> Type[BaseModel]: """ :return: pydantic model class that outputs of this pipeline must comply to """ @@ -427,25 +427,25 @@ def to_config(self) -> "PipelineConfig": def parse_inputs(self, *args, **kwargs) -> BaseModel: """ - :param args: ordered arguments to pipeline, only an input_model object + :param args: ordered arguments to pipeline, only an input_schema object is supported as an arg for this function :param kwargs: keyword arguments to pipeline - :return: pipeline arguments parsed into the given `input_model` - schema if necessary. If an instance of the `input_model` is provided + :return: pipeline arguments parsed into the given `input_schema` + schema if necessary. If an instance of the `input_schema` is provided it will be returned """ - # passed input_model schema directly - if len(args) == 1 and isinstance(args[0], self.input_model) and not kwargs: + # passed input_schema schema directly + if len(args) == 1 and isinstance(args[0], self.input_schema) and not kwargs: return args[0] if args: raise ValueError( f"pipeline {self.__class__} only supports either only a " - f"{self.input_model} object. or keyword arguments to be construct one. " - f"Found {len(args)} args and {len(kwargs)} kwargs" + f"{self.input_schema} object. or keyword arguments to be construct " + f"one. Found {len(args)} args and {len(kwargs)} kwargs" ) - return self.input_model(**kwargs) + return self.input_schema(**kwargs) def _initialize_engine(self) -> Union[Engine, ORTEngine]: engine_type = self.engine_type.lower() diff --git a/src/deepsparse/transformers/pipelines/question_answering.py b/src/deepsparse/transformers/pipelines/question_answering.py index f15f3ba45d..125e2badd6 100644 --- a/src/deepsparse/transformers/pipelines/question_answering.py +++ b/src/deepsparse/transformers/pipelines/question_answering.py @@ -169,14 +169,14 @@ def max_question_length(self) -> int: return self._max_question_length @property - def input_model(self) -> Type[BaseModel]: + def input_schema(self) -> Type[BaseModel]: """ :return: pydantic model class that inputs to this pipeline must comply to """ return QuestionAnsweringInput @property - def output_model(self) -> Type[BaseModel]: + def output_schema(self) -> Type[BaseModel]: """ :return: pydantic model class that outputs of this pipeline must comply to """ @@ -214,7 +214,7 @@ def process_engine_outputs( """ :param engine_outputs: list of numpy arrays that are the output of the engine forward pass - :return: outputs of engine post-processed into an object in the `output_model` + :return: outputs of engine post-processed into an object in the `output_schema` format of this pipeline """ features = kwargs["features"] @@ -258,7 +258,7 @@ def process_engine_outputs( # decode start, end idx into text if not self.tokenizer.is_fast: char_to_word = numpy.array(example.char_to_word_offset) - return self.output_model( + return self.output_schema( score=score.item(), start=numpy.where( char_to_word == features.token_to_orig_map[ans_start] @@ -281,7 +281,7 @@ def process_engine_outputs( # Sometimes the max probability token is in the middle of a word so: # we start by finding the right word containing the token with # `token_to_word` then we convert this word in a character span - return self.output_model( + return self.output_schema( score=score.item(), start=features.encoding.word_to_chars( features.encoding.token_to_word(ans_start), diff --git a/src/deepsparse/transformers/pipelines/text_classification.py b/src/deepsparse/transformers/pipelines/text_classification.py index 44449b5c46..2e37605c56 100644 --- a/src/deepsparse/transformers/pipelines/text_classification.py +++ b/src/deepsparse/transformers/pipelines/text_classification.py @@ -132,14 +132,14 @@ class TextClassificationPipeline(TransformersPipeline): """ @property - def input_model(self) -> Type[BaseModel]: + def input_schema(self) -> Type[BaseModel]: """ :return: pydantic model class that inputs to this pipeline must comply to """ return TextClassificationInput @property - def output_model(self) -> Type[BaseModel]: + def output_schema(self) -> Type[BaseModel]: """ :return: pydantic model class that outputs of this pipeline must comply to """ @@ -147,11 +147,11 @@ def output_model(self) -> Type[BaseModel]: def parse_inputs(self, *args, **kwargs) -> BaseModel: """ - :param args: ordered arguments to pipeline, only an input_model object + :param args: ordered arguments to pipeline, only an input_schema object is supported as an arg for this function :param kwargs: keyword arguments to pipeline - :return: pipeline arguments parsed into the given `input_model` - schema if necessary. If an instance of the `input_model` is provided + :return: pipeline arguments parsed into the given `input_schema` + schema if necessary. If an instance of the `input_schema` is provided it will be returned """ if args and kwargs: @@ -162,14 +162,14 @@ def parse_inputs(self, *args, **kwargs) -> BaseModel: if args: if len(args) == 1: - # passed input_model schema directly - if isinstance(args[0], self.input_model): + # passed input_schema schema directly + if isinstance(args[0], self.input_schema): return args[0] - return self.input_model(sequences=args[0]) + return self.input_schema(sequences=args[0]) else: - return self.input_model(sequences=args) + return self.input_schema(sequences=args) - return self.input_model(**kwargs) + return self.input_schema(**kwargs) def process_inputs(self, inputs: TextClassificationInput) -> List[numpy.ndarray]: """ @@ -191,7 +191,7 @@ def process_engine_outputs(self, engine_outputs: List[numpy.ndarray]) -> BaseMod """ :param engine_outputs: list of numpy arrays that are the output of the engine forward pass - :return: outputs of engine post-processed into an object in the `output_model` + :return: outputs of engine post-processed into an object in the `output_schema` format of this pipeline """ outputs = engine_outputs @@ -211,7 +211,7 @@ def process_engine_outputs(self, engine_outputs: List[numpy.ndarray]) -> BaseMod labels.append(self.config.id2label[score.argmax()]) label_scores.append(score.max().item()) - return self.output_model( + return self.output_schema( labels=labels, scores=label_scores, ) diff --git a/src/deepsparse/transformers/pipelines/token_classification.py b/src/deepsparse/transformers/pipelines/token_classification.py index 6150085626..35db3ed0fc 100644 --- a/src/deepsparse/transformers/pipelines/token_classification.py +++ b/src/deepsparse/transformers/pipelines/token_classification.py @@ -195,14 +195,14 @@ def ignore_labels(self) -> List[str]: return self._ignore_labels @property - def input_model(self) -> Type[BaseModel]: + def input_schema(self) -> Type[BaseModel]: """ :return: pydantic model class that inputs to this pipeline must comply to """ return TokenClassificationInput @property - def output_model(self) -> Type[BaseModel]: + def output_schema(self) -> Type[BaseModel]: """ :return: pydantic model class that outputs of this pipeline must comply to """ @@ -210,11 +210,11 @@ def output_model(self) -> Type[BaseModel]: def parse_inputs(self, *args, **kwargs) -> BaseModel: """ - :param args: ordered arguments to pipeline, only an input_model object + :param args: ordered arguments to pipeline, only an input_schema object is supported as an arg for this function :param kwargs: keyword arguments to pipeline - :return: pipeline arguments parsed into the given `input_model` - schema if necessary. If an instance of the `input_model` is provided + :return: pipeline arguments parsed into the given `input_schema` + schema if necessary. If an instance of the `input_schema` is provided it will be returned """ if args and kwargs: @@ -225,14 +225,14 @@ def parse_inputs(self, *args, **kwargs) -> BaseModel: if args: if len(args) == 1: - # passed input_model schema directly - if isinstance(args[0], self.input_model): + # passed input_schema schema directly + if isinstance(args[0], self.input_schema): return args[0] - return self.input_model(inputs=args[0]) + return self.input_schema(inputs=args[0]) else: - return self.input_model(inputs=args) + return self.input_schema(inputs=args) - return self.input_model(**kwargs) + return self.input_schema(**kwargs) def process_inputs( self, @@ -278,7 +278,7 @@ def process_engine_outputs( """ :param engine_outputs: list of numpy arrays that are the output of the engine forward pass - :return: outputs of engine post-processed into an object in the `output_model` + :return: outputs of engine post-processed into an object in the `output_schema` format of this pipeline """ inputs = kwargs["inputs"] @@ -316,7 +316,7 @@ def process_engine_outputs( current_results.append(TokenClassificationResult(**entity)) predictions.append(current_results) - return self.output_model(predictions=predictions) + return self.output_schema(predictions=predictions) # utilities below adapted from transformers From f6ba36d0242e0a5fa7a2aeab6403d0d834e1a975 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 29 Apr 2022 17:57:48 -0400 Subject: [PATCH 2/2] refactor yolo pipeline --- src/deepsparse/yolo/pipelines.py | 48 ++++++++++++++++---------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/deepsparse/yolo/pipelines.py b/src/deepsparse/yolo/pipelines.py index ea84e82ca7..d6d4a11574 100644 --- a/src/deepsparse/yolo/pipelines.py +++ b/src/deepsparse/yolo/pipelines.py @@ -101,6 +101,28 @@ def __init__( ) self._model_config = model_config + @property + def model_config(self) -> str: + return self._model_config + + @property + def class_names(self) -> Optional[Dict[str, str]]: + return self._class_names + + @property + def input_schema(self) -> Type[YOLOInput]: + """ + :return: pydantic model class that inputs to this pipeline must comply to + """ + return YOLOInput + + @property + def output_schema(self) -> Type[YOLOOutput]: + """ + :return: pydantic model class that outputs of this pipeline must comply to + """ + return YOLOOutput + def setup_onnx_file_path(self) -> str: """ Performs any setup to unwrap and process the given `model_path` and other @@ -113,7 +135,7 @@ class properties into an inference ready onnx file to be compiled by the def process_inputs(self, inputs: YOLOInput) -> List[numpy.ndarray]: """ - :param inputs: inputs to the pipeline. Must be the type of the `input_model` + :param inputs: inputs to the pipeline. Must be the type of the `input_schema` of this pipeline :return: inputs of this model processed into a list of numpy arrays that can be directly passed into the forward pass of the pipeline engine @@ -147,7 +169,7 @@ def process_engine_outputs( """ :param engine_outputs: list of numpy arrays that are the output of the engine forward pass - :return: outputs of engine post-processed into an object in the `output_model` + :return: outputs of engine post-processed into an object in the `output_schema` format of this pipeline """ @@ -182,28 +204,6 @@ def process_engine_outputs( labels=batch_labels, ) - @property - def input_model(self) -> Type[YOLOInput]: - """ - :return: pydantic model class that inputs to this pipeline must comply to - """ - return YOLOInput - - @property - def output_model(self) -> Type[YOLOOutput]: - """ - :return: pydantic model class that outputs of this pipeline must comply to - """ - return YOLOOutput - - @property - def model_config(self) -> str: - return self._model_config - - @property - def class_names(self): - return self._class_names - def _infer_image_shape(self, onnx_model) -> Tuple[int, ...]: """ Infer and return the expected shape of the input tensor