-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Decimal object to Double cast in ML Flow #6600
Conversation
@shitaoli-db Thanks for the contribution! The DCO check failed. Please sign off your commits by following the instructions here: https://github.com/mlflow/mlflow/runs/8028008004. See https://github.com/mlflow/mlflow/blob/master/CONTRIBUTING.rst#sign-your-work for more details. |
Signed-off-by: Shitao Li <shitao.li@databricks.com>
Signed-off-by: Shitao Li <shitao.li@databricks.com>
b06f1c0
to
6399e81
Compare
mlflow/pyfunc/__init__.py
Outdated
@@ -397,6 +402,15 @@ def _enforce_mlflow_datatype(name, values: pandas.Series, t: DataType): | |||
raise MlflowException( | |||
"Failed to convert column {0} from type {1} to {2}.".format(name, values.dtype, t) | |||
) | |||
if t == DataType.double and values.dtype == decimal.Decimal: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought in our test the values.dtype
is object
, while type(values[0])
is decimal.Decimal
. Can you confirm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also our unit test and e2e notebook test all passed when object is Decimal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I didn't know dtype == object
and dtype == decimal.Decimal
can both be true. LGTM!
mlflow/pyfunc/__init__.py
Outdated
@@ -397,6 +402,15 @@ def _enforce_mlflow_datatype(name, values: pandas.Series, t: DataType): | |||
raise MlflowException( | |||
"Failed to convert column {0} from type {1} to {2}.".format(name, values.dtype, t) | |||
) | |||
if t == DataType.double and values.dtype == decimal.Decimal: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I didn't know dtype == object
and dtype == decimal.Decimal
can both be true. LGTM!
mlflow/pyfunc/__init__.py
Outdated
@@ -397,6 +402,15 @@ def _enforce_mlflow_datatype(name, values: pandas.Series, t: DataType): | |||
raise MlflowException( | |||
"Failed to convert column {0} from type {1} to {2}.".format(name, values.dtype, t) | |||
) | |||
if t == DataType.double and values.dtype == decimal.Decimal: | |||
# NB: Pyspark Decimal columne get converted to decimal.Decimal when converted to pandas |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: columne -> column.
in order to support ... - is part of the comment missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
Can you also update docs please?
Can we make this work only in spark_udf environment or protect it with an argument? E.g. we could use an env variable or extra argument to pyfunc.load_model. Would that be feasible?
Signed-off-by: Shitao Li <shitao.li@databricks.com>
Let me try with an extra argument first. Do you know how did we generate this message? If we use the extra argument, then we have to customize this message in the UI to let customer know that if they want their model with decimal training data work in inference, they have to pass an extra argument.
|
mlflow/pyfunc/__init__.py
Outdated
model_uri: str, | ||
suppress_warnings: bool = False, | ||
dst_path: str = None, | ||
support_decimal: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to accomplish the goal of providing conditional decimal enforcement without adding this flag to the mlflow.pyfunc.load_model()
API? I'd rather not extend this API for a niche use case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally don't think we need this flag at all. Can we just always enable this conversion? I don't see any real use case of setting that to false (or only add that flag later when we confirm it's a real use case).
@dbczumar @tomasatdatabricks what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think tom provided 2 approach to me
The other option I could think on top of my head is use environment variable
https://github.com/mlflow/mlflow/blob/master/mlflow/environment_variables.py
I am not sure which approach is more preferrable, let's have a quick agreement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @yxiong here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted back the change, now we always enable the cast.
89bdb01
to
99d17ad
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Signed-off-by: Shitao Li <shitao.li@databricks.com>
Head branch was pushed to by a user without write access
autoformat |
Signed-off-by: Shitao Li <shitao.li@databricks.com>
…ved the enforce scehma away so we have to do same. Signed-off-by: Shitao Li <shitao.li@databricks.com>
Signed-off-by: Shitao Li <shitao.li@databricks.com>
automerge |
Signed-off-by: Shitao Li <shitao.li@databricks.com>
Head branch was pushed to by a user without write access
* Add Decimal to double cast when enforce the schema. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Add unit test for casting from Decimal to double. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Fix the comment in the python file. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Fix lint problem. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Move the change to utils since merge conflict. Previous commit had moved the enforce scehma away so we have to do same. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Clean up __init__.py since merge resolver did not handle merge conflict. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Lint the utils.py. Signed-off-by: Shitao Li <shitao.li@databricks.com> Signed-off-by: Shitao Li <shitao.li@databricks.com> Signed-off-by: Prithvi Kannan <prithvi.kannan@databricks.com>
* Add Decimal to double cast when enforce the schema. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Add unit test for casting from Decimal to double. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Fix the comment in the python file. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Fix lint problem. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Move the change to utils since merge conflict. Previous commit had moved the enforce scehma away so we have to do same. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Clean up __init__.py since merge resolver did not handle merge conflict. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Lint the utils.py. Signed-off-by: Shitao Li <shitao.li@databricks.com> Signed-off-by: Shitao Li <shitao.li@databricks.com> Signed-off-by: Prithvi Kannan <prithvi.kannan@databricks.com>
* Add Decimal to double cast when enforce the schema. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Add unit test for casting from Decimal to double. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Fix the comment in the python file. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Fix lint problem. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Move the change to utils since merge conflict. Previous commit had moved the enforce scehma away so we have to do same. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Clean up __init__.py since merge resolver did not handle merge conflict. Signed-off-by: Shitao Li <shitao.li@databricks.com> * Lint the utils.py. Signed-off-by: Shitao Li <shitao.li@databricks.com> Signed-off-by: Shitao Li <shitao.li@databricks.com>
What changes are proposed in this pull request?
(Please fill in changes proposed in this fix)
How is this patch tested?
Added unit test for Decimal to float object in enforce schema cast.
Also added end to end test with databricks automl model inference, which has double schema in model, the inference was successful
Does this PR change the documentation?
Release Notes
Is this a user-facing change?
Model inference now allow user to use Decimal object as input. Note decimal input will only be allowed when schema is double and the data will be casted to double and thus may lost precision.
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/pipelines
: Pipelines, Pipeline APIs, Pipeline configs, Pipeline Templatesarea/projects
: MLproject format, project running backendsarea/scoring
: MLflow Model server, model deployment tools, Spark UDFsarea/server-infra
: MLflow Tracking server backendarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes