diff --git a/examples/benchmark/check_correctness.py b/examples/benchmark/check_correctness.py index 81fee551e9..67d1af429c 100644 --- a/examples/benchmark/check_correctness.py +++ b/examples/benchmark/check_correctness.py @@ -44,14 +44,13 @@ import argparse -import onnxruntime - from deepsparse import compile_model, cpu +from deepsparse.benchmark_model.ort_engine import ORTEngine from deepsparse.utils import ( generate_random_inputs, - get_input_names, - get_output_names, - override_onnx_batch_size, + model_to_path, + override_onnx_input_shapes, + parse_input_shapes, verify_outputs, ) @@ -81,30 +80,47 @@ def parse_args(): help="The batch size to run the analysis for", ) + parser.add_argument( + "-shapes", + "--input_shapes", + type=str, + default="", + help="Override the shapes of the inputs, " + 'i.e., -shapes "[1,2,3],[4,5,6],[7,8,9]" results in ' + "input0=[1,2,3] input1=[4,5,6] input2=[7,8,9]. ", + ) + return parser.parse_args() def main(): args = parse_args() - onnx_filepath = args.onnx_filepath + onnx_filepath = model_to_path(args.onnx_filepath) batch_size = args.batch_size - inputs = generate_random_inputs(onnx_filepath, batch_size) - input_names = get_input_names(onnx_filepath) - output_names = get_output_names(onnx_filepath) - inputs_dict = {name: value for name, value in zip(input_names, inputs)} + input_shapes = parse_input_shapes(args.input_shapes) + + if input_shapes: + with override_onnx_input_shapes(onnx_filepath, input_shapes) as model_path: + inputs = generate_random_inputs(model_path, args.batch_size) + else: + inputs = generate_random_inputs(onnx_filepath, args.batch_size) # ONNXRuntime inference print("Executing model with ONNXRuntime...") - sess_options = onnxruntime.SessionOptions() - with override_onnx_batch_size(onnx_filepath, batch_size) as override_onnx_filepath: - ort_network = onnxruntime.InferenceSession(override_onnx_filepath, sess_options) - - ort_outputs = ort_network.run(output_names, inputs_dict) + ort_network = ORTEngine( + model=onnx_filepath, + batch_size=batch_size, + num_cores=None, + input_shapes=input_shapes, + ) + ort_outputs = ort_network.run(inputs) # DeepSparse Engine inference print("Executing model with DeepSparse Engine...") - dse_network = compile_model(onnx_filepath, batch_size=batch_size) + dse_network = compile_model( + onnx_filepath, batch_size=batch_size, input_shapes=input_shapes + ) dse_outputs = dse_network(inputs) verify_outputs(dse_outputs, ort_outputs)