From 1de3ec18501aa9905ecd5e4d23ba1d28a37ee67a Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Mon, 18 Apr 2022 08:16:18 -0400 Subject: [PATCH 1/8] Create a command line installable for image classification pipeline --- setup.py | 5 +++ .../image_classification/__init__.py | 13 +++++++ .../image_classification/pipelines.py | 35 +++++++++++++++++++ 3 files changed, 53 insertions(+) create mode 100644 src/deepsparse/image_classification/__init__.py create mode 100644 src/deepsparse/image_classification/pipelines.py diff --git a/setup.py b/setup.py index ed3950f517..787b89a512 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,9 @@ "onnxruntime>=1.7.0", ] +_ic_integration_deps = [ + "torch>=1.1.0,<=1.9.1", +] class OverrideInstall(install): """ @@ -173,6 +176,7 @@ def _setup_extras() -> Dict: "dev": _dev_deps, "server": _server_deps, "onnxruntime": _onnxruntime_deps, + "image_classification": _ic_integration_deps, } @@ -187,6 +191,7 @@ def _setup_entry_points() -> Dict: "deepsparse.check_hardware=deepsparse.cpu:print_hardware_capability", "deepsparse.benchmark=deepsparse.benchmark_model.benchmark_model:main", "deepsparse.server=deepsparse.server.main:start_server", + "deepsparse.image_classification=deepsparse.image_classification.pipelines:main", ] } diff --git a/src/deepsparse/image_classification/__init__.py b/src/deepsparse/image_classification/__init__.py new file mode 100644 index 0000000000..0c44f887a4 --- /dev/null +++ b/src/deepsparse/image_classification/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/deepsparse/image_classification/pipelines.py b/src/deepsparse/image_classification/pipelines.py new file mode 100644 index 0000000000..c1cc9f14b0 --- /dev/null +++ b/src/deepsparse/image_classification/pipelines.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Image classification pipeline +""" +__status__ = "Under-Development" +try: + import torch + + torch_error = None +except ModuleNotFoundError as error: + torch = None + torch_error = error + + +def main(): + print(f"Currently this module is {__status__}") + if torch: + print("Torch version:", torch.__version__) + + +if __name__ == '__main__': + main() \ No newline at end of file From e8a28cff4146b1ad13750789dcf8e17d3ec77b25 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 19 Apr 2022 19:59:11 -0400 Subject: [PATCH 2/8] Intermediate Commit --- setup.py | 3 +- .../image_classification/constants.py | 16 +++ .../image_classification/pipelines.py | 124 ++++++++++++++++-- src/deepsparse/pipeline.py | 2 +- 4 files changed, 131 insertions(+), 14 deletions(-) create mode 100644 src/deepsparse/image_classification/constants.py diff --git a/setup.py b/setup.py index 787b89a512..cfedfa8112 100644 --- a/setup.py +++ b/setup.py @@ -82,9 +82,10 @@ ] _ic_integration_deps = [ - "torch>=1.1.0,<=1.9.1", + "opencv-python", ] + class OverrideInstall(install): """ Install class to run checks for supported systems before install diff --git a/src/deepsparse/image_classification/constants.py b/src/deepsparse/image_classification/constants.py new file mode 100644 index 0000000000..d035e44513 --- /dev/null +++ b/src/deepsparse/image_classification/constants.py @@ -0,0 +1,16 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +IMAGENET_RGB_MEANS = [0.485, 0.456, 0.406] +IMAGENET_RGB_STDS = [0.229, 0.224, 0.225] diff --git a/src/deepsparse/image_classification/pipelines.py b/src/deepsparse/image_classification/pipelines.py index c1cc9f14b0..bc4f321dca 100644 --- a/src/deepsparse/image_classification/pipelines.py +++ b/src/deepsparse/image_classification/pipelines.py @@ -15,21 +15,121 @@ """ Image classification pipeline """ -__status__ = "Under-Development" +import json +from typing import Dict, List, Union + +import numpy +import numpy as np +from pydantic import BaseModel + +from constants import IMAGENET_RGB_MEANS, IMAGENET_RGB_STDS +from deepsparse.pipeline import Pipeline + + try: - import torch + import cv2 +except ModuleNotFoundError as e: + cv2 = None + cv2_error = e + + +class ImageClassificationInput(BaseModel): + """ + Input model for image classification + """ + + images: Union[str, numpy.ndarray, List[str]] + + +class ImageClassificationOutput(BaseModel): + """ + Input model for image classification + """ + + labels: List[int] + scores: List[float] + + +@Pipeline.register(task="image_classification") +class ImageClassificationPipeline(Pipeline): + """ + Image classification pipeline for DeepSparse + """ + + def setup_onnx_file_path(self) -> str: + """ + Performs any setup to unwrap and process the given `model_path` and other + class properties into an inference ready onnx file to be compiled by the + engine of the pipeline + + :return: file path to the ONNX file for the engine to compile + """ + return self.model_path + + def process_inputs(self, inputs: ImageClassificationInput) -> List[numpy.ndarray]: + """ + Pre-Process the Inputs for DeepSparse Engine + + :param inputs: input model + :return: list of numpy arrays + """ + + # TODO: Check logic for 3-dim and 2-dim images + images = [] + non_rand_resize_scale = 256.0 / 224.0 # standard used + image_size = 224 + + scaled_image_size = non_rand_resize_scale * image_size + + for image_file in inputs.images: + img = cv2.imread(image_file) + if img is not None: + img = cv2.resize(img, (scaled_image_size, scaled_image_size)) + center = img.shape / 2 + x = center[1] - image_size / 2 + y = center[0] - image_size / 2 + + crop_img = img[ + int(y) : int(y + image_size), int(x) : int(x + image_size) + ] + + crop_img -= np.asarray(IMAGENET_RGB_MEANS) + crop_img /= np.asarray(IMAGENET_RGB_STDS) + images.append(crop_img) + + return images + + def process_engine_outputs( + self, + engine_outputs: List[numpy.ndarray], + ) -> ImageClassificationOutput: + return ImageClassificationOutput( + scores=numpy.max(engine_outputs[0], axis=1).tolist(), + labels=numpy.argmax(engine_outputs[0], axis=1).tolist(), + ) - torch_error = None -except ModuleNotFoundError as error: - torch = None - torch_error = error + @property + def input_model(self) -> BaseModel: + return ImageClassificationInput + @property + def output_model(self) -> BaseModel: + return ImageClassificationOutput -def main(): - print(f"Currently this module is {__status__}") - if torch: - print("Torch version:", torch.__version__) + def map_labels_to_classes( + self, + labels: List[int], + class_names: Union[str, Dict[int, str]], + ) -> List[str]: + """ + :param labels: predicted class ids + :param class_names: A json file containing the mapping of class ids to + class names, or a dictionary mapping class ids to class names. + :return: Predicted class names from labels + """ + if isinstance(class_names, str) and class_names.endswith(".json"): + class_names = json.loads(class_names) -if __name__ == '__main__': - main() \ No newline at end of file + predicted_class_names = [class_names[label] for label in labels] + return predicted_class_names diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 6e210a721c..678eaa11b1 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -27,7 +27,7 @@ from pydantic import BaseModel, Field from deepsparse import Engine, Scheduler -from deepsparse.benchmark import ORTEngine +from deepsparse.benchmark_model.ort_engine import ORTEngine from deepsparse.tasks import SupportedTasks From 5e3c6fb739906743d36acf5b543731756d4526e3 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 21 Apr 2022 07:25:51 -0400 Subject: [PATCH 3/8] Image Classification pipeline implementation --- setup.py | 4 +- .../image_classification/pipelines.py | 180 ++++++++++++------ .../image_classification/schemas.py | 42 ++++ src/deepsparse/pipeline.py | 5 +- src/deepsparse/tasks.py | 21 ++ 5 files changed, 188 insertions(+), 64 deletions(-) create mode 100644 src/deepsparse/image_classification/schemas.py diff --git a/setup.py b/setup.py index cfedfa8112..b55f8baa7f 100644 --- a/setup.py +++ b/setup.py @@ -184,6 +184,8 @@ def _setup_extras() -> Dict: def _setup_entry_points() -> Dict: data_api_entrypoint = "deepsparse.transformers.pipelines_cli:cli" eval_downstream = "deepsparse.transformers.eval_downstream:main" + ic_pipeline_entrypoint = "deepsparse.image_classification.pipelines:main" + return { "console_scripts": [ f"deepsparse.transformers.run_inference={data_api_entrypoint}", @@ -192,7 +194,7 @@ def _setup_entry_points() -> Dict: "deepsparse.check_hardware=deepsparse.cpu:print_hardware_capability", "deepsparse.benchmark=deepsparse.benchmark_model.benchmark_model:main", "deepsparse.server=deepsparse.server.main:start_server", - "deepsparse.image_classification=deepsparse.image_classification.pipelines:main", + f"deepsparse.image_classification={ic_pipeline_entrypoint}", ] } diff --git a/src/deepsparse/image_classification/pipelines.py b/src/deepsparse/image_classification/pipelines.py index bc4f321dca..3a65aee4d2 100644 --- a/src/deepsparse/image_classification/pipelines.py +++ b/src/deepsparse/image_classification/pipelines.py @@ -16,46 +16,85 @@ Image classification pipeline """ import json -from typing import Dict, List, Union +from typing import Dict, List, Optional, Tuple, Type, Union import numpy -import numpy as np -from pydantic import BaseModel +import onnx -from constants import IMAGENET_RGB_MEANS, IMAGENET_RGB_STDS -from deepsparse.pipeline import Pipeline +from deepsparse import Scheduler +from deepsparse.image_classification.constants import ( + IMAGENET_RGB_MEANS, + IMAGENET_RGB_STDS, +) +from deepsparse.pipeline import DEEPSPARSE_ENGINE, Pipeline +from image_classification.schemas import ( + ImageClassificationInput, + ImageClassificationOutput, +) try: import cv2 -except ModuleNotFoundError as e: - cv2 = None - cv2_error = e - - -class ImageClassificationInput(BaseModel): - """ - Input model for image classification - """ - images: Union[str, numpy.ndarray, List[str]] - - -class ImageClassificationOutput(BaseModel): - """ - Input model for image classification - """ - - labels: List[int] - scores: List[float] + cv2_error = None +except ModuleNotFoundError as cv2_import_error: + cv2 = None + cv2_error = cv2_import_error @Pipeline.register(task="image_classification") class ImageClassificationPipeline(Pipeline): """ Image classification pipeline for DeepSparse + + :param model_path: path on local system or SparseZoo stub to load the model from + :param engine_type: inference engine to use. Currently supported values include + 'deepsparse' and 'onnxruntime'. Default is 'deepsparse' + :param batch_size: static batch size to use for inference. Default is 1 + :param num_cores: number of CPU cores to allocate for inference engine. None + specifies all available cores. Default is None + :param scheduler: (deepsparse only) kind of scheduler to execute with. + Pass None for the default + :param input_shapes: list of shapes to set ONNX the inputs to. Pass None + to use model as-is. Default is None + :param alias: optional name to give this pipeline instance, useful when + inferencing with multiple models. Default is None + :param class_names: Optional dict, or json file of class names to use for + mapping class ids to class labels. Default is None """ + def __init__( + self, + model_path: str, + engine_type: str = DEEPSPARSE_ENGINE, + batch_size: int = 1, + num_cores: int = None, + scheduler: Scheduler = None, + input_shapes: List[List[int]] = None, + alias: Optional[str] = None, + class_names: Optional[Union[str, Dict[str, str]]] = None, + ): + super().__init__( + model_path, + engine_type, + batch_size, + num_cores, + scheduler, + input_shapes, + alias, + ) + self._input_shape = None + + if isinstance(class_names, str) and class_names.endswith(".json"): + self.class_names = json.load(open(class_names)) + elif isinstance(class_names, dict): + self.class_names = class_names + else: + raise ValueError( + "class_names must be a dict or a json file path" + f" (got {type(class_names)} instead)" + ) + def setup_onnx_file_path(self) -> str: """ Performs any setup to unwrap and process the given `model_path` and other @@ -71,65 +110,86 @@ def process_inputs(self, inputs: ImageClassificationInput) -> List[numpy.ndarray Pre-Process the Inputs for DeepSparse Engine :param inputs: input model - :return: list of numpy arrays + :return: list of preprocessed numpy arrays """ - # TODO: Check logic for 3-dim and 2-dim images - images = [] - non_rand_resize_scale = 256.0 / 224.0 # standard used - image_size = 224 + image_batch = [] - scaled_image_size = non_rand_resize_scale * image_size + if isinstance(inputs.images, str): + inputs.images = [inputs.images] - for image_file in inputs.images: - img = cv2.imread(image_file) - if img is not None: - img = cv2.resize(img, (scaled_image_size, scaled_image_size)) - center = img.shape / 2 - x = center[1] - image_size / 2 - y = center[0] - image_size / 2 + for image in inputs.images: + img = cv2.imread(image) if isinstance(image, str) else image - crop_img = img[ - int(y) : int(y + image_size), int(x) : int(x + image_size) - ] + img = cv2.resize(img, dsize=self.input_shape) + img = img[:, :, ::-1].transpose(2, 0, 1) - crop_img -= np.asarray(IMAGENET_RGB_MEANS) - crop_img /= np.asarray(IMAGENET_RGB_STDS) - images.append(crop_img) + image_batch.append(img) - return images + image_batch = numpy.stack(image_batch, axis=0) + image_batch = numpy.ascontiguousarray(image_batch, dtype=numpy.float32) + image_batch /= 255.0 + + # normalize entire batch + image_batch -= numpy.asarray(IMAGENET_RGB_MEANS).reshape((-1, 3, 1, 1)) + image_batch /= numpy.asarray(IMAGENET_RGB_STDS).reshape((-1, 3, 1, 1)) + + return [image_batch] def process_engine_outputs( self, engine_outputs: List[numpy.ndarray], ) -> ImageClassificationOutput: + """ + :param engine_outputs: list of numpy arrays that are the output of the engine + forward pass + :return: outputs of engine post-processed into an object in the `output_model` + format of this pipeline + """ + labels = numpy.argmax(engine_outputs[0], axis=1).tolist() + + if self.class_names is not None: + labels = [self.class_names[str(class_id)] for class_id in labels] + return ImageClassificationOutput( scores=numpy.max(engine_outputs[0], axis=1).tolist(), - labels=numpy.argmax(engine_outputs[0], axis=1).tolist(), + labels=labels, ) @property - def input_model(self) -> BaseModel: + def input_model(self) -> Type[ImageClassificationInput]: + """ + :return: pydantic model class that inputs to this pipeline must comply to + """ return ImageClassificationInput @property - def output_model(self) -> BaseModel: + def output_model(self) -> Type[ImageClassificationOutput]: + """ + :return: pydantic model class that outputs of this pipeline must comply to + """ return ImageClassificationOutput - def map_labels_to_classes( - self, - labels: List[int], - class_names: Union[str, Dict[int, str]], - ) -> List[str]: + @property + def input_shape(self) -> Tuple[int, ...]: """ - :param labels: predicted class ids - :param class_names: A json file containing the mapping of class ids to - class names, or a dictionary mapping class ids to class names. - :return: Predicted class names from labels + Returns the expected shape of the input tensor + + :return: The expected shape of the input tensor from onnx graph """ + if self._input_shape is None: + self._input_shape = self._infer_input_shape() + return self._input_shape - if isinstance(class_names, str) and class_names.endswith(".json"): - class_names = json.loads(class_names) + def _infer_input_shape(self) -> Tuple[int, ...]: + """ + Infer and return the expected shape of the input tensor - predicted_class_names = [class_names[label] for label in labels] - return predicted_class_names + :return: The expected shape of the input tensor from onnx graph + """ + onnx_model = onnx.load(self.engine.model_path) + input_tensor = onnx_model.graph.input[0] + return ( + input_tensor.type.tensor_type.shape.dim[2].dim_value, + input_tensor.type.tensor_type.shape.dim[3].dim_value, + ) diff --git a/src/deepsparse/image_classification/schemas.py b/src/deepsparse/image_classification/schemas.py new file mode 100644 index 0000000000..f20e9d5819 --- /dev/null +++ b/src/deepsparse/image_classification/schemas.py @@ -0,0 +1,42 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Input/Output Schemas for Image Classification. +""" + +from typing import List, Union + +import numpy +from pydantic import BaseModel + + +class ImageClassificationInput(BaseModel): + """ + Input model for image classification + """ + + images: Union[str, List[numpy.ndarray], List[str]] + + class Config: + arbitrary_types_allowed = True + + +class ImageClassificationOutput(BaseModel): + """ + Input model for image classification + """ + + labels: List[Union[int, str]] + scores: List[float] diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 678eaa11b1..d26c3f77ee 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -134,7 +134,6 @@ def __init__( self._onnx_file_path = self.setup_onnx_file_path() self._engine = self._initialize_engine() - pass def __call__(self, pipeline_inputs: BaseModel = None, **kwargs) -> BaseModel: if pipeline_inputs is None and kwargs: @@ -222,7 +221,7 @@ def create( ) @classmethod - def register(cls, task: str, task_aliases: Optional[List[str]]): + def register(cls, task: str, task_aliases: Optional[List[str]] = None): """ Pipeline implementer class decorator that registers the pipeline task name and its aliases as valid tasks that can be used to load @@ -233,7 +232,7 @@ def register(cls, task: str, task_aliases: Optional[List[str]]): :param task: main task name of this pipeline :param task_aliases: list of extra task names that may be used to reference - this pipeline + this pipeline. Default is None """ task_names = [task] if task_aliases: diff --git a/src/deepsparse/tasks.py b/src/deepsparse/tasks.py index 4b24c6d16c..326f0721c9 100644 --- a/src/deepsparse/tasks.py +++ b/src/deepsparse/tasks.py @@ -78,12 +78,24 @@ class SupportedTasks: token_classification=AliasedTask("token_classification", ["ner"]), ) + image_classification = namedtuple("image_classification", ["image_classification"])( + image_classification=AliasedTask( + "image_classification", + ["image_classification"], + ), + ) + @classmethod def check_register_task(cls, task: str): if cls.is_nlp(task): # trigger transformers pipelines to register with Pipeline.register import deepsparse.transformers.pipelines # noqa: F401 + elif cls.is_image_classification(task): + # trigger image classification pipelines to + # register with Pipeline.register + import deepsparse.image_classification.pipelines # noqa: F401 + @classmethod def is_nlp(cls, task: str) -> bool: """ @@ -96,3 +108,12 @@ def is_nlp(cls, task: str) -> bool: or cls.nlp.text_classification.matches(task) or cls.nlp.token_classification.matches(task) ) + + @classmethod + def is_image_classification(cls, task: str) -> bool: + """ + :param task: the name of the task to check whether it is an image + classification task + :return: True if it is an image classification task, False otherwise + """ + return cls.image_classification.image_classification.matches(task) From 3b1eb2075e3fded863d9dfc5096b08289053ce7c Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 21 Apr 2022 08:02:36 -0400 Subject: [PATCH 4/8] Remove faulty entry point --- setup.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.py b/setup.py index b55f8baa7f..af9307e2b9 100644 --- a/setup.py +++ b/setup.py @@ -184,7 +184,6 @@ def _setup_extras() -> Dict: def _setup_entry_points() -> Dict: data_api_entrypoint = "deepsparse.transformers.pipelines_cli:cli" eval_downstream = "deepsparse.transformers.eval_downstream:main" - ic_pipeline_entrypoint = "deepsparse.image_classification.pipelines:main" return { "console_scripts": [ @@ -194,7 +193,6 @@ def _setup_entry_points() -> Dict: "deepsparse.check_hardware=deepsparse.cpu:print_hardware_capability", "deepsparse.benchmark=deepsparse.benchmark_model.benchmark_model:main", "deepsparse.server=deepsparse.server.main:start_server", - f"deepsparse.image_classification={ic_pipeline_entrypoint}", ] } From 6830b692d8ff877db083e5805393cefb5fbe0b15 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 21 Apr 2022 11:33:40 -0400 Subject: [PATCH 5/8] Apply suggestions from @bogunowicz Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> --- src/deepsparse/image_classification/schemas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/deepsparse/image_classification/schemas.py b/src/deepsparse/image_classification/schemas.py index f20e9d5819..4232702898 100644 --- a/src/deepsparse/image_classification/schemas.py +++ b/src/deepsparse/image_classification/schemas.py @@ -35,7 +35,7 @@ class Config: class ImageClassificationOutput(BaseModel): """ - Input model for image classification + Output model for image classification """ labels: List[Union[int, str]] From c4ad45c190c784d1be5fb183a66b903889291f8b Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 21 Apr 2022 14:16:22 -0400 Subject: [PATCH 6/8] Changed function name from `_infer_input_shape` to `_infer_image_shape` --- src/deepsparse/image_classification/pipelines.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/deepsparse/image_classification/pipelines.py b/src/deepsparse/image_classification/pipelines.py index 3a65aee4d2..5c2b41b2e6 100644 --- a/src/deepsparse/image_classification/pipelines.py +++ b/src/deepsparse/image_classification/pipelines.py @@ -178,10 +178,10 @@ def input_shape(self) -> Tuple[int, ...]: :return: The expected shape of the input tensor from onnx graph """ if self._input_shape is None: - self._input_shape = self._infer_input_shape() + self._input_shape = self._infer_image_shape() return self._input_shape - def _infer_input_shape(self) -> Tuple[int, ...]: + def _infer_image_shape(self) -> Tuple[int, ...]: """ Infer and return the expected shape of the input tensor From 8d25217771e1b72c6ad51a7dd6aa5ae0fdf8110e Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Fri, 29 Apr 2022 14:28:35 -0400 Subject: [PATCH 7/8] Add validation script for Image Classification pipeline (#328) * Add Validation Script for Image Classification Models * Update pipelines and corresponding schemas to work with numpy arrays * Bugfix if prediction to be converted to int if it's a string * Update docstring * Update src/deepsparse/image_classification/validation_script.py --- setup.py | 1 + .../image_classification/pipelines.py | 33 ++-- .../image_classification/schemas.py | 2 +- .../image_classification/validation_script.py | 153 ++++++++++++++++++ 4 files changed, 174 insertions(+), 15 deletions(-) create mode 100644 src/deepsparse/image_classification/validation_script.py diff --git a/setup.py b/setup.py index af9307e2b9..aa7188bfdb 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,7 @@ ] _ic_integration_deps = [ + "click<8.1", "opencv-python", ] diff --git a/src/deepsparse/image_classification/pipelines.py b/src/deepsparse/image_classification/pipelines.py index 5c2b41b2e6..71be2064f1 100644 --- a/src/deepsparse/image_classification/pipelines.py +++ b/src/deepsparse/image_classification/pipelines.py @@ -90,10 +90,7 @@ def __init__( elif isinstance(class_names, dict): self.class_names = class_names else: - raise ValueError( - "class_names must be a dict or a json file path" - f" (got {type(class_names)} instead)" - ) + self.class_names = None def setup_onnx_file_path(self) -> str: """ @@ -113,22 +110,30 @@ def process_inputs(self, inputs: ImageClassificationInput) -> List[numpy.ndarray :return: list of preprocessed numpy arrays """ - image_batch = [] + if isinstance(inputs.images, numpy.ndarray): + image_batch = inputs.images + else: + + image_batch = [] - if isinstance(inputs.images, str): - inputs.images = [inputs.images] + if isinstance(inputs.images, str): + inputs.images = [inputs.images] - for image in inputs.images: - img = cv2.imread(image) if isinstance(image, str) else image + for image in inputs.images: + img = cv2.imread(image) if isinstance(image, str) else image - img = cv2.resize(img, dsize=self.input_shape) - img = img[:, :, ::-1].transpose(2, 0, 1) + img = cv2.resize(img, dsize=self.input_shape) + img = img[:, :, ::-1].transpose(2, 0, 1) + image_batch.append(img) - image_batch.append(img) + image_batch = numpy.stack(image_batch, axis=0) - image_batch = numpy.stack(image_batch, axis=0) + original_dtype = image_batch.dtype image_batch = numpy.ascontiguousarray(image_batch, dtype=numpy.float32) - image_batch /= 255.0 + + if original_dtype == numpy.uint8: + + image_batch /= 255 # normalize entire batch image_batch -= numpy.asarray(IMAGENET_RGB_MEANS).reshape((-1, 3, 1, 1)) diff --git a/src/deepsparse/image_classification/schemas.py b/src/deepsparse/image_classification/schemas.py index 4232702898..5a92b90e3b 100644 --- a/src/deepsparse/image_classification/schemas.py +++ b/src/deepsparse/image_classification/schemas.py @@ -27,7 +27,7 @@ class ImageClassificationInput(BaseModel): Input model for image classification """ - images: Union[str, List[numpy.ndarray], List[str]] + images: Union[str, numpy.ndarray, List[str]] class Config: arbitrary_types_allowed = True diff --git a/src/deepsparse/image_classification/validation_script.py b/src/deepsparse/image_classification/validation_script.py new file mode 100644 index 0000000000..db9ed8fe16 --- /dev/null +++ b/src/deepsparse/image_classification/validation_script.py @@ -0,0 +1,153 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: validation_script.py [OPTIONS] + + Validation Script for Image Classification Models + +Options: + --dataset-path, --dataset_path DIRECTORY + Path to the validation dataset [required] + --model-path, --model_path TEXT + Path/SparseZoo stub for the Image + Classification model to be evaluated. + Defaults to resnet50 trained on + Imagenette [default: zoo:cv/classification/ + resnet_v1-50/pytorch/sparseml/imagenette/ + base-none] + --batch-size, --batch_size INTEGER + Test batch size, must divide the dataset + evenly, else the last batch will be dropped + [default: 1] + --help Show this message and exit. + +######### +EXAMPLES +######### + +########## +Example command for validating pruned resnet50 on imagenette dataset: +python validation_script.py \ + --dataset-path /path/to/imagenette/ + +""" +from tqdm import tqdm + +from deepsparse.pipeline import Pipeline +from torch.utils.data import DataLoader +from torchvision import transforms + + +try: + import torchvision + +except ModuleNotFoundError as torchvision_error: # noqa: F841 + print( + "Torchvision not installed. Please install it using the command:" + "pip install torchvision>=0.3.0,<=0.10.1" + ) + exit(1) + +import click + + +resnet50_imagenet_pruned = ( + "zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenette/base-none" +) + + +@click.command() +@click.option( + "--dataset-path", + "--dataset_path", + required=True, + type=click.Path(dir_okay=True, file_okay=False), + help="Path to the validation dataset", +) +@click.option( + "--model-path", + "--model_path", + type=str, + default=resnet50_imagenet_pruned, + help="Path/SparseZoo stub for the Image Classification model to be " + "evaluated. Defaults to dense (vanilla) resnet50 trained on Imagenette", + show_default=True, +) +@click.option( + "--batch-size", + "--batch_size", + type=int, + default=1, + show_default=True, + help="Test batch size, must divide the dataset evenly, else last " + "batch will be dropped", +) +def main(dataset_path: str, model_path: str, batch_size: int): + """ + Validation Script for Image Classification Models + """ + + dataset = torchvision.datasets.ImageFolder( + root=dataset_path, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Resize(size=(224, 224)), + ] + ), + ) + + data_loader = DataLoader( + dataset=dataset, + batch_size=batch_size, + drop_last=True, + ) + + pipeline = Pipeline.create( + task="image_classification", + model_path=model_path, + batch_size=batch_size, + ) + correct = total = 0 + progress_bar = tqdm(data_loader) + + for batch in progress_bar: + batch, actual_labels = batch + batch = batch.numpy() + outs = pipeline(images=batch) + predicted_labels = outs.labels + + for actual, predicted in zip(actual_labels, predicted_labels): + total += 1 + if isinstance(predicted, str): + predicted = int(predicted) + if actual.item() == predicted: + correct += 1 + + if total > 0: + progress_bar.set_postfix( + {"Running Accuracy": f"{correct * 100 / total:.2f}%"} + ) + + # prevent division by zero + if total == 0: + epsilon = 1e-5 + total += epsilon + + print(f"Accuracy: {correct * 100 / total:.2f} %") + + +if __name__ == "__main__": + main() From 404bf7fac599b72c939b4cc74bc5c2705bd2cc31 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Fri, 29 Apr 2022 15:02:59 -0400 Subject: [PATCH 8/8] [feature/Pipeline] fixes for ic-pipelines implementation (#336) * fixes for ic-pipelines implementation * sparsezoo support --- .../image_classification/pipelines.py | 97 +++++++++---------- .../image_classification/validation_script.py | 15 ++- src/deepsparse/pipeline.py | 9 +- 3 files changed, 57 insertions(+), 64 deletions(-) diff --git a/src/deepsparse/image_classification/pipelines.py b/src/deepsparse/image_classification/pipelines.py index 71be2064f1..b909dd12f1 100644 --- a/src/deepsparse/image_classification/pipelines.py +++ b/src/deepsparse/image_classification/pipelines.py @@ -21,16 +21,16 @@ import numpy import onnx -from deepsparse import Scheduler from deepsparse.image_classification.constants import ( IMAGENET_RGB_MEANS, IMAGENET_RGB_STDS, ) -from deepsparse.pipeline import DEEPSPARSE_ENGINE, Pipeline -from image_classification.schemas import ( +from deepsparse.image_classification.schemas import ( ImageClassificationInput, ImageClassificationOutput, ) +from deepsparse.pipeline import Pipeline +from deepsparse.utils import model_to_path try: @@ -65,32 +65,42 @@ class ImageClassificationPipeline(Pipeline): def __init__( self, - model_path: str, - engine_type: str = DEEPSPARSE_ENGINE, - batch_size: int = 1, - num_cores: int = None, - scheduler: Scheduler = None, - input_shapes: List[List[int]] = None, - alias: Optional[str] = None, - class_names: Optional[Union[str, Dict[str, str]]] = None, + *, + class_names: Union[None, str, Dict[str, str]] = None, + **kwargs, ): - super().__init__( - model_path, - engine_type, - batch_size, - num_cores, - scheduler, - input_shapes, - alias, - ) - self._input_shape = None + super().__init__(**kwargs) if isinstance(class_names, str) and class_names.endswith(".json"): - self.class_names = json.load(open(class_names)) + self._class_names = json.load(open(class_names)) elif isinstance(class_names, dict): - self.class_names = class_names + self._class_names = class_names else: - self.class_names = None + self._class_names = None + + self._image_size = self._infer_image_size() + + @property + def class_names(self) -> Optional[Dict[str, str]]: + """ + :return: Optional dict, or json file of class names to use for + mapping class ids to class labels + """ + return self._class_names + + @property + def input_model(self) -> Type[ImageClassificationInput]: + """ + :return: pydantic model class that inputs to this pipeline must comply to + """ + return ImageClassificationInput + + @property + def output_model(self) -> Type[ImageClassificationOutput]: + """ + :return: pydantic model class that outputs of this pipeline must comply to + """ + return ImageClassificationOutput def setup_onnx_file_path(self) -> str: """ @@ -100,7 +110,8 @@ class properties into an inference ready onnx file to be compiled by the :return: file path to the ONNX file for the engine to compile """ - return self.model_path + + return model_to_path(self.model_path) def process_inputs(self, inputs: ImageClassificationInput) -> List[numpy.ndarray]: """ @@ -120,9 +131,14 @@ def process_inputs(self, inputs: ImageClassificationInput) -> List[numpy.ndarray inputs.images = [inputs.images] for image in inputs.images: + if cv2 is None: + raise RuntimeError( + "cv2 is required to load image inputs from file " + f"Unable to import: {cv2_error}" + ) img = cv2.imread(image) if isinstance(image, str) else image - img = cv2.resize(img, dsize=self.input_shape) + img = cv2.resize(img, dsize=self._image_size) img = img[:, :, ::-1].transpose(2, 0, 1) image_batch.append(img) @@ -161,38 +177,13 @@ def process_engine_outputs( labels=labels, ) - @property - def input_model(self) -> Type[ImageClassificationInput]: - """ - :return: pydantic model class that inputs to this pipeline must comply to - """ - return ImageClassificationInput - - @property - def output_model(self) -> Type[ImageClassificationOutput]: - """ - :return: pydantic model class that outputs of this pipeline must comply to - """ - return ImageClassificationOutput - - @property - def input_shape(self) -> Tuple[int, ...]: - """ - Returns the expected shape of the input tensor - - :return: The expected shape of the input tensor from onnx graph - """ - if self._input_shape is None: - self._input_shape = self._infer_image_shape() - return self._input_shape - - def _infer_image_shape(self) -> Tuple[int, ...]: + def _infer_image_size(self) -> Tuple[int, ...]: """ Infer and return the expected shape of the input tensor :return: The expected shape of the input tensor from onnx graph """ - onnx_model = onnx.load(self.engine.model_path) + onnx_model = onnx.load(self.onnx_file_path) input_tensor = onnx_model.graph.input[0] return ( input_tensor.type.tensor_type.shape.dim[2].dim_value, diff --git a/src/deepsparse/image_classification/validation_script.py b/src/deepsparse/image_classification/validation_script.py index db9ed8fe16..e176b4072c 100644 --- a/src/deepsparse/image_classification/validation_script.py +++ b/src/deepsparse/image_classification/validation_script.py @@ -92,9 +92,18 @@ default=1, show_default=True, help="Test batch size, must divide the dataset evenly, else last " - "batch will be dropped", + "batch will be dropped", ) -def main(dataset_path: str, model_path: str, batch_size: int): +@click.option( + "--image-size", + "--image_size", + type=int, + default=224, + show_default=True, + help="Test batch size, must divide the dataset evenly, else last " + "batch will be dropped", +) +def main(dataset_path: str, model_path: str, batch_size: int, image_size: int): """ Validation Script for Image Classification Models """ @@ -104,7 +113,7 @@ def main(dataset_path: str, model_path: str, batch_size: int): transform=transforms.Compose( [ transforms.ToTensor(), - transforms.Resize(size=(224, 224)), + transforms.Resize(size=(image_size, image_size)), ] ), ) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index d26c3f77ee..353ed596f7 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -132,7 +132,7 @@ def __init__( if engine_type.lower() == DEEPSPARSE_ENGINE: self._engine_args["scheduler"] = scheduler - self._onnx_file_path = self.setup_onnx_file_path() + self.onnx_file_path = self.setup_onnx_file_path() self._engine = self._initialize_engine() def __call__(self, pipeline_inputs: BaseModel = None, **kwargs) -> BaseModel: @@ -386,13 +386,6 @@ def engine_type(self) -> str: """ return self._engine_type - @property - def onnx_file_path(self) -> str: - """ - :return: onnx file path used to instantiate engine - """ - return self._onnx_file_path - def to_config(self) -> "PipelineConfig": """ :return: PipelineConfig that can be used to reload this object