Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable auto dependency inference in spark flavor #4759

Merged
merged 10 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 26 additions & 17 deletions mlflow/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,13 +407,33 @@ def _save_model_metadata(
if input_example is not None:
_save_example(mlflow_model, input_example, dst_dir)

conda_env, pip_requirements, pip_constraints = (
_process_pip_requirements(
get_default_pip_requirements(), pip_requirements, extra_pip_requirements,
)
if conda_env is None
else _process_conda_env(conda_env)
mlflow_model.add_flavor(
FLAVOR_NAME, pyspark_version=pyspark.__version__, model_data=_SPARK_MODEL_PATH_SUB
)
pyfunc.add_to_model(
mlflow_model,
loader_module="mlflow.spark",
data=_SPARK_MODEL_PATH_SUB,
env=_CONDA_ENV_FILE_NAME,
)
mlflow_model.save(os.path.join(dst_dir, MLMODEL_FILE_NAME))

if conda_env is None:
if pip_requirements is None:
default_reqs = get_default_pip_requirements()
# To ensure `_load_pyfunc` can successfully load the model during the dependency
# inference, `mlflow_model.save` must be called beforehand to save an MLmodel file.
inferred_reqs = mlflow.models.infer_pip_requirements(
dst_dir, FLAVOR_NAME, fallback=default_reqs,
)
default_reqs = sorted(set(inferred_reqs).union(default_reqs))
else:
default_reqs = None
conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
default_reqs, pip_requirements, extra_pip_requirements,
)
else:
conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)

with open(os.path.join(dst_dir, _CONDA_ENV_FILE_NAME), "w") as f:
yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
Expand All @@ -425,17 +445,6 @@ def _save_model_metadata(
# Save `requirements.txt`
write_to(os.path.join(dst_dir, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))

mlflow_model.add_flavor(
FLAVOR_NAME, pyspark_version=pyspark.__version__, model_data=_SPARK_MODEL_PATH_SUB
)
pyfunc.add_to_model(
mlflow_model,
loader_module="mlflow.spark",
data=_SPARK_MODEL_PATH_SUB,
env=_CONDA_ENV_FILE_NAME,
)
mlflow_model.save(os.path.join(dst_dir, MLMODEL_FILE_NAME))


def _validate_model(spark_model):
from pyspark.ml.util import MLReadable, MLWritable
Expand Down
6 changes: 6 additions & 0 deletions mlflow/utils/_capture_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mlflow.utils.file_utils import write_to
from mlflow.pyfunc import MAIN
from mlflow.models.model import MLMODEL_FILE_NAME, Model
from mlflow.utils.databricks_utils import is_in_databricks_runtime


def _get_top_level_module(full_module_name):
Expand Down Expand Up @@ -84,6 +85,11 @@ def main():

cap_cm = _CaptureImportedModules()

if flavor == "spark" and is_in_databricks_runtime():
from dbruntime.spark_connection import initialize_spark_connection
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach breaks when initialize_spark_connection is renamed or moved to a different module.

Copy link
Member Author

@harupy harupy Aug 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make the spark initialization process modifiable via monkey-patching or an environment variable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a MLR test to prevent it break.


initialize_spark_connection()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the case not in databricks runtime ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WeichenXu123

If not, a new spark session is created in the following code:

mlflow/mlflow/spark.py

Lines 687 to 700 in c4b8e84

if spark is None:
# NB: If there is no existing Spark context, create a new local one.
# NB: We're disabling caching on the new context since we do not need it and we want to
# avoid overwriting cache of underlying Spark cluster when executed on a Spark Worker
# (e.g. as part of spark_udf).
spark = (
pyspark.sql.SparkSession.builder.config("spark.python.worker.reuse", True)
.config("spark.databricks.io.cache.enabled", False)
# In Spark 3.1 and above, we need to set this conf explicitly to enable creating
# a SparkSession on the workers
.config("spark.executor.allowSparkContext", "true")
.master("local[1]")
.getOrCreate()
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add try/catch here and add fallback handling ?
e.g. the case User upgrade/downgrade the builtin mlflow of databrick runtime and found the dbruntime.spark_connection.initialize_spark_connection API does not exist.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @WeichenXu123 , I think a try/catch is a good idea here.


# If `model_path` refers to an MLflow model directory, load the model using
# `mlflow.pyfunc.load_model`
if os.path.isdir(model_path) and MLMODEL_FILE_NAME in os.listdir(model_path):
Expand Down
7 changes: 3 additions & 4 deletions mlflow/utils/requirements_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from mlflow.exceptions import MlflowException
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils.autologging_utils.versioning import _strip_dev_version_suffix
from mlflow.utils.databricks_utils import is_in_databricks_runtime
from packaging.version import Version, InvalidVersion

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -193,9 +192,9 @@ def _get_installed_version(package, module=None):
# 1.9.0
version = __import__(module or package).__version__

# In Databricks, strip a dev version suffix for pyspark (e.g. '3.1.2.dev0' -> '3.1.2')
# and make it installable from PyPI.
if package == "pyspark" and is_in_databricks_runtime():
# Strip the suffix from `dev` versions of PySpark, which are not available for installation
# from Anaconda or PyPI
if package == "pyspark":
version = _strip_dev_version_suffix(version)

return version
Expand Down
51 changes: 12 additions & 39 deletions tests/spark/test_spark_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from tests.helper_functions import (
score_model_in_sagemaker_docker_container,
_compare_conda_env_requirements,
_get_pip_deps,
_assert_pip_requirements,
)
from tests.pyfunc.test_spark import score_model_as_udf, get_spark_session
Expand Down Expand Up @@ -601,13 +602,7 @@ def test_sparkml_model_save_without_specified_conda_env_uses_default_env_with_ex
spark_model_iris, model_path
):
sparkm.save_model(spark_model=spark_model_iris.model, path=model_path)

pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
with open(conda_env_path, "r") as f:
conda_env = yaml.safe_load(f)

assert conda_env == sparkm.get_default_conda_env()
_assert_pip_requirements(model_path, sparkm.get_default_pip_requirements())


@pytest.mark.large
Expand All @@ -617,47 +612,25 @@ def test_sparkml_model_log_without_specified_conda_env_uses_default_env_with_exp
artifact_path = "model"
with mlflow.start_run():
sparkm.log_model(spark_model=spark_model_iris.model, artifact_path=artifact_path)
model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path
)

model_path = _download_artifact_from_uri(artifact_uri=model_uri)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
with open(conda_env_path, "r") as f:
conda_env = yaml.safe_load(f)
model_uri = mlflow.get_artifact_uri(artifact_path)

assert conda_env == sparkm.get_default_conda_env()
_assert_pip_requirements(model_uri, sparkm.get_default_pip_requirements())


@pytest.mark.large
def test_default_conda_env_strips_dev_suffix_from_pyspark_version(spark_model_iris, model_path):
with mock.patch("importlib_metadata.version", return_value="2.4.0"):
default_conda_env_standard = sparkm.get_default_conda_env()

for dev_version in ["2.4.0.dev0", "2.4.0.dev", "2.4.0.dev1", "2.4.0dev.a", "2.4.0.devb"]:
with mock.patch("importlib_metadata.version", return_value=dev_version):
default_conda_env_dev = sparkm.get_default_conda_env()
assert default_conda_env_dev == default_conda_env_standard

def test_dev_version_suffix_for_pyspark_is_stripped(spark_model_iris):
unsuffixed_version = "2.4.0"
for dev_suffix in [".dev0", ".dev", ".dev1", "dev.a", ".devb"]:
with mock.patch("importlib_metadata.version", return_value=unsuffixed_version + dev_suffix):
with mlflow.start_run():
sparkm.log_model(spark_model=spark_model_iris.model, artifact_path="model")
model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path="model"
)

model_path = _download_artifact_from_uri(artifact_uri=model_uri)
pyfunc_conf = _get_flavor_configuration(
model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME
)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
with open(conda_env_path, "r") as f:
persisted_conda_env_dev = yaml.safe_load(f)
assert persisted_conda_env_dev == default_conda_env_standard
model_uri = mlflow.get_artifact_uri("model")
_assert_pip_requirements(model_uri, ["mlflow", f"pyspark=={unsuffixed_version}"])

for unaffected_version in ["2.0", "2.3.4", "2"]:
with mock.patch("importlib_metadata.version", return_value=unaffected_version):
assert unaffected_version in yaml.safe_dump(sparkm.get_default_conda_env())
pip_deps = _get_pip_deps(sparkm.get_default_conda_env())
assert any(x == f"pyspark=={unaffected_version}" for x in pip_deps)


@pytest.mark.large
Expand Down