Skip to content

Commit

Permalink
Enable onnx model quantization with trt ep (#554)
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <mengni.wang@intel.com>
  • Loading branch information
mengniwang95 committed Feb 25, 2023
1 parent 4d09eeb commit ba42d00
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
2 changes: 2 additions & 0 deletions neural_compressor/adaptor/onnxrt.py
Expand Up @@ -1146,6 +1146,8 @@ def evaluate(self, input_graph, dataloader, postprocess=None,
convert_attribute=False)
sess_options = ort.SessionOptions()
if self.backend == 'TensorrtExecutionProvider':
from neural_compressor.adaptor.ox_utils.util import trt_env_setup
trt_env_setup(input_graph.model)
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
if measurer:
# https://github.com/microsoft/onnxruntime/issues/7347
Expand Down
15 changes: 14 additions & 1 deletion neural_compressor/adaptor/ox_utils/util.py
Expand Up @@ -684,4 +684,17 @@ def insert_smooth_mul_op_per_op(scales, shape_infos, input_tensors_2_weights_nod
for index, input in enumerate(node.input):
if input == input_key:
node.input[index] = mul_output_name
return new_added_mul_nodes, new_init_tensors, name_2_nodes
return new_added_mul_nodes, new_init_tensors, name_2_nodes

def trt_env_setup(model):
"""Set environment variable for Tensorrt Execution Provider."""
is_int8 = False
for node in model.graph.node:
if node.op_type in ["QuantizeLinear", "DequantizeLinear"]:
is_int8 = True
break
if is_int8:
os.environ["ORT_TENSORRT_INT8_ENABLE"] = "1"
else:
os.environ["ORT_TENSORRT_INT8_ENABLE"] = "0"

0 comments on commit ba42d00

Please sign in to comment.