diff --git a/src/deepsparse/image_classification/pipelines.py b/src/deepsparse/image_classification/pipelines.py index 71be2064f1..b909dd12f1 100644 --- a/src/deepsparse/image_classification/pipelines.py +++ b/src/deepsparse/image_classification/pipelines.py @@ -21,16 +21,16 @@ import numpy import onnx -from deepsparse import Scheduler from deepsparse.image_classification.constants import ( IMAGENET_RGB_MEANS, IMAGENET_RGB_STDS, ) -from deepsparse.pipeline import DEEPSPARSE_ENGINE, Pipeline -from image_classification.schemas import ( +from deepsparse.image_classification.schemas import ( ImageClassificationInput, ImageClassificationOutput, ) +from deepsparse.pipeline import Pipeline +from deepsparse.utils import model_to_path try: @@ -65,32 +65,42 @@ class ImageClassificationPipeline(Pipeline): def __init__( self, - model_path: str, - engine_type: str = DEEPSPARSE_ENGINE, - batch_size: int = 1, - num_cores: int = None, - scheduler: Scheduler = None, - input_shapes: List[List[int]] = None, - alias: Optional[str] = None, - class_names: Optional[Union[str, Dict[str, str]]] = None, + *, + class_names: Union[None, str, Dict[str, str]] = None, + **kwargs, ): - super().__init__( - model_path, - engine_type, - batch_size, - num_cores, - scheduler, - input_shapes, - alias, - ) - self._input_shape = None + super().__init__(**kwargs) if isinstance(class_names, str) and class_names.endswith(".json"): - self.class_names = json.load(open(class_names)) + self._class_names = json.load(open(class_names)) elif isinstance(class_names, dict): - self.class_names = class_names + self._class_names = class_names else: - self.class_names = None + self._class_names = None + + self._image_size = self._infer_image_size() + + @property + def class_names(self) -> Optional[Dict[str, str]]: + """ + :return: Optional dict, or json file of class names to use for + mapping class ids to class labels + """ + return self._class_names + + @property + def input_model(self) -> Type[ImageClassificationInput]: + """ + :return: pydantic model class that inputs to this pipeline must comply to + """ + return ImageClassificationInput + + @property + def output_model(self) -> Type[ImageClassificationOutput]: + """ + :return: pydantic model class that outputs of this pipeline must comply to + """ + return ImageClassificationOutput def setup_onnx_file_path(self) -> str: """ @@ -100,7 +110,8 @@ class properties into an inference ready onnx file to be compiled by the :return: file path to the ONNX file for the engine to compile """ - return self.model_path + + return model_to_path(self.model_path) def process_inputs(self, inputs: ImageClassificationInput) -> List[numpy.ndarray]: """ @@ -120,9 +131,14 @@ def process_inputs(self, inputs: ImageClassificationInput) -> List[numpy.ndarray inputs.images = [inputs.images] for image in inputs.images: + if cv2 is None: + raise RuntimeError( + "cv2 is required to load image inputs from file " + f"Unable to import: {cv2_error}" + ) img = cv2.imread(image) if isinstance(image, str) else image - img = cv2.resize(img, dsize=self.input_shape) + img = cv2.resize(img, dsize=self._image_size) img = img[:, :, ::-1].transpose(2, 0, 1) image_batch.append(img) @@ -161,38 +177,13 @@ def process_engine_outputs( labels=labels, ) - @property - def input_model(self) -> Type[ImageClassificationInput]: - """ - :return: pydantic model class that inputs to this pipeline must comply to - """ - return ImageClassificationInput - - @property - def output_model(self) -> Type[ImageClassificationOutput]: - """ - :return: pydantic model class that outputs of this pipeline must comply to - """ - return ImageClassificationOutput - - @property - def input_shape(self) -> Tuple[int, ...]: - """ - Returns the expected shape of the input tensor - - :return: The expected shape of the input tensor from onnx graph - """ - if self._input_shape is None: - self._input_shape = self._infer_image_shape() - return self._input_shape - - def _infer_image_shape(self) -> Tuple[int, ...]: + def _infer_image_size(self) -> Tuple[int, ...]: """ Infer and return the expected shape of the input tensor :return: The expected shape of the input tensor from onnx graph """ - onnx_model = onnx.load(self.engine.model_path) + onnx_model = onnx.load(self.onnx_file_path) input_tensor = onnx_model.graph.input[0] return ( input_tensor.type.tensor_type.shape.dim[2].dim_value, diff --git a/src/deepsparse/image_classification/validation_script.py b/src/deepsparse/image_classification/validation_script.py index db9ed8fe16..e176b4072c 100644 --- a/src/deepsparse/image_classification/validation_script.py +++ b/src/deepsparse/image_classification/validation_script.py @@ -92,9 +92,18 @@ default=1, show_default=True, help="Test batch size, must divide the dataset evenly, else last " - "batch will be dropped", + "batch will be dropped", ) -def main(dataset_path: str, model_path: str, batch_size: int): +@click.option( + "--image-size", + "--image_size", + type=int, + default=224, + show_default=True, + help="Test batch size, must divide the dataset evenly, else last " + "batch will be dropped", +) +def main(dataset_path: str, model_path: str, batch_size: int, image_size: int): """ Validation Script for Image Classification Models """ @@ -104,7 +113,7 @@ def main(dataset_path: str, model_path: str, batch_size: int): transform=transforms.Compose( [ transforms.ToTensor(), - transforms.Resize(size=(224, 224)), + transforms.Resize(size=(image_size, image_size)), ] ), ) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index d26c3f77ee..353ed596f7 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -132,7 +132,7 @@ def __init__( if engine_type.lower() == DEEPSPARSE_ENGINE: self._engine_args["scheduler"] = scheduler - self._onnx_file_path = self.setup_onnx_file_path() + self.onnx_file_path = self.setup_onnx_file_path() self._engine = self._initialize_engine() def __call__(self, pipeline_inputs: BaseModel = None, **kwargs) -> BaseModel: @@ -386,13 +386,6 @@ def engine_type(self) -> str: """ return self._engine_type - @property - def onnx_file_path(self) -> str: - """ - :return: onnx file path used to instantiate engine - """ - return self._onnx_file_path - def to_config(self) -> "PipelineConfig": """ :return: PipelineConfig that can be used to reload this object