Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix provider options when several providers are passed #653

Merged
merged 2 commits into from
Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 8 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,11 +436,17 @@ def load_model(
# follow advice in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#python
providers.append("CUDAExecutionProvider")

# `providers` and `provider_options` need to be of the same length
if provider_options is not None:
providers_options = [provider_options] + [{} for _ in range(len(providers) - 1)]
else:
providers_options = None

decoder_session = onnxruntime.InferenceSession(
str(decoder_path),
providers=providers,
sess_options=session_options,
provider_options=None if provider_options is None else [provider_options],
provider_options=providers_options,
)
decoder_with_past_session = None
# If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs
Expand All @@ -450,7 +456,7 @@ def load_model(
str(decoder_with_past_path),
providers=providers,
sess_options=session_options,
provider_options=None if provider_options is None else [provider_options],
provider_options=None if provider_options is None else providers_options,
)
return decoder_session, decoder_with_past_session

Expand Down
9 changes: 7 additions & 2 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,17 @@ def load_model(
if not isinstance(path, str):
path = str(path)

# `providers` list must of be of the same length as `provider_options` list
# `providers` and `provider_options` need to be of the same length
if provider_options is not None:
providers_options = [provider_options] + [{} for _ in range(len(providers) - 1)]
else:
providers_options = None

return ort.InferenceSession(
path,
providers=providers,
sess_options=session_options,
provider_options=None if provider_options is None else [provider_options],
provider_options=providers_options,
)

def _save_pretrained(self, save_directory: Union[str, Path], file_name: str = ONNX_WEIGHTS_NAME, **kwargs):
Expand Down
12 changes: 9 additions & 3 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,17 +851,23 @@ def load_model(
# follow advice in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#python
providers.append("CUDAExecutionProvider")

# `providers` and `provider_options` need to be of the same length
if provider_options is not None:
providers_options = [provider_options] + [{} for _ in range(len(providers) - 1)]
else:
providers_options = None

encoder_session = ort.InferenceSession(
str(encoder_path),
providers=providers,
sess_options=session_options,
provider_options=None if provider_options is None else [provider_options],
provider_options=providers_options,
)
decoder_session = ort.InferenceSession(
str(decoder_path),
providers=providers,
sess_options=session_options,
provider_options=None if provider_options is None else [provider_options],
provider_options=providers_options,
)

decoder_with_past_session = None
Expand All @@ -872,7 +878,7 @@ def load_model(
str(decoder_with_past_path),
providers=providers,
sess_options=session_options,
provider_options=None if provider_options is None else [provider_options],
provider_options=providers_options,
)

return encoder_session, decoder_session, decoder_with_past_session
Expand Down