From 30ee4a7ca011e835b407643bd7c6a88bc3a6285f Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 5 Apr 2022 12:52:21 -0400 Subject: [PATCH 1/2] Allow input overrides in check_correctness.py --- examples/benchmark/check_correctness.py | 38 ++++++++++++++++++++----- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/examples/benchmark/check_correctness.py b/examples/benchmark/check_correctness.py index 81fee551e9..9ab2bb79c1 100644 --- a/examples/benchmark/check_correctness.py +++ b/examples/benchmark/check_correctness.py @@ -51,7 +51,10 @@ generate_random_inputs, get_input_names, get_output_names, + model_to_path, override_onnx_batch_size, + override_onnx_input_shapes, + parse_input_shapes, verify_outputs, ) @@ -81,15 +84,32 @@ def parse_args(): help="The batch size to run the analysis for", ) + parser.add_argument( + "-shapes", + "--input_shapes", + type=str, + default="", + help="Override the shapes of the inputs, " + 'i.e., -shapes "[1,2,3],[4,5,6],[7,8,9]" results in ' + "input0=[1,2,3] input1=[4,5,6] input2=[7,8,9]. ", + ) + return parser.parse_args() def main(): args = parse_args() - onnx_filepath = args.onnx_filepath + onnx_filepath = model_to_path(args.onnx_filepath) batch_size = args.batch_size - inputs = generate_random_inputs(onnx_filepath, batch_size) + input_shapes = parse_input_shapes(args.input_shapes) + + if input_shapes: + with override_onnx_input_shapes(onnx_filepath, input_shapes) as model_path: + inputs = generate_random_inputs(model_path, args.batch_size) + else: + inputs = generate_random_inputs(onnx_filepath, args.batch_size) + input_names = get_input_names(onnx_filepath) output_names = get_output_names(onnx_filepath) inputs_dict = {name: value for name, value in zip(input_names, inputs)} @@ -97,14 +117,18 @@ def main(): # ONNXRuntime inference print("Executing model with ONNXRuntime...") sess_options = onnxruntime.SessionOptions() - with override_onnx_batch_size(onnx_filepath, batch_size) as override_onnx_filepath: - ort_network = onnxruntime.InferenceSession(override_onnx_filepath, sess_options) - - ort_outputs = ort_network.run(output_names, inputs_dict) + if input_shapes: + with override_onnx_input_shapes(onnx_filepath, input_shapes) as override_onnx_filepath: + ort_network = onnxruntime.InferenceSession(override_onnx_filepath, sess_options) + ort_outputs = ort_network.run(output_names, inputs_dict) + else: + with override_onnx_batch_size(onnx_filepath, batch_size) as override_onnx_filepath: + ort_network = onnxruntime.InferenceSession(override_onnx_filepath, sess_options) + ort_outputs = ort_network.run(output_names, inputs_dict) # DeepSparse Engine inference print("Executing model with DeepSparse Engine...") - dse_network = compile_model(onnx_filepath, batch_size=batch_size) + dse_network = compile_model(onnx_filepath, batch_size=batch_size, input_shapes=input_shapes) dse_outputs = dse_network(inputs) verify_outputs(dse_outputs, ort_outputs) From ee478e1089636745538eb982b639395897171cfb Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 5 Apr 2022 16:44:16 -0400 Subject: [PATCH 2/2] review comments, style + quality --- examples/benchmark/check_correctness.py | 30 +++++++++---------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/examples/benchmark/check_correctness.py b/examples/benchmark/check_correctness.py index 9ab2bb79c1..67d1af429c 100644 --- a/examples/benchmark/check_correctness.py +++ b/examples/benchmark/check_correctness.py @@ -44,15 +44,11 @@ import argparse -import onnxruntime - from deepsparse import compile_model, cpu +from deepsparse.benchmark_model.ort_engine import ORTEngine from deepsparse.utils import ( generate_random_inputs, - get_input_names, - get_output_names, model_to_path, - override_onnx_batch_size, override_onnx_input_shapes, parse_input_shapes, verify_outputs, @@ -110,25 +106,21 @@ def main(): else: inputs = generate_random_inputs(onnx_filepath, args.batch_size) - input_names = get_input_names(onnx_filepath) - output_names = get_output_names(onnx_filepath) - inputs_dict = {name: value for name, value in zip(input_names, inputs)} - # ONNXRuntime inference print("Executing model with ONNXRuntime...") - sess_options = onnxruntime.SessionOptions() - if input_shapes: - with override_onnx_input_shapes(onnx_filepath, input_shapes) as override_onnx_filepath: - ort_network = onnxruntime.InferenceSession(override_onnx_filepath, sess_options) - ort_outputs = ort_network.run(output_names, inputs_dict) - else: - with override_onnx_batch_size(onnx_filepath, batch_size) as override_onnx_filepath: - ort_network = onnxruntime.InferenceSession(override_onnx_filepath, sess_options) - ort_outputs = ort_network.run(output_names, inputs_dict) + ort_network = ORTEngine( + model=onnx_filepath, + batch_size=batch_size, + num_cores=None, + input_shapes=input_shapes, + ) + ort_outputs = ort_network.run(inputs) # DeepSparse Engine inference print("Executing model with DeepSparse Engine...") - dse_network = compile_model(onnx_filepath, batch_size=batch_size, input_shapes=input_shapes) + dse_network = compile_model( + onnx_filepath, batch_size=batch_size, input_shapes=input_shapes + ) dse_outputs = dse_network(inputs) verify_outputs(dse_outputs, ort_outputs)