Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable auto dependency inference in spark flavor #4759

Merged
merged 10 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 26 additions & 17 deletions mlflow/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,13 +407,33 @@ def _save_model_metadata(
if input_example is not None:
_save_example(mlflow_model, input_example, dst_dir)

conda_env, pip_requirements, pip_constraints = (
_process_pip_requirements(
get_default_pip_requirements(), pip_requirements, extra_pip_requirements,
)
if conda_env is None
else _process_conda_env(conda_env)
mlflow_model.add_flavor(
FLAVOR_NAME, pyspark_version=pyspark.__version__, model_data=_SPARK_MODEL_PATH_SUB
)
pyfunc.add_to_model(
mlflow_model,
loader_module="mlflow.spark",
data=_SPARK_MODEL_PATH_SUB,
env=_CONDA_ENV_FILE_NAME,
)
mlflow_model.save(os.path.join(dst_dir, MLMODEL_FILE_NAME))

if conda_env is None:
if pip_requirements is None:
default_reqs = get_default_pip_requirements()
# To ensure `_load_pyfunc` can successfully load the model during the dependency
# inference, `mlflow_model.save` must be called beforehand to save an MLmodel file.
inferred_reqs = mlflow.models.infer_pip_requirements(
dst_dir, FLAVOR_NAME, fallback=default_reqs,
)
default_reqs = sorted(set(inferred_reqs).union(default_reqs))
else:
default_reqs = None
conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
default_reqs, pip_requirements, extra_pip_requirements,
)
else:
conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)

with open(os.path.join(dst_dir, _CONDA_ENV_FILE_NAME), "w") as f:
yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
Expand All @@ -425,17 +445,6 @@ def _save_model_metadata(
# Save `requirements.txt`
write_to(os.path.join(dst_dir, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))

mlflow_model.add_flavor(
FLAVOR_NAME, pyspark_version=pyspark.__version__, model_data=_SPARK_MODEL_PATH_SUB
)
pyfunc.add_to_model(
mlflow_model,
loader_module="mlflow.spark",
data=_SPARK_MODEL_PATH_SUB,
env=_CONDA_ENV_FILE_NAME,
)
mlflow_model.save(os.path.join(dst_dir, MLMODEL_FILE_NAME))


def _validate_model(spark_model):
from pyspark.ml.util import MLReadable, MLWritable
Expand Down
12 changes: 12 additions & 0 deletions mlflow/utils/_capture_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mlflow.utils.file_utils import write_to
from mlflow.pyfunc import MAIN
from mlflow.models.model import MLMODEL_FILE_NAME, Model
from mlflow.utils.databricks_utils import is_in_databricks_runtime


def _get_top_level_module(full_module_name):
Expand Down Expand Up @@ -82,6 +83,17 @@ def main():
# Mirror `sys.path` of the parent process
sys.path = json.loads(args.sys_path)

if flavor == mlflow.spark.FLAVOR_NAME and is_in_databricks_runtime():
try:
# pylint: disable=import-error
from dbruntime.spark_connection import initialize_spark_connection

initialize_spark_connection()
except Exception as e:
raise Exception(
"Attempted to initialize a spark session to load the spark model, but failed"
) from e

cap_cm = _CaptureImportedModules()

# If `model_path` refers to an MLflow model directory, load the model using
Expand Down
15 changes: 12 additions & 3 deletions mlflow/utils/requirements_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ def _get_installed_version(package, module=None):
# 1.9.0
version = __import__(module or package).__version__

# In Databricks, strip a dev version suffix for pyspark (e.g. '3.1.2.dev0' -> '3.1.2')
# and make it installable from PyPI.
if package == "pyspark" and is_in_databricks_runtime():
# Strip the suffix from `dev` versions of PySpark, which are not available for installation
# from Anaconda or PyPI
if package == "pyspark":
version = _strip_dev_version_suffix(version)

return version
Expand Down Expand Up @@ -250,7 +250,16 @@ def _infer_requirements(model_uri, flavor):
"""
global _MODULES_TO_PACKAGES
if _MODULES_TO_PACKAGES is None:
# Note `importlib_metada.packages_distributions` only captures packages installed into
# Python’s site-packages directory via tools such as pip:
# https://importlib-metadata.readthedocs.io/en/latest/using.html#using-importlib-metadata
_MODULES_TO_PACKAGES = importlib_metadata.packages_distributions()

# In Databricks, `_MODULES_TO_PACKAGES` doesn't contain pyspark since it's not installed
# via pip or conda. To work around this issue, manually add pyspark.
if is_in_databricks_runtime():
_MODULES_TO_PACKAGES.update({"pyspark": ["pyspark"]})

modules = _capture_imported_modules(model_uri, flavor)
packages = _flatten([_MODULES_TO_PACKAGES.get(module, []) for module in modules])
packages = map(_canonicalize_package_name, packages)
Expand Down
51 changes: 12 additions & 39 deletions tests/spark/test_spark_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from tests.helper_functions import (
score_model_in_sagemaker_docker_container,
_compare_conda_env_requirements,
_get_pip_deps,
_assert_pip_requirements,
)
from tests.pyfunc.test_spark import score_model_as_udf, get_spark_session
Expand Down Expand Up @@ -601,13 +602,7 @@ def test_sparkml_model_save_without_specified_conda_env_uses_default_env_with_ex
spark_model_iris, model_path
):
sparkm.save_model(spark_model=spark_model_iris.model, path=model_path)

pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
with open(conda_env_path, "r") as f:
conda_env = yaml.safe_load(f)

assert conda_env == sparkm.get_default_conda_env()
_assert_pip_requirements(model_path, sparkm.get_default_pip_requirements())


@pytest.mark.large
Expand All @@ -617,47 +612,25 @@ def test_sparkml_model_log_without_specified_conda_env_uses_default_env_with_exp
artifact_path = "model"
with mlflow.start_run():
sparkm.log_model(spark_model=spark_model_iris.model, artifact_path=artifact_path)
model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path
)

model_path = _download_artifact_from_uri(artifact_uri=model_uri)
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
with open(conda_env_path, "r") as f:
conda_env = yaml.safe_load(f)
model_uri = mlflow.get_artifact_uri(artifact_path)

assert conda_env == sparkm.get_default_conda_env()
_assert_pip_requirements(model_uri, sparkm.get_default_pip_requirements())


@pytest.mark.large
def test_default_conda_env_strips_dev_suffix_from_pyspark_version(spark_model_iris, model_path):
with mock.patch("importlib_metadata.version", return_value="2.4.0"):
default_conda_env_standard = sparkm.get_default_conda_env()

for dev_version in ["2.4.0.dev0", "2.4.0.dev", "2.4.0.dev1", "2.4.0dev.a", "2.4.0.devb"]:
with mock.patch("importlib_metadata.version", return_value=dev_version):
default_conda_env_dev = sparkm.get_default_conda_env()
assert default_conda_env_dev == default_conda_env_standard

def test_pyspark_version_is_logged_without_dev_suffix(spark_model_iris):
unsuffixed_version = "2.4.0"
for dev_suffix in [".dev0", ".dev", ".dev1", "dev.a", ".devb"]:
with mock.patch("importlib_metadata.version", return_value=unsuffixed_version + dev_suffix):
with mlflow.start_run():
sparkm.log_model(spark_model=spark_model_iris.model, artifact_path="model")
model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path="model"
)

model_path = _download_artifact_from_uri(artifact_uri=model_uri)
pyfunc_conf = _get_flavor_configuration(
model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME
)
conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
with open(conda_env_path, "r") as f:
persisted_conda_env_dev = yaml.safe_load(f)
assert persisted_conda_env_dev == default_conda_env_standard
model_uri = mlflow.get_artifact_uri("model")
_assert_pip_requirements(model_uri, ["mlflow", f"pyspark=={unsuffixed_version}"])

for unaffected_version in ["2.0", "2.3.4", "2"]:
with mock.patch("importlib_metadata.version", return_value=unaffected_version):
assert unaffected_version in yaml.safe_dump(sparkm.get_default_conda_env())
pip_deps = _get_pip_deps(sparkm.get_default_conda_env())
assert any(x == f"pyspark=={unaffected_version}" for x in pip_deps)


@pytest.mark.large
Expand Down