From 0e0f5d7ab810477a9b491637dd7fd1d104027067 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Sat, 27 Mar 2021 22:06:57 -0400 Subject: [PATCH 1/4] YOLOv3 benchmarking static input shape override --- .../ultralytics/deepsparse/benchmark.py | 25 +++- .../deepsparse/deepsparse_utils.py | 137 ++++++++++++++---- integrations/ultralytics/train.py | 6 +- 3 files changed, 132 insertions(+), 36 deletions(-) diff --git a/integrations/ultralytics/deepsparse/benchmark.py b/integrations/ultralytics/deepsparse/benchmark.py index 13bc1841b28..c9094026c06 100644 --- a/integrations/ultralytics/deepsparse/benchmark.py +++ b/integrations/ultralytics/deepsparse/benchmark.py @@ -110,7 +110,12 @@ from deepsparse import compile_model, cpu from deepsparse.benchmark import BenchmarkResults -from deepsparse_utils import load_image, postprocess_nms, pre_nms_postprocess +from deepsparse_utils import ( + YoloPostprocessor, + load_image, + modify_yolo_onnx_input_shape, + postprocess_nms, +) from sparseml.onnx.utils import override_model_batch_size from sparsezoo.models.detection import yolo_v3 as zoo_yolo_v3 from sparsezoo.utils import load_numpy_list @@ -288,6 +293,12 @@ def _load_model(args) -> Any: "environment variable" ) + # scale static ONNX graph to desired image shape + if args.engine in [DEEPSPARSE_ENGINE, ORT_ENGINE]: + args.model_filepath, _ = modify_yolo_onnx_input_shape( + args.model_filepath, args.image_shape + ) + # load model if args.engine == DEEPSPARSE_ENGINE: print(f"Compiling deepsparse model for {args.model_filepath}") @@ -376,7 +387,7 @@ def _run_model( def benchmark_yolo(args): model = _load_model(args) print("Loading dataset") - dataset = _load_images(args.data_path, args.image_shape) + dataset = _load_images(args.data_path, tuple(args.image_shape)) total_iterations = args.num_iterations + args.num_warmup_iterations data_loader = _iter_batches(dataset, args.batch_size, total_iterations) @@ -388,6 +399,12 @@ def benchmark_yolo(args): flush=True, ) + postprocessor = ( + YoloPostprocessor(args.image_shape) + if args.engine in [DEEPSPARSE_ENGINE, ORT_ENGINE] + else None + ) + results = BenchmarkResults() progress_bar = tqdm(total=args.num_iterations) @@ -403,8 +420,8 @@ def benchmark_yolo(args): outputs = _run_model(args, model, batch) # post-processing - if args.engine != TORCH_ENGINE: - outputs = pre_nms_postprocess(outputs) + if postprocessor: + outputs = postprocessor.pre_nms_postprocess(outputs) # NMS outputs = postprocess_nms(outputs) diff --git a/integrations/ultralytics/deepsparse/deepsparse_utils.py b/integrations/ultralytics/deepsparse/deepsparse_utils.py index 9eecb8bf22f..f72a39762ed 100644 --- a/integrations/ultralytics/deepsparse/deepsparse_utils.py +++ b/integrations/ultralytics/deepsparse/deepsparse_utils.py @@ -20,12 +20,16 @@ """ -from typing import List, Tuple, Union +from tempfile import NamedTemporaryFile +from typing import List, Optional, Tuple, Union import cv2 import numpy +import onnx import torch +from sparsezoo import Zoo + # ultralytics/yolov5 imports from utils.general import non_max_suppression @@ -37,13 +41,6 @@ ] -def _get_grid(size: int) -> torch.Tensor: - # adapted from yolov5.yolo.Detect._make_grid - coords_y, coords_x = torch.meshgrid([torch.arange(size), torch.arange(size)]) - grid = torch.stack((coords_x, coords_y), 2) - return grid.view(1, 1, size, size, 2).float() - - # Yolo V3 specific variables _YOLO_V3_ANCHORS = [ torch.Tensor([[10, 13], [16, 30], [33, 23]]), @@ -51,8 +48,6 @@ def _get_grid(size: int) -> torch.Tensor: torch.Tensor([[116, 90], [156, 198], [373, 326]]), ] _YOLO_V3_ANCHOR_GRIDS = [t.clone().view(1, -1, 1, 1, 2) for t in _YOLO_V3_ANCHORS] -_YOLO_V3_OUTPUT_SHAPES = [80, 40, 20] -_YOLO_V3_GRIDS = [_get_grid(grid_size) for grid_size in _YOLO_V3_OUTPUT_SHAPES] def load_image( @@ -70,28 +65,50 @@ def load_image( return img -def pre_nms_postprocess(outputs: List[numpy.ndarray]) -> torch.Tensor: +class YoloPostprocessor: """ - :param outputs: raw outputs of a YOLOv3 model before anchor grid processing - :return: post-processed model outputs without NMS. + Class for performing postprocessing of YOLOv3 model predictions """ - # postprocess and transform raw outputs into single torch tensor - processed_outputs = [] - for idx, pred in enumerate(outputs): - pred = torch.from_numpy(pred) - pred = pred.sigmoid() - # get grid and stride - grid = _YOLO_V3_GRIDS[idx] - anchor_grid = _YOLO_V3_ANCHOR_GRIDS[idx] - stride = 640 / _YOLO_V3_OUTPUT_SHAPES[idx] - - # decode xywh box values - pred[..., 0:2] = (pred[..., 0:2] * 2.0 - 0.5 + grid) * stride - pred[..., 2:4] = (pred[..., 2:4] * 2) ** 2 * anchor_grid - # flatten anchor and grid dimensions -> (bs, num_predictions, num_classes + 5) - processed_outputs.append(pred.view(pred.size(0), -1, pred.size(-1))) - return torch.cat(processed_outputs, 1) + def __init__(self, image_size: Tuple[int]): + self._image_size = image_size + self._grids = {} # Dict[Tuple[int], torch.Tensor] + + def pre_nms_postprocess(self, outputs: List[numpy.ndarray]) -> torch.Tensor: + """ + :param outputs: raw outputs of a YOLOv3 model before anchor grid processing + :return: post-processed model outputs without NMS. + """ + # postprocess and transform raw outputs into single torch tensor + processed_outputs = [] + for idx, pred in enumerate(outputs): + pred = torch.from_numpy(pred) + pred = pred.sigmoid() + + # get grid and stride + grid_shape = pred.shape[2:4] + grid = self._get_grid(grid_shape) + anchor_grid = _YOLO_V3_ANCHOR_GRIDS[idx] + stride = self._image_size[0] / grid_shape[0] + + # decode xywh box values + pred[..., 0:2] = (pred[..., 0:2] * 2.0 - 0.5 + grid) * stride + pred[..., 2:4] = (pred[..., 2:4] * 2) ** 2 * anchor_grid + # flatten anchor and grid dimensions -> (bs, num_predictions, num_classes + 5) + processed_outputs.append(pred.view(pred.size(0), -1, pred.size(-1))) + return torch.cat(processed_outputs, 1) + + def _get_grid(self, grid_shape: Tuple[int]) -> torch.Tensor: + if grid_shape not in self._grids: + # adapted from yolov5.yolo.Detect._make_grid + coords_y, coords_x = torch.meshgrid( + [torch.arange(grid_shape[0]), torch.arange(grid_shape[1])] + ) + grid = torch.stack((coords_x, coords_y), 2) + self._grids[grid_shape] = grid.view( + 1, 1, grid_shape[0], grid_shape[1], 2 + ).float() + return self._grids[grid_shape] def postprocess_nms(outputs: torch.Tensor) -> List[numpy.ndarray]: @@ -102,3 +119,65 @@ def postprocess_nms(outputs: torch.Tensor) -> List[numpy.ndarray]: # run nms in PyTorch, only post-process first output nms_outputs = non_max_suppression(outputs) return [output.cpu().numpy() for output in nms_outputs] + + +def modify_yolo_onnx_input_shape( + model_path: str, image_shape: Tuple[int] +) -> Tuple[str, Optional[NamedTemporaryFile]]: + """ + + :param model_path: file path to YOLOv3 ONNX model or SparseZoo stub of the model + to be loaded + :param image_shape: 2-tuple of the image shape to resize this yolo model to + :return: filepath to an onnx model reshaped to the given input shape and tempfile + object that the modified model is written to. if the given model has the given + input shape, then the path to the original model will be returned with no + tempfile. tempfile returned so caller can control when the tempfile is destroyed + """ + original_model_path = model_path + if model_path.startswith("zoo:"): + # load SparseZoo Model from stub + model = Zoo.load_model_from_stub(model_path) + model_path = model.onnx_file.downloaded_path() + + model = onnx.load(model_path) + model_input = model.graph.input[0] + + initial_x = _get_onnx_tensor_idx_shape(model_input, 2) + initial_y = _get_onnx_tensor_idx_shape(model_input, 3) + + if not (isinstance(initial_x, int) and isinstance(initial_y, int)): + return model_path, None # model graph does not have static integer input shape + + if (initial_x, initial_y) == tuple(image_shape): + return model_path, None # no shape modification needed + + scale_x = initial_x / image_shape[0] + scale_y = initial_y / image_shape[1] + _set_onnx_tensor_idx_shape(model_input, 2, image_shape[0]) + _set_onnx_tensor_idx_shape(model_input, 3, image_shape[1]) + + for model_output in model.graph.output: + output_x = _get_onnx_tensor_idx_shape(model_output, 2) + output_y = _get_onnx_tensor_idx_shape(model_output, 3) + _set_onnx_tensor_idx_shape(model_output, 2, int(output_x / scale_x)) + _set_onnx_tensor_idx_shape(model_output, 3, int(output_y / scale_y)) + + tmp_file = NamedTemporaryFile() # file will be deleted after program exit + onnx.save(model, tmp_file.name) + + print( + f"Overwriting original model shape {(initial_x, initial_y)} to {image_shape}\n" + f"Original model path: {original_model_path}, new temporary model saved to " + f"{tmp_file.name}" + ) + + return tmp_file.name, tmp_file + + +def _get_onnx_tensor_idx_shape(tensor: onnx.TensorProto, idx: int) -> int: + return tensor.type.tensor_type.shape.dim[idx].dim_value + + +def _set_onnx_tensor_idx_shape(tensor: onnx.TensorProto, idx: int, value: int): + tensor.type.tensor_type.shape.dim[idx].dim_value = value diff --git a/integrations/ultralytics/train.py b/integrations/ultralytics/train.py index 5b5fedb46d5..c3bec048991 100644 --- a/integrations/ultralytics/train.py +++ b/integrations/ultralytics/train.py @@ -269,7 +269,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): #################################################################################### from sparseml.pytorch.nn import replace_activations from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer - from sparseml.pytorch.utils import PythonLogger, TensorBoardLogger + from sparseml.pytorch.utils import is_parallel_model, PythonLogger, TensorBoardLogger if not opt.no_leaky_relu_override: # use LeakyReLU activations model = replace_activations(model, "lrelu", inplace=True) @@ -277,7 +277,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): manager = ScheduledModifierManager.from_yaml(opt.sparseml_recipe) optimizer = ScheduledOptimizer( optimizer, - model, + model if not is_parallel_model(model) else model.module, manager, steps_per_epoch=len(dataloader), loggers=[PythonLogger(), TensorBoardLogger(writer=tb_writer)] @@ -521,7 +521,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ################################################################################# # Start SparseML ONNX Export ################################################################################# - from sparseml.pytorch.utils import ModuleExporter, is_parallel_model + from sparseml.pytorch.utils import ModuleExporter from sparseml.pytorch.utils.quantization import skip_onnx_input_quantize onnx_path = f"{save_dir}/model.onnx" From 38c65a2056b3ce7a32af6585cf1bc20ed7c2abf8 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Sun, 28 Mar 2021 18:48:07 -0400 Subject: [PATCH 2/4] updates from review comments --- .../deepsparse/deepsparse_utils.py | 29 ++++++++----------- src/sparseml/onnx/utils/helpers.py | 22 ++++++++++++++ 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/integrations/ultralytics/deepsparse/deepsparse_utils.py b/integrations/ultralytics/deepsparse/deepsparse_utils.py index f72a39762ed..4a09554b24b 100644 --- a/integrations/ultralytics/deepsparse/deepsparse_utils.py +++ b/integrations/ultralytics/deepsparse/deepsparse_utils.py @@ -28,6 +28,7 @@ import onnx import torch +from sparseml.onnx.utils import get_tensor_dim_shape, set_tensor_dim_shape from sparsezoo import Zoo # ultralytics/yolov5 imports @@ -68,6 +69,9 @@ def load_image( class YoloPostprocessor: """ Class for performing postprocessing of YOLOv3 model predictions + + :param image_size: size of input image to model. used to calculate stride based on + output shapes """ def __init__(self, image_size: Tuple[int]): @@ -125,7 +129,6 @@ def modify_yolo_onnx_input_shape( model_path: str, image_shape: Tuple[int] ) -> Tuple[str, Optional[NamedTemporaryFile]]: """ - :param model_path: file path to YOLOv3 ONNX model or SparseZoo stub of the model to be loaded :param image_shape: 2-tuple of the image shape to resize this yolo model to @@ -143,8 +146,8 @@ def modify_yolo_onnx_input_shape( model = onnx.load(model_path) model_input = model.graph.input[0] - initial_x = _get_onnx_tensor_idx_shape(model_input, 2) - initial_y = _get_onnx_tensor_idx_shape(model_input, 3) + initial_x = get_tensor_dim_shape(model_input, 2) + initial_y = get_tensor_dim_shape(model_input, 3) if not (isinstance(initial_x, int) and isinstance(initial_y, int)): return model_path, None # model graph does not have static integer input shape @@ -154,14 +157,14 @@ def modify_yolo_onnx_input_shape( scale_x = initial_x / image_shape[0] scale_y = initial_y / image_shape[1] - _set_onnx_tensor_idx_shape(model_input, 2, image_shape[0]) - _set_onnx_tensor_idx_shape(model_input, 3, image_shape[1]) + set_tensor_dim_shape(model_input, 2, image_shape[0]) + set_tensor_dim_shape(model_input, 3, image_shape[1]) for model_output in model.graph.output: - output_x = _get_onnx_tensor_idx_shape(model_output, 2) - output_y = _get_onnx_tensor_idx_shape(model_output, 3) - _set_onnx_tensor_idx_shape(model_output, 2, int(output_x / scale_x)) - _set_onnx_tensor_idx_shape(model_output, 3, int(output_y / scale_y)) + output_x = get_tensor_dim_shape(model_output, 2) + output_y = get_tensor_dim_shape(model_output, 3) + set_tensor_dim_shape(model_output, 2, int(output_x / scale_x)) + set_tensor_dim_shape(model_output, 3, int(output_y / scale_y)) tmp_file = NamedTemporaryFile() # file will be deleted after program exit onnx.save(model, tmp_file.name) @@ -173,11 +176,3 @@ def modify_yolo_onnx_input_shape( ) return tmp_file.name, tmp_file - - -def _get_onnx_tensor_idx_shape(tensor: onnx.TensorProto, idx: int) -> int: - return tensor.type.tensor_type.shape.dim[idx].dim_value - - -def _set_onnx_tensor_idx_shape(tensor: onnx.TensorProto, idx: int, value: int): - tensor.type.tensor_type.shape.dim[idx].dim_value = value diff --git a/src/sparseml/onnx/utils/helpers.py b/src/sparseml/onnx/utils/helpers.py index 35a81a14c75..d128e21ea55 100644 --- a/src/sparseml/onnx/utils/helpers.py +++ b/src/sparseml/onnx/utils/helpers.py @@ -70,6 +70,8 @@ "get_kernel_shape", "calculate_flops", "get_quantize_parent_for_dequantize_node", + "get_tensor_dim_shape", + "set_tensor_dim_shape", ] @@ -1189,3 +1191,23 @@ def get_quantize_parent_for_dequantize_node( input_nodes = get_node_input_nodes(quantized_model, curr_node) curr_node = input_nodes[0] if input_nodes else None return curr_node + + +def get_tensor_dim_shape(tensor: onnx.TensorProto, dim: int) -> int: + """ + :param tensor: ONNX tensor to get the shape of a dimension of + :param dim: dimension index of the tensor to get the shape of + :return: shape of the tensor at the given dimension + """ + return tensor.type.tensor_type.shape.dim[dim].dim_value + + +def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: int): + """ + Sets the shape of the tensor at the given dimension to the given value + + :param tensor: ONNX tensor to modify the shape of + :param dim: dimension index of the tensor to modify the shape of + :param value: new shape for the given dimension + """ + tensor.type.tensor_type.shape.dim[dim].dim_value = value From 9dde4bc7fd33cb84d3cc31845c57f3a7aa4c9978 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Sun, 28 Mar 2021 19:21:30 -0400 Subject: [PATCH 3/4] udate README with sparsezoo stubs --- integrations/ultralytics/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/ultralytics/README.md b/integrations/ultralytics/README.md index 8ab4d3d8f21..bea59b6053c 100644 --- a/integrations/ultralytics/README.md +++ b/integrations/ultralytics/README.md @@ -130,8 +130,8 @@ performance with DeepSparse. For a full list of options run `python benchmarkin To run a benchmark run: ```bash python benchmark.py - zoo:cv/detection/yolo_v3-spp/pytorch/ultralytics/coco/pruned_quant-aggressive_90 \ - --batch-size 32 \ + zoo:cv/detection/yolo_v3-spp/pytorch/ultralytics/coco/pruned_quant-aggressive_94 \ + --batch-size 1 \ --quantized-inputs ``` From 5d338509adc184368f32b7f739ffed3af4f36f14 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 29 Mar 2021 17:35:29 -0400 Subject: [PATCH 4/4] support for num-sockets, default num cores to None, update docs, updates from review --- integrations/ultralytics/deepsparse/SERVER.md | 5 +- .../ultralytics/deepsparse/benchmark.py | 52 ++++++++++++++----- .../deepsparse/deepsparse_utils.py | 12 +++-- 3 files changed, 50 insertions(+), 19 deletions(-) diff --git a/integrations/ultralytics/deepsparse/SERVER.md b/integrations/ultralytics/deepsparse/SERVER.md index 3cd0a19b582..af582b86abd 100644 --- a/integrations/ultralytics/deepsparse/SERVER.md +++ b/integrations/ultralytics/deepsparse/SERVER.md @@ -49,11 +49,12 @@ pip install deepsparse sparseml flask flask-cors ### Server -First, start up the host `server.py` with your model of choice. +First, start up the host `server.py` with your model of choice, SparseZoo stubs are +also supported. Example command: ```bash -python server.py ~/models/yolov3-pruned_quant.onnx +python server.py zoo:cv/detection/yolo_v3-spp/pytorch/ultralytics/coco/pruned_quant-aggressive_94 ``` You can leave that running as a detached process or in a spare terminal. diff --git a/integrations/ultralytics/deepsparse/benchmark.py b/integrations/ultralytics/deepsparse/benchmark.py index c9094026c06..ada02bb1afd 100644 --- a/integrations/ultralytics/deepsparse/benchmark.py +++ b/integrations/ultralytics/deepsparse/benchmark.py @@ -21,8 +21,9 @@ usage: benchmark.py [-h] [-e {deepsparse,onnxruntime,torch}] [--data-path DATA_PATH] [--image-shape IMAGE_SHAPE [IMAGE_SHAPE ...]] - [-b BATCH_SIZE] [-c NUM_CORES] [-i NUM_ITERATIONS] - [-w NUM_WARMUP_ITERATIONS] [-q] [--fp16] [--device DEVICE] + [-b BATCH_SIZE] [-c NUM_CORES] [-s NUM_SOCKETS] + [-i NUM_ITERATIONS] [-w NUM_WARMUP_ITERATIONS] [-q] + [--fp16] [--device DEVICE] model_filepath Benchmark sparsified YOLOv3 models @@ -53,7 +54,13 @@ The batch size to run the benchmark for -c NUM_CORES, --num-cores NUM_CORES The number of physical cores to run the benchmark on, - defaults to all physical cores available on the system + defaults to None where it uses all physical cores + available on the system. For DeepSparse benchmarks, + this value is the number of cores per socket + -s NUM_SOCKETS, --num-sockets NUM_SOCKETS + For DeepSparse benchmarks only. The number of physical + cores to run the benchmark on. Defaults to None where + is uses all sockets available on the system -i NUM_ITERATIONS, --num-iterations NUM_ITERATIONS The number of iterations the benchmark will be run for -w NUM_WARMUP_ITERATIONS, --num-warmup-iterations NUM_WARMUP_ITERATIONS @@ -121,8 +128,6 @@ from sparsezoo.utils import load_numpy_list -CORES_PER_SOCKET, AVX_TYPE, _ = cpu.cpu_details() - DEEPSPARSE_ENGINE = "deepsparse" ORT_ENGINE = "onnxruntime" TORCH_ENGINE = "torch" @@ -185,10 +190,22 @@ def parse_args(): "-c", "--num-cores", type=int, - default=CORES_PER_SOCKET, + default=None, help=( "The number of physical cores to run the benchmark on, " - "defaults to all physical cores available on the system" + "defaults to None where it uses all physical cores available on the system. " + "For DeepSparse benchmarks, this value is the number of cores per socket" + ), + ) + parser.add_argument( + "-s", + "--num-sockets", + type=int, + default=None, + help=( + "For DeepSparse benchmarks only. The number of physical cores to run the " + "benchmark on. Defaults to None where is uses all sockets available on the " + "system" ), ) parser.add_argument( @@ -232,7 +249,6 @@ def parse_args(): ) args = parser.parse_args() - if args.engine == TORCH_ENGINE and args.device is None: args.device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -278,20 +294,22 @@ def _load_model(args) -> Any: raise ValueError(f"half precision is not supported for {args.engine}") if args.quantized_inputs and args.engine == TORCH_ENGINE: raise ValueError(f"quantized inputs not supported for {args.engine}") - if args.num_cores != CORES_PER_SOCKET and args.engine == TORCH_ENGINE: + if args.num_cores is not None and args.engine == TORCH_ENGINE: raise ValueError( f"overriding default num_cores not supported for {args.engine}" ) if ( - args.num_cores != CORES_PER_SOCKET + args.num_cores is not None and args.engine == ORT_ENGINE and onnxruntime.__version__ < "1.7" ): - print( + raise ValueError( "overriding default num_cores not supported for onnxruntime < 1.7.0. " "If using an older build with OpenMP, try setting the OMP_NUM_THREADS " "environment variable" ) + if args.num_sockets is not None and args.engine != DEEPSPARSE_ENGINE: + raise ValueError(f"Overriding num_sockets is not supported for {args.engine}") # scale static ONNX graph to desired image shape if args.engine in [DEEPSPARSE_ENGINE, ORT_ENGINE]: @@ -302,12 +320,20 @@ def _load_model(args) -> Any: # load model if args.engine == DEEPSPARSE_ENGINE: print(f"Compiling deepsparse model for {args.model_filepath}") - model = compile_model(args.model_filepath, args.batch_size, args.num_cores) + model = compile_model( + args.model_filepath, args.batch_size, args.num_cores, args.num_sockets + ) + if args.quantized_inputs and not model.cpu_vnni: + print( + "WARNING: VNNI instructions not detected, " + "quantization speedup not well supported" + ) elif args.engine == ORT_ENGINE: print(f"loading onnxruntime model for {args.model_filepath}") sess_options = onnxruntime.SessionOptions() - sess_options.intra_op_num_threads = args.num_cores + if args.num_cores is not None: + sess_options.intra_op_num_threads = args.num_cores sess_options.log_severity_level = 3 sess_options.graph_optimization_level = ( onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL diff --git a/integrations/ultralytics/deepsparse/deepsparse_utils.py b/integrations/ultralytics/deepsparse/deepsparse_utils.py index 4a09554b24b..d2ce24bab1a 100644 --- a/integrations/ultralytics/deepsparse/deepsparse_utils.py +++ b/integrations/ultralytics/deepsparse/deepsparse_utils.py @@ -129,19 +129,23 @@ def modify_yolo_onnx_input_shape( model_path: str, image_shape: Tuple[int] ) -> Tuple[str, Optional[NamedTemporaryFile]]: """ + Creates a new YOLOv3 ONNX model from the given path that accepts the given input + shape. If the given model already has the given input shape no modifications are + made. Uses a tempfile to store the modified model file. + :param model_path: file path to YOLOv3 ONNX model or SparseZoo stub of the model to be loaded :param image_shape: 2-tuple of the image shape to resize this yolo model to - :return: filepath to an onnx model reshaped to the given input shape and tempfile - object that the modified model is written to. if the given model has the given - input shape, then the path to the original model will be returned with no - tempfile. tempfile returned so caller can control when the tempfile is destroyed + :return: filepath to an onnx model reshaped to the given input shape will be the + original path if the shape is the same. Additionally returns the + NamedTemporaryFile for managing the scope of the object for file deletion """ original_model_path = model_path if model_path.startswith("zoo:"): # load SparseZoo Model from stub model = Zoo.load_model_from_stub(model_path) model_path = model.onnx_file.downloaded_path() + print(f"Downloaded {original_model_path} to {model_path}") model = onnx.load(model_path) model_input = model.graph.input[0]