Skip to content

Commit

Permalink
Fix local uri (#10651)
Browse files Browse the repository at this point in the history
Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
  • Loading branch information
serena-ruan committed Jan 11, 2024
1 parent 1af5a88 commit 438a450
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 88 deletions.
9 changes: 0 additions & 9 deletions mlflow/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,6 @@ def get(self):
"MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE", str, "500MB"
)

#: Specifies whether or not to allow using a file URI as a model version source.
#: Please be aware that setting this environment variable to True is potentially risky
#: because it can allow access to arbitrary files on the specified filesystem
#: (default: ``False``).
MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE = _BooleanEnvironmentVariable(
"MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE", False
)


#: Specifies the name of the Databricks secret scope to use for storing OpenAI API keys.
MLFLOW_OPENAI_SECRET_SCOPE = _EnvironmentVariable("MLFLOW_OPENAI_SECRET_SCOPE", str, None)

Expand Down
19 changes: 2 additions & 17 deletions mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
from mlflow.entities import DatasetInput, ExperimentTag, FileInfo, Metric, Param, RunTag, ViewType
from mlflow.entities.model_registry import ModelVersionTag, RegisteredModelTag
from mlflow.entities.multipart_upload import MultipartUploadPart
from mlflow.environment_variables import (
MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE,
MLFLOW_DEPLOYMENTS_TARGET,
)
from mlflow.environment_variables import MLFLOW_DEPLOYMENTS_TARGET
from mlflow.exceptions import MlflowException, _UnsupportedMultipartUploadException
from mlflow.models import Model
from mlflow.protos import databricks_pb2
Expand Down Expand Up @@ -105,7 +102,7 @@
from mlflow.utils.promptlab_utils import _create_promptlab_run_impl
from mlflow.utils.proto_json_utils import message_to_json, parse_dict
from mlflow.utils.string_utils import is_string_type
from mlflow.utils.uri import is_file_uri, is_local_uri, validate_path_is_safe, validate_query_string
from mlflow.utils.uri import is_local_uri, validate_path_is_safe, validate_query_string
from mlflow.utils.validation import _validate_batch_log_api_req

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1604,18 +1601,6 @@ def _validate_source(source: str, run_id: str) -> None:
INVALID_PARAMETER_VALUE,
)

# There might be file URIs that are local but can bypass the above check. To prevent this, we
# disallow using file URIs as model version sources by default unless it's explicitly allowed
# by setting the MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE environment variable to True.
if not MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE.get() and is_file_uri(source):
raise MlflowException(
f"Invalid model version source: '{source}'. MLflow tracking server doesn't allow using "
"a file URI as a model version source for security reasons. To disable this check, set "
f"the {MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE} environment variable to "
"True.",
INVALID_PARAMETER_VALUE,
)

# Checks if relative paths are present in the source (a security threat). If any are present,
# raises an Exception.
_validate_non_local_source_contains_relative_paths(source)
Expand Down
18 changes: 13 additions & 5 deletions mlflow/utils/uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,32 @@ def is_local_uri(uri, is_tracking_or_registry_uri=True):
if scheme == "":
return True

if parsed_uri.hostname and not (
is_remote_hostname = parsed_uri.hostname and not (
parsed_uri.hostname == "."
or parsed_uri.hostname.startswith("localhost")
or parsed_uri.hostname.startswith("127.0.0.1")
):
return False

)
if scheme == "file":
if is_remote_hostname:
raise MlflowException(
f"{uri} is not a valid remote uri. For remote access "
"on windows, please consider using a different scheme "
"such as SMB (e.g. smb://<hostname>/<path>)."
)
return True

if is_remote_hostname:
return False

if is_windows() and len(scheme) == 1 and scheme.lower() == pathlib.Path(uri).drive.lower()[0]:
return True

return False


def is_file_uri(uri):
return urllib.parse.urlparse(uri).scheme == "file"
scheme = urllib.parse.urlparse(uri).scheme
return scheme == "file"


def is_http_uri(uri):
Expand Down
22 changes: 1 addition & 21 deletions tests/artifacts/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import mlflow
from mlflow.exceptions import MlflowException
from mlflow.utils.file_utils import local_file_uri_to_path, mkdir, path_to_local_file_uri
from mlflow.utils.file_utils import mkdir, path_to_local_file_uri
from mlflow.utils.os import is_windows

Artifact = namedtuple("Artifact", ["uri", "content"])
Expand Down Expand Up @@ -247,23 +247,3 @@ def test_log_artifact_windows_path_with_hostname(text_artifact):
rf"{experiment_test_1_artifact_location}\{run.info.run_id}"
rf"\artifacts\{text_artifact.artifact_name}" == local_path
)

experiment_test_2_artifact_location = "file://my_server/my_path/my_sub_path"
experiment_test_2_id = mlflow.create_experiment(
"test_exp_e", experiment_test_2_artifact_location
)
with mlflow.start_run(experiment_id=experiment_test_2_id) as run:
with mock.patch("shutil.copy2") as copyfile_mock, mock.patch(
"os.path.exists", return_value=True
) as exists_mock:
mlflow.log_artifact(text_artifact.artifact_path)
copyfile_mock.assert_called_once()
exists_mock.assert_called_once()
local_path = mlflow.artifacts.download_artifacts(
run_id=run.info.run_id, artifact_path=text_artifact.artifact_name
)
assert (
local_file_uri_to_path(experiment_test_2_artifact_location)
+ rf"\{run.info.run_id}\artifacts\{text_artifact.artifact_name}"
== local_path
)
6 changes: 3 additions & 3 deletions tests/store/tracking/test_file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,8 +2076,8 @@ def _assert_create_run_appends_to_artifact_uri_path_correctly(
("input_uri", "expected_uri"),
[
(
"file://my_server/my_path/my_sub_path",
"file://my_server/my_path/my_sub_path/{e}/{r}/artifacts",
"\\my_server/my_path/my_sub_path",
"file:///{drive}my_server/my_path/my_sub_path/{e}/{r}/artifacts",
),
("path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}/{r}/artifacts"),
("/path/to/local/folder", "file:///{drive}path/to/local/folder/{e}/{r}/artifacts"),
Expand Down Expand Up @@ -2179,7 +2179,7 @@ def _assert_create_experiment_appends_to_artifact_uri_path_correctly(
@pytest.mark.parametrize(
("input_uri", "expected_uri"),
[
("file://my_server/my_path/my_sub_path", "file://my_server/my_path/my_sub_path/{e}"),
("\\my_server/my_path/my_sub_path", "file:///{drive}my_server/my_path/my_sub_path/{e}"),
("path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}"),
("/path/to/local/folder", "file:///{drive}path/to/local/folder/{e}"),
("#path/to/local/folder?", "file://{cwd}/{e}#path/to/local/folder?"),
Expand Down
6 changes: 3 additions & 3 deletions tests/store/tracking/test_sqlalchemy_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3738,7 +3738,7 @@ def _assert_create_experiment_appends_to_artifact_uri_path_correctly(
@pytest.mark.parametrize(
("input_uri", "expected_uri"),
[
("file://my_server/my_path/my_sub_path", "file://my_server/my_path/my_sub_path/{e}"),
("\\my_server/my_path/my_sub_path", "file:///{drive}my_server/my_path/my_sub_path/{e}"),
("path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}"),
("/path/to/local/folder", "file:///{drive}path/to/local/folder/{e}"),
("#path/to/local/folder?", "file://{cwd}/{e}#path/to/local/folder?"),
Expand Down Expand Up @@ -3841,8 +3841,8 @@ def _assert_create_run_appends_to_artifact_uri_path_correctly(
("input_uri", "expected_uri"),
[
(
"file://my_server/my_path/my_sub_path",
"file://my_server/my_path/my_sub_path/{e}/{r}/artifacts",
"\\my_server/my_path/my_sub_path",
"file:///{drive}my_server/my_path/my_sub_path/{e}/{r}/artifacts",
),
("path/to/local/folder", "file://{cwd}/path/to/local/folder/{e}/{r}/artifacts"),
("/path/to/local/folder", "file:///{drive}path/to/local/folder/{e}/{r}/artifacts"),
Expand Down
28 changes: 2 additions & 26 deletions tests/tracking/test_rest_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,32 +1389,8 @@ def test_create_model_version_with_file_uri(mlflow_client):
"run_id": run.info.run_id,
},
)
assert response.status_code == 400
assert "MLflow tracking server doesn't allow" in response.json()["message"]


def test_create_model_version_with_file_uri_env_var(tmp_path):
backend_uri = tmp_path.joinpath("file").as_uri()
with _init_server(
backend_uri,
root_artifact_uri=tmp_path.as_uri(),
extra_env={"MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE": "true"},
) as url:
mlflow_client = MlflowClient(url)

name = "test"
mlflow_client.create_registered_model(name)
exp_id = mlflow_client.create_experiment("test")
run = mlflow_client.create_run(experiment_id=exp_id)
response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "file://123.456.789.123/path/to/source",
"run_id": run.info.run_id,
},
)
assert response.status_code == 200
assert response.status_code == 500, response.json()
assert "is not a valid remote uri" in response.json()["message"]


def test_logging_model_with_local_artifact_uri(mlflow_client):
Expand Down
1 change: 0 additions & 1 deletion tests/utils/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ def test_handle_readonly_on_windows(tmp_path):
@pytest.mark.parametrize(
("input_uri", "expected_path"),
[
("file://my_server/my_path/my_sub_path", r"\\my_server\my_path\my_sub_path"),
(r"\\my_server\my_path\my_sub_path", r"\\my_server\my_path\my_sub_path"),
],
)
Expand Down
8 changes: 5 additions & 3 deletions tests/utils/test_uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,15 @@ def test_is_local_uri():
assert is_local_uri("//proc/self/root")
assert is_local_uri("/proc/self/root")

assert not is_local_uri("file://myhostname/path/to/file")
assert not is_local_uri("https://whatever")
assert not is_local_uri("http://whatever")
assert not is_local_uri("databricks")
assert not is_local_uri("databricks:whatever")
assert not is_local_uri("databricks://whatever")

with pytest.raises(MlflowException, match="is not a valid remote uri."):
is_local_uri("file://myhostname/path/to/file")


@pytest.mark.skipif(not is_windows(), reason="Windows-only test")
def test_is_local_uri_windows():
Expand Down Expand Up @@ -682,7 +684,7 @@ def _assert_resolve_uri_if_local(input_uri, expected_uri):
[
("my/path", "{cwd}/my/path"),
("#my/path?a=b", "{cwd}/#my/path?a=b"),
("file://myhostname/my/path", "file://myhostname/my/path"),
("file://localhost/my/path", "file://localhost/my/path"),
("file:///my/path", "file:///{drive}my/path"),
("file:my/path", "file://{cwd}/my/path"),
("/home/my/path", "/home/my/path"),
Expand All @@ -700,7 +702,7 @@ def test_resolve_uri_if_local(input_uri, expected_uri):
[
("my/path", "file://{cwd}/my/path"),
("#my/path?a=b", "file://{cwd}/#my/path?a=b"),
("file://myhostname/my/path", "file://myhostname/my/path"),
("\\myhostname/my/path", "file:///{drive}myhostname/my/path"),
("file:///my/path", "file:///{drive}my/path"),
("file:my/path", "file://{cwd}/my/path"),
("/home/my/path", "file:///{drive}home/my/path"),
Expand Down

0 comments on commit 438a450

Please sign in to comment.