Skip to content

Commit

Permalink
enhance onnxrt backend setting
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
  • Loading branch information
yuwenzho committed Sep 6, 2023
1 parent a16332d commit 295535a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
8 changes: 6 additions & 2 deletions neural_compressor/adaptor/onnxrt.py
Expand Up @@ -83,6 +83,10 @@ def __init__(self, framework_specific_info):
self.format = "integerops"
if "format" in framework_specific_info and framework_specific_info["format"].lower() == "qdq":
logger.warning("Dynamic approach doesn't support QDQ format.")

# do not load TensorRT if backend is not TensorrtExecutionProvider
if self.backend != "TensorrtExecutionProvider":
os.environ["ORT_TENSORRT_UNAVAILABLE"] = "1"

# get quantization config file according to backend
config_file = None
Expand Down Expand Up @@ -700,9 +704,9 @@ def _detect_domain(self, model):
# typically, NLP models have multiple inputs,
# and the dimension of each input is usually 2 (batch_size, max_seq_len)
if not model.is_large_model:
sess = ort.InferenceSession(model.model.SerializeToString(), providers=ort.get_available_providers())
sess = ort.InferenceSession(model.model.SerializeToString(), providers=["CPUExecutionProvider"])
elif model.model_path is not None: # pragma: no cover
sess = ort.InferenceSession(model.model_path, providers=ort.get_available_providers())
sess = ort.InferenceSession(model.model_path, providers=["CPUExecutionProvider"])
else: # pragma: no cover
assert False, "Please use model path instead of onnx model object to quantize."
input_shape_lens = [len(input.shape) for input in sess.get_inputs()]
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/model/model.py
Expand Up @@ -83,9 +83,9 @@ def _is_onnxruntime(model):

so.register_custom_ops_library(get_library_path())
if isinstance(model, str):
ort.InferenceSession(model, so, providers=ort.get_available_providers())
ort.InferenceSession(model, so, providers=["CPUExecutionProvider"])
else:
ort.InferenceSession(model.SerializeToString(), so, providers=ort.get_available_providers())
ort.InferenceSession(model.SerializeToString(), so, providers=["CPUExecutionProvider"])
except Exception as e: # pragma: no cover
if "Message onnx.ModelProto exceeds maximum protobuf size of 2GB" in str(e):
logger.warning("Please use model path instead of onnx model object to quantize")
Expand Down
9 changes: 9 additions & 0 deletions test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py
Expand Up @@ -1657,6 +1657,15 @@ def test_backend(self, mock_warning):

self.assertEqual(mock_warning.call_count, 2)

def test_cuda_ep_env_set(self):
config = PostTrainingQuantConfig(approach="static", backend="onnxrt_cuda_ep", device="gpu", quant_level=1)
q_model = quantization.fit(
self.distilbert_model,
config,
calib_dataloader=DummyNLPDataloader_dict("distilbert-base-uncased-finetuned-sst-2-english")
)
self.assertEqual(os.environ.get("ORT_TENSORRT_UNAVAILABLE"), "1")


if __name__ == "__main__":
unittest.main()

0 comments on commit 295535a

Please sign in to comment.