-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable spark_udf to use column names from model signature by default #4236
Conversation
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I left a few comments.
mlflow/pyfunc/__init__.py
Outdated
input_names = input_schema.input_names() | ||
return udf(*input_names) | ||
else: | ||
return udf(*args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can there be a valid udf that takes zero arguments? If not, we should raise error informing user that the model is missing model signature and hence can not be called with empty argument list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it is possible to define a valid udf that takes zero arguments (maybe they're used to define constant/random values?), so I think this else clause is a valid case where no error is necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we should just remove *args
here since it's an empty array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Can we catch the missing argument error and display the message then? I am worried that there is no way to tell if you can call the model udf with no arguments and it can be confusing. If we at least can inform the user that this particular model has no schema or no column names that would be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. The missing argument error is actually raised by the predict
function defined a few lines above. Instead of catching the error and displaying a different message, does it make sense if we just clarify the error message if the model has no schema and no col name arguments (line 840 in the diff) ? Let me know if you actually meant something different.
) | ||
) | ||
res = good_data.withColumn("res", udf()).select("res").toPandas() | ||
assert res["res"][0] == ["a", "b", "c"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also test that if user gives argument list we respect that? E.g. assert that the following raises:
res = good_data.withColumn("res", udf(struct("b", "c")))
) | ||
with pytest.raises(AnalysisException, match=r"cannot resolve '`a`' given input columns"): | ||
bad_data.withColumn("res", udf()) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a test case for what happens if the model has no signature and user calls it with empty argument list?
mlflow/pyfunc/__init__.py
Outdated
@functools.wraps(udf) | ||
def udf_with_default_cols(*args): | ||
if len(args) == 0: | ||
model = SparkModelCache.get_or_load(archive_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than loading the full model, can we read the model yaml file from local_model_path
(line 780)?
I think it may be safer to avoid loading the model on the driver.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried this suggestion, but it didn't work because we don't have access to local_model_path
. It is deleted after the with TempDir() as local_tmpdir:
scope ends. Do you have a recommendation on how to work around this?
I used SparkModelCache.get_or_load(archive_path)
here because that's what predict
does at line 786.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can load the Model while the tempdir exists (e.g. add the model = Model.load(local_model_path)
to the with block on line 780) and use it later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in fact, it may be useful to attach the model object to the returned function object, e.g. udf.metadata = model
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed some comments, and left some notes/questions.
mlflow/pyfunc/__init__.py
Outdated
@functools.wraps(udf) | ||
def udf_with_default_cols(*args): | ||
if len(args) == 0: | ||
model = SparkModelCache.get_or_load(archive_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried this suggestion, but it didn't work because we don't have access to local_model_path
. It is deleted after the with TempDir() as local_tmpdir:
scope ends. Do you have a recommendation on how to work around this?
I used SparkModelCache.get_or_load(archive_path)
here because that's what predict
does at line 786.
mlflow/pyfunc/__init__.py
Outdated
input_names = input_schema.input_names() | ||
return udf(*input_names) | ||
else: | ||
return udf(*args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it is possible to define a valid udf that takes zero arguments (maybe they're used to define constant/random values?), so I think this else clause is a valid case where no error is necessary.
Oh, I forgot to ask in my review, we also need to update docs (arguably that's the most important part otherwise users will have no clue this is an option). Are you planning on updating docs in a followup? I think it would be better to do it as part of a single PR. |
Thanks for the reminder. I'll update docs as part of this PR. |
mlflow/pyfunc/__init__.py
Outdated
@@ -780,6 +782,7 @@ def spark_udf(spark, model_uri, result_type="double"): | |||
artifact_uri=model_uri, output_path=local_tmpdir.path() | |||
) | |||
archive_path = SparkModelCache.add_local_model(spark, local_model_path) | |||
model = Model.load(os.path.join(local_model_path, MLMODEL_FILE_NAME)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've realized that we use model
in the predict function as well. It should not matter but it is confusing. Can you rename this to model_metadata
?
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
I added some documentation in the docstring as well as models.rst. Let me know what you think, and if there are other doc pages to update. |
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
input_names = input_schema.input_names() | ||
return udf(*input_names) | ||
else: | ||
return udf() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we check if the model has input schema with no column names? and raise an error in that case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, would it be fair to say that we do not support models with empty argument list that have no signature? And just raise an error whenever there is an empty argument list but no signature (or one with unnamed columns)?
I think this would lead to much better user experience for the common case. It is possible there is not a single model with no arguments that is being deployed this way.
mlflow/pyfunc/__init__.py
Outdated
@@ -832,11 +838,20 @@ def predict(*args): | |||
result = result.select_dtypes(include=(np.number,)).astype(np.float64) | |||
|
|||
if len(result.columns) == 0: | |||
message = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you check what is scikit's behavior when you pass it empty dataframe?
mlflow/pyfunc/__init__.py
Outdated
) | ||
if len(args) == 0 and input_schema is None: | ||
message += ( | ||
"Apply the udf by specifying column name arguments, for example " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant to add this message at the top level. If that is not possible, I would not add it at all because I don't think we can add the message reliably at the python level - by which I mean some models will raise exceptions, some will produce empty output, it would be difficult to reliably detect when predict fails due to missing columns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
All tests pass. Ready to merge. |
…lflow#4236) * basic implementation with udf.default property Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * make it work by just calling udf() instead of udf.default Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * autoformat Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * address some comments Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * yolo Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * fix Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * make tests work Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * autoformat Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * add documentation Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * fix lint? Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * fix tests? Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * fix lint AGAIN Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * test Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * fix Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * fix lint Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * trigger GitHub actions because r failed Signed-off-by: Andy Zhang <yue.zhang@databricks.com> Signed-off-by: Yiqing Wang <yiqing@wangemail.com>
…lflow#4236) * basic implementation with udf.default property Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * make it work by just calling udf() instead of udf.default Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * autoformat Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * address some comments Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * yolo Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * fix Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * make tests work Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * autoformat Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * add documentation Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * fix lint? Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * fix tests? Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * fix lint AGAIN Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * test Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * fix Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * fix lint Signed-off-by: Andy Zhang <yue.zhang@databricks.com> * trigger GitHub actions because r failed Signed-off-by: Andy Zhang <yue.zhang@databricks.com> Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: Andy Zhang yue.zhang@databricks.com
What changes are proposed in this pull request?
spark_udf returns a spark UDF python function. The user must specify the column names as input:
Use functools.wraps to create a wrapper for the UDF so that it can be used without specifying column names:
The wrapper fills in column names using the input schema logged in model signatures. If the input schema doesn't exist, revert to old behavior.
How is this patch tested?
pytest tests/pyfunc/test_spark.py --large
Release Notes
Is this a user-facing change?
Allow the UDF returned by
mlflow.pyfunc.spark_udf
to be called with no column name arguments, in which case use input names from the model signature.What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/projects
: MLproject format, project running backendsarea/scoring
: Local serving, model deployment tools, spark UDFsarea/server-infra
: MLflow server, JavaScript dev serverarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, JavaScript, plottingarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes