Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1] Copy gateway server CLI to mlflow deployments start-server #10426

Merged
merged 13 commits into from
Nov 21, 2023
3 changes: 2 additions & 1 deletion .github/workflows/deployments.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ jobs:
- uses: ./.github/actions/setup-python
- name: Install dependencies
run: |
pip install --no-dependencies tests/resources/mlflow-test-plugin
pip install .[gateway] \
pytest pytest-timeout pytest-asyncio httpx psutil
- name: Run tests
run: |
pytest tests/deployments/databricks tests/deployments/mlflow
pytest tests/deployments
42 changes: 42 additions & 0 deletions mlflow/deployments/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import click

from mlflow.deployments import interface
from mlflow.environment_variables import MLFLOW_DEPLOYMENTS_CONFIG
from mlflow.utils import cli_args
from mlflow.utils.annotations import experimental
from mlflow.utils.proto_json_utils import NumpyEncoder, _get_jsonable_obj


Expand Down Expand Up @@ -451,3 +453,43 @@ def get_endpoint(target, endpoint):
for key, val in desc.items():
click.echo(f"{key}: {val}")
click.echo("\n")


def validate_config_path(_ctx, _param, value):
from mlflow.gateway.config import _validate_config
dbczumar marked this conversation as resolved.
Show resolved Hide resolved

try:
_validate_config(value)
return value
except Exception as e:
raise click.BadParameter(str(e))


@experimental
@commands.command("start-server", help="Start the MLflow Deployments server")
@click.option(
"--config-path",
envvar=MLFLOW_DEPLOYMENTS_CONFIG.name,
callback=validate_config_path,
required=True,
help="The path to the deployments configuration file.",
)
@click.option(
"--host",
default="127.0.0.1",
help="The network address to listen on (default: 127.0.0.1).",
)
@click.option(
"--port",
default=5000,
help="The port to listen on (default: 5000).",
)
@click.option(
"--workers",
default=2,
help="The number of workers.",
)
def start_server(config_path: str, host: str, port: str, workers: int):
from mlflow.gateway.runner import run_app

run_app(config_path=config_path, host=host, port=port, workers=workers)
4 changes: 4 additions & 0 deletions mlflow/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,10 @@ def get(self):
#: (default: ``None``)
MLFLOW_GATEWAY_CONFIG = _EnvironmentVariable("MLFLOW_GATEWAY_CONFIG", str, None)

#: Specifies the path of the config file for the MLflow Deployments server.
#: (default: ``None``)
MLFLOW_DEPLOYMENTS_CONFIG = _EnvironmentVariable("MLFLOW_DEPLOYMENTS_CONFIG", str, None)

#: Specifies whether to display the progress bar when uploading/downloading artifacts.
#: (default: ``True``)
MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR = _BooleanEnvironmentVariable(
Expand Down
53 changes: 53 additions & 0 deletions tests/deployments/test_server_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
from click.testing import CliRunner

from mlflow.deployments import cli

pytest.importorskip("mlflow.gateway")


def test_start_help():
runner = CliRunner()
res = runner.invoke(
cli.start_server,
["--help"],
catch_exceptions=False,
)
assert res.exit_code == 0


def test_start_invalid_config(tmp_path):
runner = CliRunner()
config = tmp_path.joinpath("config.yml")
res = runner.invoke(
cli.start_server,
["--config-path", config],
catch_exceptions=False,
)
assert res.exit_code == 2
assert "does not exist" in res.output

config.write_text("\t")
res = runner.invoke(
cli.start_server,
["--config-path", config],
catch_exceptions=False,
)
assert res.exit_code == 2
assert "not a valid yaml file" in res.output

config.write_text(
"""
routes:
- model:
name: invalid
"""
)
res = runner.invoke(
cli.start_server,
["--config-path", config],
catch_exceptions=False,
)
assert res.exit_code == 2
assert "The gateway configuration is invalid" in res.output
assert "routes" in res.output