From 49328ea0b5fd73df2ffb6e7c02a25853e24993bc Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 21 Apr 2022 15:30:14 -0400 Subject: [PATCH 1/5] Add Validation Script for Image Classification Models * Update pipelines and corresponding schemas to work with numpy arrays --- setup.py | 1 + .../image_classification/pipelines.py | 33 ++-- .../image_classification/schemas.py | 2 +- .../image_classification/validation_script.py | 150 ++++++++++++++++++ 4 files changed, 171 insertions(+), 15 deletions(-) create mode 100644 src/deepsparse/image_classification/validation_script.py diff --git a/setup.py b/setup.py index af9307e2b9..aa7188bfdb 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,7 @@ ] _ic_integration_deps = [ + "click<8.1", "opencv-python", ] diff --git a/src/deepsparse/image_classification/pipelines.py b/src/deepsparse/image_classification/pipelines.py index 5c2b41b2e6..71be2064f1 100644 --- a/src/deepsparse/image_classification/pipelines.py +++ b/src/deepsparse/image_classification/pipelines.py @@ -90,10 +90,7 @@ def __init__( elif isinstance(class_names, dict): self.class_names = class_names else: - raise ValueError( - "class_names must be a dict or a json file path" - f" (got {type(class_names)} instead)" - ) + self.class_names = None def setup_onnx_file_path(self) -> str: """ @@ -113,22 +110,30 @@ def process_inputs(self, inputs: ImageClassificationInput) -> List[numpy.ndarray :return: list of preprocessed numpy arrays """ - image_batch = [] + if isinstance(inputs.images, numpy.ndarray): + image_batch = inputs.images + else: + + image_batch = [] - if isinstance(inputs.images, str): - inputs.images = [inputs.images] + if isinstance(inputs.images, str): + inputs.images = [inputs.images] - for image in inputs.images: - img = cv2.imread(image) if isinstance(image, str) else image + for image in inputs.images: + img = cv2.imread(image) if isinstance(image, str) else image - img = cv2.resize(img, dsize=self.input_shape) - img = img[:, :, ::-1].transpose(2, 0, 1) + img = cv2.resize(img, dsize=self.input_shape) + img = img[:, :, ::-1].transpose(2, 0, 1) + image_batch.append(img) - image_batch.append(img) + image_batch = numpy.stack(image_batch, axis=0) - image_batch = numpy.stack(image_batch, axis=0) + original_dtype = image_batch.dtype image_batch = numpy.ascontiguousarray(image_batch, dtype=numpy.float32) - image_batch /= 255.0 + + if original_dtype == numpy.uint8: + + image_batch /= 255 # normalize entire batch image_batch -= numpy.asarray(IMAGENET_RGB_MEANS).reshape((-1, 3, 1, 1)) diff --git a/src/deepsparse/image_classification/schemas.py b/src/deepsparse/image_classification/schemas.py index 4232702898..5a92b90e3b 100644 --- a/src/deepsparse/image_classification/schemas.py +++ b/src/deepsparse/image_classification/schemas.py @@ -27,7 +27,7 @@ class ImageClassificationInput(BaseModel): Input model for image classification """ - images: Union[str, List[numpy.ndarray], List[str]] + images: Union[str, numpy.ndarray, List[str]] class Config: arbitrary_types_allowed = True diff --git a/src/deepsparse/image_classification/validation_script.py b/src/deepsparse/image_classification/validation_script.py new file mode 100644 index 0000000000..438d03d331 --- /dev/null +++ b/src/deepsparse/image_classification/validation_script.py @@ -0,0 +1,150 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: validation_script.py [OPTIONS] + + Validation Script for Image Classification Models + +Options: + --dataset-path, --dataset_path DIRECTORY + Path to the validation dataset [required] + --model-path, --model_path TEXT + Path/SparseZoo stub for the Image + Classification model to be evaluated. + Defaults to pruned resnet50 trained on + Imagenette [default: zoo:cv/classification/ + resnet_v1-50/pytorch/sparseml/imagenette/pru + ned-conservative] + --batch-size, --batch_size INTEGER + Test batch size, must divide the dataset + evenly [default: 1] + --help Show this message and exit. + +######### +EXAMPLES +######### + +########## +Example command for validating pruned resnet50 on imagenette dataset: +python validation_script.py \ + --dataset-path /path/to/imagenette/ + +""" +from tqdm import tqdm + +from deepsparse.pipeline import Pipeline +from torch.utils.data import DataLoader +from torchvision import transforms + + +try: + import torchvision + +except ModuleNotFoundError as torchvision_error: # noqa: F841 + print( + "Torchvision not installed. Please install it using the command:" + "pip install torchvision>=0.3.0,<=0.10.1" + ) + exit(1) + +import click + + +resnet50_imagenet_pruned = ( + "zoo:cv/classification/resnet_v1-50/pytorch/" + "sparseml/imagenette/pruned-conservative" +) + + +@click.command() +@click.option( + "--dataset-path", + "--dataset_path", + required=True, + type=click.Path(dir_okay=True, file_okay=False), + help="Path to the validation dataset", +) +@click.option( + "--model-path", + "--model_path", + type=str, + default=resnet50_imagenet_pruned, + help="Path/SparseZoo stub for the Image Classification model to be " + "evaluated. Defaults to pruned resnet50 trained on Imagenette", + show_default=True, +) +@click.option( + "--batch-size", + "--batch_size", + type=int, + default=1, + show_default=True, + help="Test batch size, must divide the dataset evenly", +) +def main(dataset_root: str, model_path: str, batch_size: int): + """ + Validation Script for Image Classification Models + """ + + dataset = torchvision.datasets.ImageFolder( + root=dataset_root, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Resize(size=(224, 224)), + ] + ), + ) + + data_loader = DataLoader( + dataset=dataset, + batch_size=batch_size, + ) + + pipeline = Pipeline.create( + task="image_classification", + model_path=model_path, + batch_size=batch_size, + ) + correct = total = 0 + progress_bar = tqdm(data_loader) + + for batch in progress_bar: + batch, actual_labels = batch + batch = batch.numpy() + outs = pipeline(images=batch) + predicted_labels = outs.labels + + for actual, predicted in zip(actual_labels, predicted_labels): + total += 1 + + if actual.item() == predicted: + correct += 1 + + if total > 0: + progress_bar.set_postfix( + {"Running Accuracy": f"{correct * 100 / total:.2f}%"} + ) + + # prevent division by zero + if total == 0: + epsilon = 1e-5 + total += epsilon + + print(f"Accuracy: {correct * 100 / total:.2f} %") + + +if __name__ == "__main__": + main() From 61032f64e21ffb133f0b0025deafe7689fc0d6c1 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 21 Apr 2022 19:12:48 -0400 Subject: [PATCH 2/5] Bugfix if prediction to be converted to int if it's a string --- .../image_classification/validation_script.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/deepsparse/image_classification/validation_script.py b/src/deepsparse/image_classification/validation_script.py index 438d03d331..2b47aff7fb 100644 --- a/src/deepsparse/image_classification/validation_script.py +++ b/src/deepsparse/image_classification/validation_script.py @@ -23,7 +23,7 @@ --model-path, --model_path TEXT Path/SparseZoo stub for the Image Classification model to be evaluated. - Defaults to pruned resnet50 trained on + Defaults to resnet50 trained on Imagenette [default: zoo:cv/classification/ resnet_v1-50/pytorch/sparseml/imagenette/pru ned-conservative] @@ -63,8 +63,7 @@ resnet50_imagenet_pruned = ( - "zoo:cv/classification/resnet_v1-50/pytorch/" - "sparseml/imagenette/pruned-conservative" + "zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenette/base-none" ) @@ -82,7 +81,7 @@ type=str, default=resnet50_imagenet_pruned, help="Path/SparseZoo stub for the Image Classification model to be " - "evaluated. Defaults to pruned resnet50 trained on Imagenette", + "evaluated. Defaults to resnet50 trained on Imagenette", show_default=True, ) @click.option( @@ -93,13 +92,13 @@ show_default=True, help="Test batch size, must divide the dataset evenly", ) -def main(dataset_root: str, model_path: str, batch_size: int): +def main(dataset_path: str, model_path: str, batch_size: int): """ Validation Script for Image Classification Models """ dataset = torchvision.datasets.ImageFolder( - root=dataset_root, + root=dataset_path, transform=transforms.Compose( [ transforms.ToTensor(), @@ -129,7 +128,8 @@ def main(dataset_root: str, model_path: str, batch_size: int): for actual, predicted in zip(actual_labels, predicted_labels): total += 1 - + if isinstance(predicted, str): + predicted = int(predicted) if actual.item() == predicted: correct += 1 From 925f8d969f858b0d10f2a3bb35f92566b20d0798 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 21 Apr 2022 19:15:18 -0400 Subject: [PATCH 3/5] Update docstring --- src/deepsparse/image_classification/validation_script.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/deepsparse/image_classification/validation_script.py b/src/deepsparse/image_classification/validation_script.py index 2b47aff7fb..5f5a39b565 100644 --- a/src/deepsparse/image_classification/validation_script.py +++ b/src/deepsparse/image_classification/validation_script.py @@ -25,8 +25,8 @@ Classification model to be evaluated. Defaults to resnet50 trained on Imagenette [default: zoo:cv/classification/ - resnet_v1-50/pytorch/sparseml/imagenette/pru - ned-conservative] + resnet_v1-50/pytorch/sparseml/imagenette/ + base-none] --batch-size, --batch_size INTEGER Test batch size, must divide the dataset evenly [default: 1] From 51f2a8685e96a359268cf26d745011d9d106bdba Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Mon, 25 Apr 2022 12:32:10 -0400 Subject: [PATCH 4/5] Update src/deepsparse/image_classification/validation_script.py Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> --- src/deepsparse/image_classification/validation_script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/deepsparse/image_classification/validation_script.py b/src/deepsparse/image_classification/validation_script.py index 5f5a39b565..f7da75d389 100644 --- a/src/deepsparse/image_classification/validation_script.py +++ b/src/deepsparse/image_classification/validation_script.py @@ -81,7 +81,7 @@ type=str, default=resnet50_imagenet_pruned, help="Path/SparseZoo stub for the Image Classification model to be " - "evaluated. Defaults to resnet50 trained on Imagenette", + "evaluated. Defaults to dense (vanilla) resnet50 trained on Imagenette", show_default=True, ) @click.option( From 6e2ca42fcfe071b362a7f19738cc24cb7353bc67 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 26 Apr 2022 10:49:51 -0400 Subject: [PATCH 5/5] Comments from @bogunowicz --- src/deepsparse/image_classification/validation_script.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/deepsparse/image_classification/validation_script.py b/src/deepsparse/image_classification/validation_script.py index f7da75d389..db9ed8fe16 100644 --- a/src/deepsparse/image_classification/validation_script.py +++ b/src/deepsparse/image_classification/validation_script.py @@ -29,7 +29,8 @@ base-none] --batch-size, --batch_size INTEGER Test batch size, must divide the dataset - evenly [default: 1] + evenly, else the last batch will be dropped + [default: 1] --help Show this message and exit. ######### @@ -90,7 +91,8 @@ type=int, default=1, show_default=True, - help="Test batch size, must divide the dataset evenly", + help="Test batch size, must divide the dataset evenly, else last " + "batch will be dropped", ) def main(dataset_path: str, model_path: str, batch_size: int): """ @@ -110,6 +112,7 @@ def main(dataset_path: str, model_path: str, batch_size: int): data_loader = DataLoader( dataset=dataset, batch_size=batch_size, + drop_last=True, ) pipeline = Pipeline.create(