Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
]

_ic_integration_deps = [
"click<8.1",
"opencv-python",
]

Expand Down
33 changes: 19 additions & 14 deletions src/deepsparse/image_classification/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,7 @@ def __init__(
elif isinstance(class_names, dict):
self.class_names = class_names
else:
raise ValueError(
"class_names must be a dict or a json file path"
f" (got {type(class_names)} instead)"
)
self.class_names = None

def setup_onnx_file_path(self) -> str:
"""
Expand All @@ -113,22 +110,30 @@ def process_inputs(self, inputs: ImageClassificationInput) -> List[numpy.ndarray
:return: list of preprocessed numpy arrays
"""

image_batch = []
if isinstance(inputs.images, numpy.ndarray):
image_batch = inputs.images
else:

image_batch = []

if isinstance(inputs.images, str):
inputs.images = [inputs.images]
if isinstance(inputs.images, str):
inputs.images = [inputs.images]

for image in inputs.images:
img = cv2.imread(image) if isinstance(image, str) else image
for image in inputs.images:
img = cv2.imread(image) if isinstance(image, str) else image

img = cv2.resize(img, dsize=self.input_shape)
img = img[:, :, ::-1].transpose(2, 0, 1)
img = cv2.resize(img, dsize=self.input_shape)
img = img[:, :, ::-1].transpose(2, 0, 1)
image_batch.append(img)

image_batch.append(img)
image_batch = numpy.stack(image_batch, axis=0)

image_batch = numpy.stack(image_batch, axis=0)
original_dtype = image_batch.dtype
image_batch = numpy.ascontiguousarray(image_batch, dtype=numpy.float32)
image_batch /= 255.0

if original_dtype == numpy.uint8:

image_batch /= 255

# normalize entire batch
image_batch -= numpy.asarray(IMAGENET_RGB_MEANS).reshape((-1, 3, 1, 1))
Expand Down
2 changes: 1 addition & 1 deletion src/deepsparse/image_classification/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ImageClassificationInput(BaseModel):
Input model for image classification
"""

images: Union[str, List[numpy.ndarray], List[str]]
images: Union[str, numpy.ndarray, List[str]]

class Config:
arbitrary_types_allowed = True
Expand Down
153 changes: 153 additions & 0 deletions src/deepsparse/image_classification/validation_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Usage: validation_script.py [OPTIONS]

Validation Script for Image Classification Models

Options:
--dataset-path, --dataset_path DIRECTORY
Path to the validation dataset [required]
--model-path, --model_path TEXT
Path/SparseZoo stub for the Image
Classification model to be evaluated.
Defaults to resnet50 trained on
Imagenette [default: zoo:cv/classification/
resnet_v1-50/pytorch/sparseml/imagenette/
base-none]
--batch-size, --batch_size INTEGER
Test batch size, must divide the dataset
evenly, else the last batch will be dropped
[default: 1]
--help Show this message and exit.

#########
EXAMPLES
#########

##########
Example command for validating pruned resnet50 on imagenette dataset:
python validation_script.py \
--dataset-path /path/to/imagenette/

"""
from tqdm import tqdm

from deepsparse.pipeline import Pipeline
from torch.utils.data import DataLoader
from torchvision import transforms


try:
import torchvision

except ModuleNotFoundError as torchvision_error: # noqa: F841
print(
"Torchvision not installed. Please install it using the command:"
"pip install torchvision>=0.3.0,<=0.10.1"
)
exit(1)

import click


resnet50_imagenet_pruned = (
"zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenette/base-none"
)


@click.command()
@click.option(
"--dataset-path",
"--dataset_path",
required=True,
type=click.Path(dir_okay=True, file_okay=False),
help="Path to the validation dataset",
)
@click.option(
"--model-path",
"--model_path",
type=str,
default=resnet50_imagenet_pruned,
help="Path/SparseZoo stub for the Image Classification model to be "
"evaluated. Defaults to dense (vanilla) resnet50 trained on Imagenette",
show_default=True,
)
@click.option(
"--batch-size",
"--batch_size",
type=int,
default=1,
show_default=True,
help="Test batch size, must divide the dataset evenly, else last "
"batch will be dropped",
)
def main(dataset_path: str, model_path: str, batch_size: int):
"""
Validation Script for Image Classification Models
"""

dataset = torchvision.datasets.ImageFolder(
root=dataset_path,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize(size=(224, 224)),
]
),
)

data_loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
drop_last=True,
)

pipeline = Pipeline.create(
task="image_classification",
model_path=model_path,
batch_size=batch_size,
)
correct = total = 0
progress_bar = tqdm(data_loader)

for batch in progress_bar:
batch, actual_labels = batch
batch = batch.numpy()
outs = pipeline(images=batch)
predicted_labels = outs.labels

for actual, predicted in zip(actual_labels, predicted_labels):
total += 1
if isinstance(predicted, str):
predicted = int(predicted)
if actual.item() == predicted:
correct += 1

if total > 0:
progress_bar.set_postfix(
{"Running Accuracy": f"{correct * 100 / total:.2f}%"}
)

# prevent division by zero
if total == 0:
epsilon = 1e-5
total += epsilon

print(f"Accuracy: {correct * 100 / total:.2f} %")


if __name__ == "__main__":
main()