Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable spark_udf to use column names from model signature by default #4236

Merged
merged 17 commits into from
Apr 19, 2021

Conversation

Loquats
Copy link
Contributor

@Loquats Loquats commented Apr 9, 2021

Signed-off-by: Andy Zhang yue.zhang@databricks.com

What changes are proposed in this pull request?

spark_udf returns a spark UDF python function. The user must specify the column names as input:

model_udf = mlflow.pyfunc.spark_udf(spark, path, result_type=”string”)
column_names = df.drop('label').schema.names
pred_df = df.withColumn('prediction', model_udf(*column_names))

Use functools.wraps to create a wrapper for the UDF so that it can be used without specifying column names:

model_udf = mlflow.pyfunc.spark_udf(spark, path)
pred_df = df.withColumn('prediction', model_udf())

The wrapper fills in column names using the input schema logged in model signatures. If the input schema doesn't exist, revert to old behavior.

How is this patch tested?

pytest tests/pyfunc/test_spark.py --large

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.
    Allow the UDF returned by mlflow.pyfunc.spark_udf to be called with no column name arguments, in which case use input names from the model signature.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/projects: MLproject format, project running backends
  • area/scoring: Local serving, model deployment tools, spark UDFs
  • area/server-infra: MLflow server, JavaScript dev server
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, JavaScript, plotting
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
@github-actions github-actions bot added area/scoring MLflow Model server, model deployment tools, Spark UDFs rn/feature Mention under Features in Changelogs. labels Apr 10, 2021
@Loquats Loquats changed the title basic implementation of spark_udf with default columns filled in Enable spark_udf to use column names from model signature by default Apr 10, 2021
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Copy link
Contributor

@tomasatdatabricks tomasatdatabricks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I left a few comments.

input_names = input_schema.input_names()
return udf(*input_names)
else:
return udf(*args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can there be a valid udf that takes zero arguments? If not, we should raise error informing user that the model is missing model signature and hence can not be called with empty argument list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it is possible to define a valid udf that takes zero arguments (maybe they're used to define constant/random values?), so I think this else clause is a valid case where no error is necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should just remove *args here since it's an empty array.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Can we catch the missing argument error and display the message then? I am worried that there is no way to tell if you can call the model udf with no arguments and it can be confusing. If we at least can inform the user that this particular model has no schema or no column names that would be helpful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. The missing argument error is actually raised by the predict function defined a few lines above. Instead of catching the error and displaying a different message, does it make sense if we just clarify the error message if the model has no schema and no col name arguments (line 840 in the diff) ? Let me know if you actually meant something different.

mlflow/pyfunc/__init__.py Outdated Show resolved Hide resolved
)
)
res = good_data.withColumn("res", udf()).select("res").toPandas()
assert res["res"][0] == ["a", "b", "c"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also test that if user gives argument list we respect that? E.g. assert that the following raises:
res = good_data.withColumn("res", udf(struct("b", "c")))

)
with pytest.raises(AnalysisException, match=r"cannot resolve '`a`' given input columns"):
bad_data.withColumn("res", udf())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test case for what happens if the model has no signature and user calls it with empty argument list?

@functools.wraps(udf)
def udf_with_default_cols(*args):
if len(args) == 0:
model = SparkModelCache.get_or_load(archive_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than loading the full model, can we read the model yaml file from local_model_path (line 780)?
I think it may be safer to avoid loading the model on the driver.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this suggestion, but it didn't work because we don't have access to local_model_path. It is deleted after the with TempDir() as local_tmpdir: scope ends. Do you have a recommendation on how to work around this?
I used SparkModelCache.get_or_load(archive_path) here because that's what predict does at line 786.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can load the Model while the tempdir exists (e.g. add the model = Model.load(local_model_path) to the with block on line 780) and use it later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in fact, it may be useful to attach the model object to the returned function object, e.g. udf.metadata = model

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Copy link
Contributor Author

@Loquats Loquats left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed some comments, and left some notes/questions.

@functools.wraps(udf)
def udf_with_default_cols(*args):
if len(args) == 0:
model = SparkModelCache.get_or_load(archive_path)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this suggestion, but it didn't work because we don't have access to local_model_path. It is deleted after the with TempDir() as local_tmpdir: scope ends. Do you have a recommendation on how to work around this?
I used SparkModelCache.get_or_load(archive_path) here because that's what predict does at line 786.

input_names = input_schema.input_names()
return udf(*input_names)
else:
return udf(*args)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it is possible to define a valid udf that takes zero arguments (maybe they're used to define constant/random values?), so I think this else clause is a valid case where no error is necessary.

mlflow/pyfunc/__init__.py Outdated Show resolved Hide resolved
@tomasatdatabricks
Copy link
Contributor

Oh, I forgot to ask in my review, we also need to update docs (arguably that's the most important part otherwise users will have no clue this is an option). Are you planning on updating docs in a followup? I think it would be better to do it as part of a single PR.

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
@Loquats
Copy link
Contributor Author

Loquats commented Apr 14, 2021

Thanks for the reminder. I'll update docs as part of this PR.

@@ -780,6 +782,7 @@ def spark_udf(spark, model_uri, result_type="double"):
artifact_uri=model_uri, output_path=local_tmpdir.path()
)
archive_path = SparkModelCache.add_local_model(spark, local_model_path)
model = Model.load(os.path.join(local_model_path, MLMODEL_FILE_NAME))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've realized that we use model in the predict function as well. It should not matter but it is confusing. Can you rename this to model_metadata?

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Loquats and others added 2 commits April 14, 2021 15:08
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
@Loquats
Copy link
Contributor Author

Loquats commented Apr 14, 2021

I added some documentation in the docstring as well as models.rst. Let me know what you think, and if there are other doc pages to update.

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
input_names = input_schema.input_names()
return udf(*input_names)
else:
return udf()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we check if the model has input schema with no column names? and raise an error in that case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, would it be fair to say that we do not support models with empty argument list that have no signature? And just raise an error whenever there is an empty argument list but no signature (or one with unnamed columns)?

I think this would lead to much better user experience for the common case. It is possible there is not a single model with no arguments that is being deployed this way.

@@ -832,11 +838,20 @@ def predict(*args):
result = result.select_dtypes(include=(np.number,)).astype(np.float64)

if len(result.columns) == 0:
message = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you check what is scikit's behavior when you pass it empty dataframe?

)
if len(args) == 0 and input_schema is None:
message += (
"Apply the udf by specifying column name arguments, for example "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to add this message at the top level. If that is not possible, I would not add it at all because I don't think we can add the message reliably at the python level - by which I mean some models will raise exceptions, some will produce empty output, it would be difficult to reliably detect when predict fails due to missing columns.

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Copy link
Contributor

@tomasatdatabricks tomasatdatabricks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
@Loquats
Copy link
Contributor Author

Loquats commented Apr 17, 2021

All tests pass. Ready to merge.

@tomasatdatabricks tomasatdatabricks merged commit 25379a3 into mlflow:master Apr 19, 2021
YQ-Wang pushed a commit to YQ-Wang/mlflow that referenced this pull request May 29, 2021
…lflow#4236)

* basic implementation with udf.default property

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* make it work by just calling udf() instead of udf.default

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* autoformat

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* address some comments

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* yolo

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* fix

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* make tests work

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* autoformat

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* add documentation

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* fix lint?

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* fix tests?

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* fix lint AGAIN

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* test

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* fix

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* fix lint

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* trigger GitHub actions because r failed

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Signed-off-by: Yiqing Wang <yiqing@wangemail.com>
harupy pushed a commit to wamartin-aml/mlflow that referenced this pull request Jun 7, 2021
…lflow#4236)

* basic implementation with udf.default property

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* make it work by just calling udf() instead of udf.default

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* autoformat

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* address some comments

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* yolo

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* fix

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* make tests work

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* autoformat

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* add documentation

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* fix lint?

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* fix tests?

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* fix lint AGAIN

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* test

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* fix

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* fix lint

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>

* trigger GitHub actions because r failed

Signed-off-by: Andy Zhang <yue.zhang@databricks.com>
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/scoring MLflow Model server, model deployment tools, Spark UDFs rn/feature Mention under Features in Changelogs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants