Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 44 additions & 53 deletions src/deepsparse/image_classification/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
import numpy
import onnx

from deepsparse import Scheduler
from deepsparse.image_classification.constants import (
IMAGENET_RGB_MEANS,
IMAGENET_RGB_STDS,
)
from deepsparse.pipeline import DEEPSPARSE_ENGINE, Pipeline
from image_classification.schemas import (
from deepsparse.image_classification.schemas import (
ImageClassificationInput,
ImageClassificationOutput,
)
from deepsparse.pipeline import Pipeline
from deepsparse.utils import model_to_path


try:
Expand Down Expand Up @@ -65,32 +65,42 @@ class ImageClassificationPipeline(Pipeline):

def __init__(
self,
model_path: str,
engine_type: str = DEEPSPARSE_ENGINE,
batch_size: int = 1,
num_cores: int = None,
scheduler: Scheduler = None,
input_shapes: List[List[int]] = None,
alias: Optional[str] = None,
class_names: Optional[Union[str, Dict[str, str]]] = None,
*,
class_names: Union[None, str, Dict[str, str]] = None,
**kwargs,
):
super().__init__(
model_path,
engine_type,
batch_size,
num_cores,
scheduler,
input_shapes,
alias,
)
self._input_shape = None
super().__init__(**kwargs)

if isinstance(class_names, str) and class_names.endswith(".json"):
self.class_names = json.load(open(class_names))
self._class_names = json.load(open(class_names))
elif isinstance(class_names, dict):
self.class_names = class_names
self._class_names = class_names
else:
self.class_names = None
self._class_names = None

self._image_size = self._infer_image_size()

@property
def class_names(self) -> Optional[Dict[str, str]]:
"""
:return: Optional dict, or json file of class names to use for
mapping class ids to class labels
"""
return self._class_names

@property
def input_model(self) -> Type[ImageClassificationInput]:
"""
:return: pydantic model class that inputs to this pipeline must comply to
"""
return ImageClassificationInput

@property
def output_model(self) -> Type[ImageClassificationOutput]:
"""
:return: pydantic model class that outputs of this pipeline must comply to
"""
return ImageClassificationOutput

def setup_onnx_file_path(self) -> str:
"""
Expand All @@ -100,7 +110,8 @@ class properties into an inference ready onnx file to be compiled by the

:return: file path to the ONNX file for the engine to compile
"""
return self.model_path

return model_to_path(self.model_path)

def process_inputs(self, inputs: ImageClassificationInput) -> List[numpy.ndarray]:
"""
Expand All @@ -120,9 +131,14 @@ def process_inputs(self, inputs: ImageClassificationInput) -> List[numpy.ndarray
inputs.images = [inputs.images]

for image in inputs.images:
if cv2 is None:
raise RuntimeError(
"cv2 is required to load image inputs from file "
f"Unable to import: {cv2_error}"
)
img = cv2.imread(image) if isinstance(image, str) else image

img = cv2.resize(img, dsize=self.input_shape)
img = cv2.resize(img, dsize=self._image_size)
img = img[:, :, ::-1].transpose(2, 0, 1)
image_batch.append(img)

Expand Down Expand Up @@ -161,38 +177,13 @@ def process_engine_outputs(
labels=labels,
)

@property
def input_model(self) -> Type[ImageClassificationInput]:
"""
:return: pydantic model class that inputs to this pipeline must comply to
"""
return ImageClassificationInput

@property
def output_model(self) -> Type[ImageClassificationOutput]:
"""
:return: pydantic model class that outputs of this pipeline must comply to
"""
return ImageClassificationOutput

@property
def input_shape(self) -> Tuple[int, ...]:
"""
Returns the expected shape of the input tensor

:return: The expected shape of the input tensor from onnx graph
"""
if self._input_shape is None:
self._input_shape = self._infer_image_shape()
return self._input_shape

def _infer_image_shape(self) -> Tuple[int, ...]:
def _infer_image_size(self) -> Tuple[int, ...]:
"""
Infer and return the expected shape of the input tensor

:return: The expected shape of the input tensor from onnx graph
"""
onnx_model = onnx.load(self.engine.model_path)
onnx_model = onnx.load(self.onnx_file_path)
input_tensor = onnx_model.graph.input[0]
return (
input_tensor.type.tensor_type.shape.dim[2].dim_value,
Expand Down
15 changes: 12 additions & 3 deletions src/deepsparse/image_classification/validation_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,18 @@
default=1,
show_default=True,
help="Test batch size, must divide the dataset evenly, else last "
"batch will be dropped",
"batch will be dropped",
)
def main(dataset_path: str, model_path: str, batch_size: int):
@click.option(
"--image-size",
"--image_size",
type=int,
default=224,
show_default=True,
help="Test batch size, must divide the dataset evenly, else last "
"batch will be dropped",
)
def main(dataset_path: str, model_path: str, batch_size: int, image_size: int):
"""
Validation Script for Image Classification Models
"""
Expand All @@ -104,7 +113,7 @@ def main(dataset_path: str, model_path: str, batch_size: int):
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize(size=(224, 224)),
transforms.Resize(size=(image_size, image_size)),
]
),
)
Expand Down
9 changes: 1 addition & 8 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(
if engine_type.lower() == DEEPSPARSE_ENGINE:
self._engine_args["scheduler"] = scheduler

self._onnx_file_path = self.setup_onnx_file_path()
self.onnx_file_path = self.setup_onnx_file_path()
self._engine = self._initialize_engine()

def __call__(self, pipeline_inputs: BaseModel = None, **kwargs) -> BaseModel:
Expand Down Expand Up @@ -386,13 +386,6 @@ def engine_type(self) -> str:
"""
return self._engine_type

@property
def onnx_file_path(self) -> str:
"""
:return: onnx file path used to instantiate engine
"""
return self._onnx_file_path

def to_config(self) -> "PipelineConfig":
"""
:return: PipelineConfig that can be used to reload this object
Expand Down