-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add ort session options #8433
add ort session options #8433
Conversation
Documentation preview for 82d4667 will be available here when this CircleCI job completes successfully. More info
|
@leqiao-1 Thank you for the contribution! Could you fix the following issue(s)? ⚠ DCO checkThe DCO check failed. Please sign off your commit(s) by following the instructions here. See https://github.com/mlflow/mlflow/blob/master/CONTRIBUTING.md#sign-your-work for more details. |
99e9bfa
to
8a75bc6
Compare
mlflow/onnx.py
Outdated
@@ -176,6 +178,7 @@ def save_model( | |||
onnx_version=onnx.__version__, | |||
data=model_data_subpath, | |||
providers=onnx_execution_providers, | |||
options=onnx_session_options, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to add some validation here, validate the option keys and value types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not use onnx_session_options=onnx_session_options
?
sess_options.inter_op_num_threads = inter_op_num_threads | ||
if intra_op_num_threads: | ||
sess_options.intra_op_num_threads = intra_op_num_threads | ||
if execution_mode: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not force execution_mode
value to be one of "SEQUENTIAL" and "PARALLEL" ? and we should document it.
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel( | ||
graph_optimization_level | ||
) | ||
if extra_session_config: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can it be an arbitrary dict ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall good, but I have some comments.
@@ -255,9 +255,9 @@ def __init__(self, path, providers=None): | |||
if intra_op_num_threads: | |||
sess_options.intra_op_num_threads = intra_op_num_threads | |||
if execution_mode: | |||
if execution_mode == 0: | |||
if execution_mode.upper() == "SEQUENTIAL": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should document that when logging model, we should use "SEQUENTIAL" or "PARALLEL" as session option value.
mlflow/onnx.py
Outdated
:param onnx_session_options: Dictionary of options to be passed to onnxruntime.InferenceSession. | ||
For example: | ||
``{ | ||
'graph_optimization_level': 99, | ||
'intra_op_num_threads': 1, | ||
'inter_op_num_threads': 1, | ||
'execution_mode': 'sequential' | ||
}`` | ||
'execution_mode' can be set ot 'sequential' or 'parallel'. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pls also add the doc in "log_model" method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been added
@@ -176,6 +188,7 @@ def save_model( | |||
onnx_version=onnx.__version__, | |||
data=model_data_subpath, | |||
providers=onnx_execution_providers, | |||
onnx_session_options=onnx_session_options, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can add simple verification:
- check keys are valid
- for
execution_mode
, check value are valid - for other option keys, check types are correct, e.g., inter_op_num_threads should be integer that is > 0
a7ad12c
to
39c3a21
Compare
Signed-off-by: leqiao <leqiao@microsoft.com>
Signed-off-by: leqiao <leqiao@microsoft.com>
Signed-off-by: leqiao <leqiao@microsoft.com>
Signed-off-by: leqiao <leqiao@microsoft.com>
Signed-off-by: leqiao <leqiao@microsoft.com>
Signed-off-by: leqiao <leqiao@microsoft.com>
39c3a21
to
275ddca
Compare
if execution_mode.upper() == "SEQUENTIAL": | ||
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL | ||
elif execution_mode.upper() == "PARALLEL": | ||
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_PARALLEL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: for other values, raises error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK we don't need this. We already do verification when saving model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@leqiao-1 Could you fix the lint errors https://github.com/mlflow/mlflow/actions/runs/5011618497/jobs/8987189752?pr=8433 Once you fix them, I will merge the PR. |
Head branch was pushed to by a user without write access
Signed-off-by: leqiao <leqiao@microsoft.com>
049735c
to
82d4667
Compare
Format fixed. Thanks |
raise ValueError( | ||
f"Key {key} in onnx_session_options is not a valid \ | ||
ONNX Runtime session options key" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is how the actual error message look like:
key = 1
raise ValueError(
f"Key {key} in onnx_session_options is not a valid \
ONNX Runtime session options key"
)
Traceback (most recent call last):
File "a.py", line 2, in <module>
raise ValueError(
ValueError: Key 1 in onnx_session_options is not a valid ONNX Runtime session options key
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
inter_op_num_threads = options.get("inter_op_num_threads") | ||
intra_op_num_threads = options.get("intra_op_num_threads") | ||
execution_mode = options.get("execution_mode") | ||
graph_optimization_level = options.get("graph_optimization_level") | ||
extra_session_config = options.get("extra_session_config") | ||
if inter_op_num_threads: | ||
sess_options.inter_op_num_threads = inter_op_num_threads | ||
if intra_op_num_threads: | ||
sess_options.intra_op_num_threads = intra_op_num_threads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be simplified by using assignment expressions:
if inter_op_num_threads := options.get("inter_op_num_threads"):
# do something with inter_op_num_threads
Related Issues/PRs
#8413What changes are proposed in this pull request?
Add onnxruntime session options when saving and loading onnx model with pyfunc
How is this patch tested?
Does this PR change the documentation?
Release Notes
Is this a user-facing change?
User can provide onnxruntime session options when saving and loading onnx model with pyfunc.
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/recipes
: Recipes, Recipe APIs, Recipe configs, Recipe Templatesarea/projects
: MLproject format, project running backendsarea/scoring
: MLflow Model server, model deployment tools, Spark UDFsarea/server-infra
: MLflow Tracking server backendarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes