Skip to content

Commit

Permalink
Fix potential issues with PyFuncBackend in cli (#9053)
Browse files Browse the repository at this point in the history
Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
  • Loading branch information
serena-ruan committed Jul 18, 2023
1 parent 330bf0b commit 6dde937
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 19 deletions.
29 changes: 29 additions & 0 deletions mlflow/pyfunc/_mlflow_pyfunc_backend_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
This script should be executed in a fresh python interpreter process using `subprocess`.
"""
import argparse

from mlflow.pyfunc.scoring_server import _predict


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model-uri", required=True)
parser.add_argument("--input-path", required=False)
parser.add_argument("--output-path", required=False)
parser.add_argument("--content-type", required=True)
return parser.parse_args()


def main():
args = parse_args()
_predict(
model_uri=args.model_uri,
input_path=args.input_path if args.input_path else None,
output_path=args.output_path if args.output_path else None,
content_type=args.content_type,
)


if __name__ == "__main__":
main()
29 changes: 15 additions & 14 deletions mlflow/pyfunc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pathlib
import subprocess
import posixpath
import shlex
import sys
import warnings
import ctypes
Expand All @@ -24,6 +25,7 @@
from mlflow.utils.conda import get_or_create_conda_env, get_conda_bin_executable
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils import env_manager as _EnvManager
from mlflow.pyfunc import _mlflow_pyfunc_backend_predict
from mlflow.utils.file_utils import (
path_to_local_file_uri,
get_or_create_tmp_dir,
Expand Down Expand Up @@ -143,20 +145,19 @@ def predict(self, model_uri, input_path, output_path, content_type):
local_uri = path_to_local_file_uri(local_path)

if self._env_manager != _EnvManager.LOCAL:
command = (
'python -c "from mlflow.pyfunc.scoring_server import _predict; _predict('
"model_uri={model_uri}, "
"input_path={input_path}, "
"output_path={output_path}, "
"content_type={content_type})"
'"'
).format(
model_uri=repr(local_uri),
input_path=repr(input_path),
output_path=repr(output_path),
content_type=repr(content_type),
)
return self.prepare_env(local_path).execute(command)
predict_cmd = [
"python",
_mlflow_pyfunc_backend_predict.__file__,
"--model-uri",
str(local_uri),
"--content-type",
shlex.quote(str(content_type)),
]
if input_path:
predict_cmd += ["--input-path", shlex.quote(str(input_path))]
if output_path:
predict_cmd += ["--output-path", shlex.quote(str(output_path))]
return self.prepare_env(local_path).execute(" ".join(predict_cmd))
else:
scoring_server._predict(local_uri, input_path, output_path, content_type)

Expand Down
12 changes: 8 additions & 4 deletions mlflow/pyfunc/scoring_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import logging
import os
import shlex
import sys
import traceback

Expand All @@ -31,6 +32,7 @@
from mlflow.types import Schema
from mlflow.utils import reraise
from mlflow.utils.file_utils import path_to_local_file_uri
from mlflow.utils.os import is_windows
from mlflow.utils.proto_json_utils import (
NumpyEncoder,
dataframe_from_parsed_json,
Expand Down Expand Up @@ -328,14 +330,16 @@ def get_cmd(
) -> Tuple[str, Dict[str, str]]:
local_uri = path_to_local_file_uri(model_uri)
timeout = timeout or MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT.get()

# NB: Absolute windows paths do not work with mlflow apis, use file uri to ensure
# platform compatibility.
if os.name != "nt":
if not is_windows():
args = [f"--timeout={timeout}"]
if port and host:
args.append(f"-b {host}:{port}")
address = shlex.quote(f"{host}:{port}")
args.append(f"-b {address}")
elif host:
args.append(f"-b {host}")
args.append(f"-b {shlex.quote(host)}")

if nworkers:
args.append(f"-w {nworkers}")
Expand All @@ -347,7 +351,7 @@ def get_cmd(
else:
args = []
if host:
args.append(f"--host={host}")
args.append(f"--host={shlex.quote(host)}")

if port:
args.append(f"--port={port}")
Expand Down
190 changes: 189 additions & 1 deletion tests/models/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from mlflow.environment_variables import MLFLOW_DISABLE_ENV_MANAGER_CONDA_WARNING
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import ErrorCode, BAD_REQUEST
from mlflow.pyfunc.backend import PyFuncBackend
from mlflow.pyfunc.scoring_server import (
CONTENT_TYPE_JSON,
CONTENT_TYPE_CSV,
Expand All @@ -36,6 +37,7 @@
from mlflow.utils.environment import _mlflow_conda_env
from mlflow.utils import env_manager as _EnvManager
from mlflow.utils import PYTHON_VERSION
from mlflow.utils.process import ShellCommandException
from tests.helper_functions import (
pyfunc_build_image,
pyfunc_serve_from_docker_image,
Expand Down Expand Up @@ -174,7 +176,7 @@ def test_predict(iris_data, sk_model):
with mlflow.start_run() as active_run:
mlflow.sklearn.log_model(sk_model, "model", registered_model_name="impredicting")
model_uri = f"runs:/{active_run.info.run_id}/model"
model_registry_uri = "models:/{name}/{stage}".format(name="impredicting", stage="None")
model_registry_uri = "models:/impredicting/None"
input_json_path = tmp.path("input.json")
input_csv_path = tmp.path("input.csv")
output_json_path = tmp.path("output.json")
Expand Down Expand Up @@ -331,6 +333,173 @@ def test_predict(iris_data, sk_model):
assert all(expected == actual)


def test_predict_check_content_type(iris_data, sk_model, tmp_path):
with mlflow.start_run():
mlflow.sklearn.log_model(sk_model, "model", registered_model_name="impredicting")
model_registry_uri = "models:/impredicting/None"
input_json_path = tmp_path / "input.json"
input_csv_path = tmp_path / "input.csv"
output_json_path = tmp_path / "output.json"

x, _ = iris_data
with input_json_path.open("w") as f:
json.dump({"dataframe_split": pd.DataFrame(x).to_dict(orient="split")}, f)

pd.DataFrame(x).to_csv(input_csv_path, index=False)

# Throw errors for invalid content_type
prc = subprocess.run(
[
"mlflow",
"models",
"predict",
"-m",
model_registry_uri,
"-i",
input_json_path,
"-o",
output_json_path,
"-t",
"invalid",
"--env-manager",
"local",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env_with_tracking_uri(),
check=False,
)
assert prc.returncode != 0
assert "Unknown content type" in prc.stderr.decode("utf-8")


def test_predict_check_input_path(iris_data, sk_model, tmp_path):
with mlflow.start_run():
mlflow.sklearn.log_model(sk_model, "model", registered_model_name="impredicting")
model_registry_uri = "models:/impredicting/None"
input_json_path = tmp_path / "input with space.json"
input_csv_path = tmp_path / "input.csv"
output_json_path = tmp_path / "output.json"

x, _ = iris_data
with input_json_path.open("w") as f:
json.dump({"dataframe_split": pd.DataFrame(x).to_dict(orient="split")}, f)

pd.DataFrame(x).to_csv(input_csv_path, index=False)

# Valid input path with space
prc = subprocess.run(
[
"mlflow",
"models",
"predict",
"-m",
model_registry_uri,
"-i",
f"{input_json_path}",
"-o",
output_json_path,
"--env-manager",
"local",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env_with_tracking_uri(),
check=False,
text=True,
)
assert prc.returncode == 0

# Throw errors for invalid input_path
prc = subprocess.run(
[
"mlflow",
"models",
"predict",
"-m",
model_registry_uri,
"-i",
f'{input_json_path}"; echo ThisIsABug! "',
"-o",
output_json_path,
"--env-manager",
"local",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env_with_tracking_uri(),
check=False,
text=True,
)
assert prc.returncode != 0
assert "ThisIsABug!" not in prc.stdout
assert "FileNotFoundError" in prc.stderr

prc = subprocess.run(
[
"mlflow",
"models",
"predict",
"-m",
model_registry_uri,
"-i",
f'{input_csv_path}"; echo ThisIsABug! "',
"-o",
output_json_path,
"-t",
"csv",
"--env-manager",
"local",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env_with_tracking_uri(),
check=False,
text=True,
)
assert prc.returncode != 0
assert "ThisIsABug!" not in prc.stdout
assert "FileNotFoundError" in prc.stderr


def test_predict_check_output_path(iris_data, sk_model, tmp_path):
with mlflow.start_run():
mlflow.sklearn.log_model(sk_model, "model", registered_model_name="impredicting")
model_registry_uri = "models:/impredicting/None"
input_json_path = tmp_path / "input.json"
input_csv_path = tmp_path / "input.csv"
output_json_path = tmp_path / "output.json"

x, _ = iris_data
with input_json_path.open("w") as f:
json.dump({"dataframe_split": pd.DataFrame(x).to_dict(orient="split")}, f)

pd.DataFrame(x).to_csv(input_csv_path, index=False)

prc = subprocess.run(
[
"mlflow",
"models",
"predict",
"-m",
model_registry_uri,
"-i",
input_json_path,
"-o",
f'{output_json_path}"; echo ThisIsABug! "',
"--env-manager",
"local",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env_with_tracking_uri(),
check=False,
text=True,
)
assert prc.returncode == 0
assert "ThisIsABug!" not in prc.stdout


def test_prepare_env_passes(sk_model):
if no_conda:
pytest.skip("This test requires conda.")
Expand Down Expand Up @@ -574,6 +743,25 @@ def test_env_manager_unsupported_value():
)


def test_host_invalid_value():
class MyModel(mlflow.pyfunc.PythonModel):
def predict(self, ctx, model_input):
return model_input

with mlflow.start_run():
model_info = mlflow.pyfunc.log_model(
python_model=MyModel(), artifact_path="test_model", registered_model_name="model"
)

with mock.patch("mlflow.models.cli.get_flavor_backend", return_value=PyFuncBackend({})):
with pytest.raises(ShellCommandException, match=r"Non-zero exit code: 1"):
CliRunner().invoke(
models_cli.serve,
["--model-uri", model_info.model_uri, "--host", "localhost & echo BUG"],
catch_exceptions=False,
)


def test_change_conda_env_root_location(tmp_path, sk_model):
env_root1_path = tmp_path / "root1"
env_root1_path.mkdir()
Expand Down

0 comments on commit 6dde937

Please sign in to comment.