Skip to content

Commit

Permalink
Disallow using a file URI as model version source (#8126)
Browse files Browse the repository at this point in the history
* Add a test

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Add a test

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Disallow file URI

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Add a test

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Ensure artifact URI is a file URI

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Add a new test

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Fix doc comment

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Add test case

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Fix is_local_uri

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Fix valiation logic

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Include 127.0.0.1

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Update comment

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Remove dot

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Update mlflow/environment_variables.py

Signed-off-by: Harutaka Kawamura <hkawamura0130@gmail.com>

---------

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: Harutaka Kawamura <hkawamura0130@gmail.com>
  • Loading branch information
harupy committed Mar 31, 2023
1 parent eae60d7 commit fae77a5
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 30 deletions.
8 changes: 8 additions & 0 deletions mlflow/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,11 @@ def get(self):
MLFLOW_DEFAULT_PREDICTION_DEVICE = _EnvironmentVariable(
"MLFLOW_DEFAULT_PREDICTION_DEVICE", str, None
)

#: Specifies whether or not to allow using a file URI as a model version source.
#: Please be aware that setting this environment variable to True is potentially risky
#: because it can allow access to arbitrary files on the specified filesystem
#: (default: ``False``).
MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE = _BooleanEnvironmentVariable(
"MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE", False
)
47 changes: 29 additions & 18 deletions mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@
from mlflow.utils.proto_json_utils import message_to_json, parse_dict
from mlflow.utils.validation import _validate_batch_log_api_req
from mlflow.utils.string_utils import is_string_type
from mlflow.utils.uri import is_local_uri
from mlflow.utils.uri import is_local_uri, is_file_uri
from mlflow.utils.file_utils import local_file_uri_to_path
from mlflow.tracking.registry import UnsupportedModelRegistryStoreURIException
from mlflow.environment_variables import MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE

_logger = logging.getLogger(__name__)
_tracking_store = None
Expand Down Expand Up @@ -1322,23 +1323,33 @@ def _delete_registered_model_tag():


def _validate_source(source: str, run_id: str) -> None:
if not is_local_uri(source):
return

if run_id:
store = _get_tracking_store()
run = store.get_run(run_id)
source = pathlib.Path(local_file_uri_to_path(source)).resolve()
run_artifact_dir = pathlib.Path(local_file_uri_to_path(run.info.artifact_uri)).resolve()
if run_artifact_dir in [source, *source.parents]:
return

raise MlflowException(
f"Invalid source: '{source}'. To use a local path as source, the run_id request parameter "
"has to be specified and the local path has to be contained within the artifact directory "
"of the run specified by the run_id.",
INVALID_PARAMETER_VALUE,
)
if is_local_uri(source):
if run_id:
store = _get_tracking_store()
run = store.get_run(run_id)
source = pathlib.Path(local_file_uri_to_path(source)).resolve()
run_artifact_dir = pathlib.Path(local_file_uri_to_path(run.info.artifact_uri)).resolve()
if run_artifact_dir in [source, *source.parents]:
return

raise MlflowException(
f"Invalid model version source: '{source}'. To use a local path as a model version "
"source, the run_id request parameter has to be specified and the local path has to be "
"contained within the artifact directory of the run specified by the run_id.",
INVALID_PARAMETER_VALUE,
)

# There might be file URIs that are local but can bypass the above check. To prevent this, we
# disallow using file URIs as model version sources by default unless it's explicitly allowed
# by setting the MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE environment variable to True.
if not MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE.get() and is_file_uri(source):
raise MlflowException(
f"Invalid model version source: '{source}'. MLflow tracking server doesn't allow using "
"a file URI as a model version source for security reasons. To disable this check, set "
f"the {MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE.name} environment variable to "
"True.",
INVALID_PARAMETER_VALUE,
)


@catch_mlflow_exception
Expand Down
10 changes: 9 additions & 1 deletion mlflow/utils/uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def is_local_uri(uri):
return False

parsed_uri = urllib.parse.urlparse(uri)
if parsed_uri.hostname:
if parsed_uri.hostname and not (
parsed_uri.hostname == "."
or parsed_uri.hostname.startswith("localhost")
or parsed_uri.hostname.startswith("127.0.0.1")
):
return False

scheme = parsed_uri.scheme
Expand All @@ -42,6 +46,10 @@ def is_local_uri(uri):
return False


def is_file_uri(uri):
return urllib.parse.urlparse(uri).scheme == "file"


def is_http_uri(uri):
scheme = urllib.parse.urlparse(uri).scheme
return scheme == "http" or scheme == "https"
Expand Down
3 changes: 2 additions & 1 deletion tests/tracking/integration_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _terminate_server(process, timeout=10):
process.wait(timeout=timeout)


def _init_server(backend_uri, root_artifact_uri):
def _init_server(backend_uri, root_artifact_uri, extra_env=None):
"""
Launch a new REST server using the tracking store specified by backend_uri and root artifact
directory specified by root_artifact_uri.
Expand All @@ -57,6 +57,7 @@ def _init_server(backend_uri, root_artifact_uri):
**os.environ,
BACKEND_STORE_URI_ENV_VAR: backend_uri,
ARTIFACT_ROOT_ENV_VAR: root_artifact_uri,
**(extra_env or {}),
},
)

Expand Down
87 changes: 77 additions & 10 deletions tests/tracking/test_rest_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,12 +1044,52 @@ def get(self, key, default=None):
)


def test_create_model_version_with_local_source(mlflow_client):
def test_create_model_version_with_path_source(mlflow_client):
name = "mode"
mlflow_client.create_registered_model(name)
exp_id = mlflow_client.create_experiment("test")
run = mlflow_client.create_run(experiment_id=exp_id)

response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": run.info.artifact_uri[len("file://") :],
"run_id": run.info.run_id,
},
)
assert response.status_code == 200

# run_id is not specified
response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": run.info.artifact_uri[len("file://") :],
},
)
assert response.status_code == 400
assert "To use a local path as a model version" in response.json()["message"]

# run_id is specified but source is not in the run's artifact directory
response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "/tmp",
"run_id": run.info.run_id,
},
)
assert response.status_code == 400
assert "To use a local path as a model version" in response.json()["message"]


def test_create_model_version_with_file_uri(mlflow_client):
name = "test"
mlflow_client.create_registered_model(name)
exp_id = mlflow_client.create_experiment("test")
run = mlflow_client.create_run(experiment_id=exp_id)
assert run.info.artifact_uri.startswith("file://")
response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
Expand Down Expand Up @@ -1090,38 +1130,65 @@ def test_create_model_version_with_local_source(mlflow_client):
)
assert response.status_code == 200

# run_id is not specified
response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": run.info.artifact_uri[len("file://") :],
"run_id": run.info.run_id,
"source": run.info.artifact_uri,
},
)
assert response.status_code == 200
assert response.status_code == 400
assert "To use a local path as a model version" in response.json()["message"]

# run_id is specified but source is not in the run's artifact directory
response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": run.info.artifact_uri,
"source": "file:///tmp",
},
)
assert response.status_code == 400
resp = response.json()
assert "Invalid source" in resp["message"]
assert "To use a local path as a model version" in response.json()["message"]

response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "/tmp",
"source": "file://123.456.789.123/path/to/source",
"run_id": run.info.run_id,
},
)
assert response.status_code == 400
resp = response.json()
assert "Invalid source" in resp["message"]
assert "MLflow tracking server doesn't allow" in response.json()["message"]


def test_create_model_version_with_file_uri_env_var(tmp_path):
backend_uri = tmp_path.joinpath("file").as_uri()
url, process = _init_server(
backend_uri,
root_artifact_uri=tmp_path.as_uri(),
extra_env={"MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE": "true"},
)
try:
mlflow_client = MlflowClient(url)

name = "test"
mlflow_client.create_registered_model(name)
exp_id = mlflow_client.create_experiment("test")
run = mlflow_client.create_run(experiment_id=exp_id)
response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "file://123.456.789.123/path/to/source",
"run_id": run.info.run_id,
},
)
assert response.status_code == 200
finally:
_terminate_server(process)


def test_logging_model_with_local_artifact_uri(mlflow_client):
Expand Down
5 changes: 5 additions & 0 deletions tests/utils/test_uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def test_is_local_uri():
assert is_local_uri("./mlruns")
assert is_local_uri("file:///foo/mlruns")
assert is_local_uri("file:foo/mlruns")
assert is_local_uri("file://./mlruns")
assert is_local_uri("file://localhost/mlruns")
assert is_local_uri("file://localhost:5000/mlruns")
assert is_local_uri("file://127.0.0.1/mlruns")
assert is_local_uri("file://127.0.0.1:5000/mlruns")

assert not is_local_uri("file://myhostname/path/to/file")
assert not is_local_uri("https://whatever")
Expand Down

0 comments on commit fae77a5

Please sign in to comment.