diff --git a/mlflow/cli.py b/mlflow/cli.py index 4271be78fc3f2..48c57dbbf7094 100644 --- a/mlflow/cli.py +++ b/mlflow/cli.py @@ -1,5 +1,4 @@ import contextlib -import importlib.metadata import json import logging import os @@ -27,7 +26,7 @@ from mlflow.tracking import _get_store from mlflow.utils import cli_args from mlflow.utils.logging_utils import eprint -from mlflow.utils.os import is_windows +from mlflow.utils.os import get_entry_points, is_windows from mlflow.utils.process import ShellCommandException from mlflow.utils.server_cli_utils import ( artifacts_only_config_validation, @@ -353,7 +352,7 @@ def _validate_static_prefix(ctx, param, value): # pylint: disable=unused-argume @click.option( "--app-name", default=None, - type=click.Choice([e.name for e in importlib.metadata.entry_points().get("mlflow.app", [])]), + type=click.Choice([e.name for e in get_entry_points("mlflow.app")]), show_default=True, help=( "Application name to be used for the tracking server. " diff --git a/mlflow/server/__init__.py b/mlflow/server/__init__.py index 3177d63e93104..2c57dac526008 100644 --- a/mlflow/server/__init__.py +++ b/mlflow/server/__init__.py @@ -22,7 +22,7 @@ search_datasets_handler, upload_artifact_handler, ) -from mlflow.utils.os import is_windows +from mlflow.utils.os import get_entry_points, is_windows from mlflow.utils.process import _exec_cmd from mlflow.version import VERSION @@ -142,7 +142,7 @@ def serve(): def _find_app(app_name: str) -> str: - apps = importlib.metadata.entry_points().get("mlflow.app", []) + apps = get_entry_points("mlflow.app") for app in apps: if app.name == app_name: return app.value @@ -177,7 +177,7 @@ def get_app_client(app_name: str, *args, **kwargs): Returns: An app client instance. """ - clients = importlib.metadata.entry_points().get("mlflow.app.client", []) + clients = get_entry_points("mlflow.app.client") for client in clients: if client.name == app_name: cls = client.load() diff --git a/mlflow/utils/os.py b/mlflow/utils/os.py index ff5be9b4d685c..ad50f005adf7d 100644 --- a/mlflow/utils/os.py +++ b/mlflow/utils/os.py @@ -1,4 +1,6 @@ +import importlib.metadata import os +import sys def is_windows(): @@ -10,3 +12,13 @@ def is_windows(): """ return os.name == "nt" + + +def get_entry_points(namespace): + if sys.version_info >= (3, 10): + return importlib.metadata.entry_points(group=namespace) + else: + try: + return importlib.metadata.entry_points().get(namespace, []) + except AttributeError: + return importlib.metadata.entry_points().select(group=namespace)