Join GitHub today
GitHub is home to over 40 million developers working together to host and review code, manage projects, and build software together.Sign up
Update mlflow.spark.log_model() to accept descendants of pyspark.Model #1519
What changes are proposed in this pull request?
The mlflow.spark.log_model() method defines a spark_model parameter. This parameter corresponds to the PySpark Pipeline that the caller is persisting in MLflow format. Currently, the spark_model parameter must be of type pyspark.ml.PipelineModel; the mlflow.spark._validate_model() method throws an exception if spark_model is not of type PipelineModel.
However, a PipelineModel is just a collection of PySpark stages, and each stage is a descendant of pyspark.ml.Transformer. We should be able to handle any descendant of pyspark.ml.Model (a descendant of Transformer) because those generate predictions, and we can wrap these under the hood in a PipelineModel that we then save. We cannot support arbitrary Transformers because many of them generate "features" and not predictions which breaks some other assumptions in mlflow.spark
This commit extends the set of supported types for the spark_model parameter of mlflow.spark.log_model() to include any descendant of pyspark.ml.Model. We convert it to a PipelineModel to use a consistent save/load implementation before serializing
How is this patch tested?
Added new tests for different types of pyspark models for
Is this a user-facing change?
mlflow.spark.log_model and mlflow.spark.save_model can now serialize any descendant of pyspark.ml.Model which implements MLReadable and MLWritable.
What component(s) does this PR affect?
How should the PR be classified in the release notes? Choose one:
dbczumar left a comment
Looks great in general! Left a few nits. Additionally, we should update the parameter documentation for