Skip to content

Commit

Permalink
Fix provider options when several providers are passed (#653)
Browse files Browse the repository at this point in the history
* fix provider options when several providers are passed

* add test
  • Loading branch information
fxmarty committed Jan 2, 2023
1 parent 6a9d2f8 commit 8d1fe87
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 7 deletions.
10 changes: 8 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,11 +436,17 @@ def load_model(
# follow advice in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#python
providers.append("CUDAExecutionProvider")

# `providers` and `provider_options` need to be of the same length
if provider_options is not None:
providers_options = [provider_options] + [{} for _ in range(len(providers) - 1)]
else:
providers_options = None

decoder_session = onnxruntime.InferenceSession(
str(decoder_path),
providers=providers,
sess_options=session_options,
provider_options=None if provider_options is None else [provider_options],
provider_options=providers_options,
)
decoder_with_past_session = None
# If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs
Expand All @@ -450,7 +456,7 @@ def load_model(
str(decoder_with_past_path),
providers=providers,
sess_options=session_options,
provider_options=None if provider_options is None else [provider_options],
provider_options=None if provider_options is None else providers_options,
)
return decoder_session, decoder_with_past_session

Expand Down
9 changes: 7 additions & 2 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,17 @@ def load_model(
if not isinstance(path, str):
path = str(path)

# `providers` list must of be of the same length as `provider_options` list
# `providers` and `provider_options` need to be of the same length
if provider_options is not None:
providers_options = [provider_options] + [{} for _ in range(len(providers) - 1)]
else:
providers_options = None

return ort.InferenceSession(
path,
providers=providers,
sess_options=session_options,
provider_options=None if provider_options is None else [provider_options],
provider_options=providers_options,
)

def _save_pretrained(self, save_directory: Union[str, Path], file_name: str = ONNX_WEIGHTS_NAME, **kwargs):
Expand Down
12 changes: 9 additions & 3 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,17 +851,23 @@ def load_model(
# follow advice in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#python
providers.append("CUDAExecutionProvider")

# `providers` and `provider_options` need to be of the same length
if provider_options is not None:
providers_options = [provider_options] + [{} for _ in range(len(providers) - 1)]
else:
providers_options = None

encoder_session = ort.InferenceSession(
str(encoder_path),
providers=providers,
sess_options=session_options,
provider_options=None if provider_options is None else [provider_options],
provider_options=providers_options,
)
decoder_session = ort.InferenceSession(
str(decoder_path),
providers=providers,
sess_options=session_options,
provider_options=None if provider_options is None else [provider_options],
provider_options=providers_options,
)

decoder_with_past_session = None
Expand All @@ -872,7 +878,7 @@ def load_model(
str(decoder_with_past_path),
providers=providers,
sess_options=session_options,
provider_options=None if provider_options is None else [provider_options],
provider_options=providers_options,
)

return encoder_session, decoder_session, decoder_with_past_session
Expand Down
66 changes: 66 additions & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,21 @@ def test_passing_provider_options(self):
)
self.assertEqual(model.model.get_provider_options()["CUDAExecutionProvider"]["do_copy_in_default_stream"], "0")

# two providers case
model = ORTModel.from_pretrained(self.ONNX_MODEL_ID, provider="TensorrtExecutionProvider")
self.assertEqual(
model.model.get_provider_options()["TensorrtExecutionProvider"]["trt_engine_cache_enable"], "0"
)

model = ORTModel.from_pretrained(
self.ONNX_MODEL_ID,
provider="TensorrtExecutionProvider",
provider_options={"trt_engine_cache_enable": True},
)
self.assertEqual(
model.model.get_provider_options()["TensorrtExecutionProvider"]["trt_engine_cache_enable"], "1"
)

@unittest.skipIf(get_gpu_count() <= 1, "this test requires multi-gpu")
def test_model_on_gpu_id(self):
model = ORTModel.from_pretrained(self.ONNX_MODEL_ID)
Expand All @@ -316,18 +331,69 @@ def test_passing_provider_options_seq2seq(self):
self.assertEqual(
model.decoder.session.get_provider_options()["CUDAExecutionProvider"]["do_copy_in_default_stream"], "1"
)
self.assertEqual(
model.decoder_with_past.session.get_provider_options()["CUDAExecutionProvider"][
"do_copy_in_default_stream"
],
"1",
)

model = ORTModelForSeq2SeqLM.from_pretrained(
self.ONNX_SEQ2SEQ_MODEL_ID,
provider="CUDAExecutionProvider",
provider_options={"do_copy_in_default_stream": 0},
use_cache=True,
)
self.assertEqual(
model.encoder.session.get_provider_options()["CUDAExecutionProvider"]["do_copy_in_default_stream"], "0"
)
self.assertEqual(
model.decoder.session.get_provider_options()["CUDAExecutionProvider"]["do_copy_in_default_stream"], "0"
)
self.assertEqual(
model.decoder_with_past.session.get_provider_options()["CUDAExecutionProvider"][
"do_copy_in_default_stream"
],
"0",
)

# two providers case
model = ORTModelForSeq2SeqLM.from_pretrained(
self.ONNX_SEQ2SEQ_MODEL_ID,
provider="TensorrtExecutionProvider",
use_cache=True,
)
self.assertEqual(
model.encoder.session.get_provider_options()["TensorrtExecutionProvider"]["trt_engine_cache_enable"], "0"
)
self.assertEqual(
model.decoder.session.get_provider_options()["TensorrtExecutionProvider"]["trt_engine_cache_enable"], "0"
)
self.assertEqual(
model.decoder_with_past.session.get_provider_options()["TensorrtExecutionProvider"][
"trt_engine_cache_enable"
],
"0",
)

model = ORTModelForSeq2SeqLM.from_pretrained(
self.ONNX_SEQ2SEQ_MODEL_ID,
provider="TensorrtExecutionProvider",
provider_options={"trt_engine_cache_enable": True},
use_cache=True,
)
self.assertEqual(
model.encoder.session.get_provider_options()["TensorrtExecutionProvider"]["trt_engine_cache_enable"], "1"
)
self.assertEqual(
model.decoder.session.get_provider_options()["TensorrtExecutionProvider"]["trt_engine_cache_enable"], "1"
)
self.assertEqual(
model.decoder_with_past.session.get_provider_options()["TensorrtExecutionProvider"][
"trt_engine_cache_enable"
],
"1",
)

def test_seq2seq_model_on_cpu(self):
model = ORTModelForSeq2SeqLM.from_pretrained(self.ONNX_SEQ2SEQ_MODEL_ID, use_cache=True)
Expand Down

0 comments on commit 8d1fe87

Please sign in to comment.