Skip to content

Commit fae77a5

Browse files
authored
Disallow using a file URI as model version source (#8126)
* Add a test Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Add a test Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Disallow file URI Signed-off-by: harupy <hkawamura0130@gmail.com> * Add a test Signed-off-by: harupy <hkawamura0130@gmail.com> * Ensure artifact URI is a file URI Signed-off-by: harupy <hkawamura0130@gmail.com> * Add a new test Signed-off-by: harupy <hkawamura0130@gmail.com> * Fix doc comment Signed-off-by: harupy <hkawamura0130@gmail.com> * Add test case Signed-off-by: harupy <hkawamura0130@gmail.com> * Fix is_local_uri Signed-off-by: harupy <hkawamura0130@gmail.com> * Fix valiation logic Signed-off-by: harupy <hkawamura0130@gmail.com> * Include 127.0.0.1 Signed-off-by: harupy <hkawamura0130@gmail.com> * Update comment Signed-off-by: harupy <hkawamura0130@gmail.com> * Remove dot Signed-off-by: harupy <hkawamura0130@gmail.com> * Update mlflow/environment_variables.py Signed-off-by: Harutaka Kawamura <hkawamura0130@gmail.com> --------- Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> Signed-off-by: harupy <hkawamura0130@gmail.com> Signed-off-by: Harutaka Kawamura <hkawamura0130@gmail.com>
1 parent eae60d7 commit fae77a5

File tree

6 files changed

+130
-30
lines changed

6 files changed

+130
-30
lines changed

Diff for: mlflow/environment_variables.py

+8
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,11 @@ def get(self):
202202
MLFLOW_DEFAULT_PREDICTION_DEVICE = _EnvironmentVariable(
203203
"MLFLOW_DEFAULT_PREDICTION_DEVICE", str, None
204204
)
205+
206+
#: Specifies whether or not to allow using a file URI as a model version source.
207+
#: Please be aware that setting this environment variable to True is potentially risky
208+
#: because it can allow access to arbitrary files on the specified filesystem
209+
#: (default: ``False``).
210+
MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE = _BooleanEnvironmentVariable(
211+
"MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE", False
212+
)

Diff for: mlflow/server/handlers.py

+29-18
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,10 @@
8484
from mlflow.utils.proto_json_utils import message_to_json, parse_dict
8585
from mlflow.utils.validation import _validate_batch_log_api_req
8686
from mlflow.utils.string_utils import is_string_type
87-
from mlflow.utils.uri import is_local_uri
87+
from mlflow.utils.uri import is_local_uri, is_file_uri
8888
from mlflow.utils.file_utils import local_file_uri_to_path
8989
from mlflow.tracking.registry import UnsupportedModelRegistryStoreURIException
90+
from mlflow.environment_variables import MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE
9091

9192
_logger = logging.getLogger(__name__)
9293
_tracking_store = None
@@ -1322,23 +1323,33 @@ def _delete_registered_model_tag():
13221323

13231324

13241325
def _validate_source(source: str, run_id: str) -> None:
1325-
if not is_local_uri(source):
1326-
return
1327-
1328-
if run_id:
1329-
store = _get_tracking_store()
1330-
run = store.get_run(run_id)
1331-
source = pathlib.Path(local_file_uri_to_path(source)).resolve()
1332-
run_artifact_dir = pathlib.Path(local_file_uri_to_path(run.info.artifact_uri)).resolve()
1333-
if run_artifact_dir in [source, *source.parents]:
1334-
return
1335-
1336-
raise MlflowException(
1337-
f"Invalid source: '{source}'. To use a local path as source, the run_id request parameter "
1338-
"has to be specified and the local path has to be contained within the artifact directory "
1339-
"of the run specified by the run_id.",
1340-
INVALID_PARAMETER_VALUE,
1341-
)
1326+
if is_local_uri(source):
1327+
if run_id:
1328+
store = _get_tracking_store()
1329+
run = store.get_run(run_id)
1330+
source = pathlib.Path(local_file_uri_to_path(source)).resolve()
1331+
run_artifact_dir = pathlib.Path(local_file_uri_to_path(run.info.artifact_uri)).resolve()
1332+
if run_artifact_dir in [source, *source.parents]:
1333+
return
1334+
1335+
raise MlflowException(
1336+
f"Invalid model version source: '{source}'. To use a local path as a model version "
1337+
"source, the run_id request parameter has to be specified and the local path has to be "
1338+
"contained within the artifact directory of the run specified by the run_id.",
1339+
INVALID_PARAMETER_VALUE,
1340+
)
1341+
1342+
# There might be file URIs that are local but can bypass the above check. To prevent this, we
1343+
# disallow using file URIs as model version sources by default unless it's explicitly allowed
1344+
# by setting the MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE environment variable to True.
1345+
if not MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE.get() and is_file_uri(source):
1346+
raise MlflowException(
1347+
f"Invalid model version source: '{source}'. MLflow tracking server doesn't allow using "
1348+
"a file URI as a model version source for security reasons. To disable this check, set "
1349+
f"the {MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE.name} environment variable to "
1350+
"True.",
1351+
INVALID_PARAMETER_VALUE,
1352+
)
13421353

13431354

13441355
@catch_mlflow_exception

Diff for: mlflow/utils/uri.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ def is_local_uri(uri):
2929
return False
3030

3131
parsed_uri = urllib.parse.urlparse(uri)
32-
if parsed_uri.hostname:
32+
if parsed_uri.hostname and not (
33+
parsed_uri.hostname == "."
34+
or parsed_uri.hostname.startswith("localhost")
35+
or parsed_uri.hostname.startswith("127.0.0.1")
36+
):
3337
return False
3438

3539
scheme = parsed_uri.scheme
@@ -42,6 +46,10 @@ def is_local_uri(uri):
4246
return False
4347

4448

49+
def is_file_uri(uri):
50+
return urllib.parse.urlparse(uri).scheme == "file"
51+
52+
4553
def is_http_uri(uri):
4654
scheme = urllib.parse.urlparse(uri).scheme
4755
return scheme == "http" or scheme == "https"

Diff for: tests/tracking/integration_test_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _terminate_server(process, timeout=10):
3838
process.wait(timeout=timeout)
3939

4040

41-
def _init_server(backend_uri, root_artifact_uri):
41+
def _init_server(backend_uri, root_artifact_uri, extra_env=None):
4242
"""
4343
Launch a new REST server using the tracking store specified by backend_uri and root artifact
4444
directory specified by root_artifact_uri.
@@ -57,6 +57,7 @@ def _init_server(backend_uri, root_artifact_uri):
5757
**os.environ,
5858
BACKEND_STORE_URI_ENV_VAR: backend_uri,
5959
ARTIFACT_ROOT_ENV_VAR: root_artifact_uri,
60+
**(extra_env or {}),
6061
},
6162
)
6263

Diff for: tests/tracking/test_rest_tracking.py

+77-10
Original file line numberDiff line numberDiff line change
@@ -1044,12 +1044,52 @@ def get(self, key, default=None):
10441044
)
10451045

10461046

1047-
def test_create_model_version_with_local_source(mlflow_client):
1047+
def test_create_model_version_with_path_source(mlflow_client):
10481048
name = "mode"
10491049
mlflow_client.create_registered_model(name)
10501050
exp_id = mlflow_client.create_experiment("test")
10511051
run = mlflow_client.create_run(experiment_id=exp_id)
10521052

1053+
response = requests.post(
1054+
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1055+
json={
1056+
"name": name,
1057+
"source": run.info.artifact_uri[len("file://") :],
1058+
"run_id": run.info.run_id,
1059+
},
1060+
)
1061+
assert response.status_code == 200
1062+
1063+
# run_id is not specified
1064+
response = requests.post(
1065+
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1066+
json={
1067+
"name": name,
1068+
"source": run.info.artifact_uri[len("file://") :],
1069+
},
1070+
)
1071+
assert response.status_code == 400
1072+
assert "To use a local path as a model version" in response.json()["message"]
1073+
1074+
# run_id is specified but source is not in the run's artifact directory
1075+
response = requests.post(
1076+
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1077+
json={
1078+
"name": name,
1079+
"source": "/tmp",
1080+
"run_id": run.info.run_id,
1081+
},
1082+
)
1083+
assert response.status_code == 400
1084+
assert "To use a local path as a model version" in response.json()["message"]
1085+
1086+
1087+
def test_create_model_version_with_file_uri(mlflow_client):
1088+
name = "test"
1089+
mlflow_client.create_registered_model(name)
1090+
exp_id = mlflow_client.create_experiment("test")
1091+
run = mlflow_client.create_run(experiment_id=exp_id)
1092+
assert run.info.artifact_uri.startswith("file://")
10531093
response = requests.post(
10541094
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
10551095
json={
@@ -1090,38 +1130,65 @@ def test_create_model_version_with_local_source(mlflow_client):
10901130
)
10911131
assert response.status_code == 200
10921132

1133+
# run_id is not specified
10931134
response = requests.post(
10941135
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
10951136
json={
10961137
"name": name,
1097-
"source": run.info.artifact_uri[len("file://") :],
1098-
"run_id": run.info.run_id,
1138+
"source": run.info.artifact_uri,
10991139
},
11001140
)
1101-
assert response.status_code == 200
1141+
assert response.status_code == 400
1142+
assert "To use a local path as a model version" in response.json()["message"]
11021143

1144+
# run_id is specified but source is not in the run's artifact directory
11031145
response = requests.post(
11041146
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
11051147
json={
11061148
"name": name,
1107-
"source": run.info.artifact_uri,
1149+
"source": "file:///tmp",
11081150
},
11091151
)
11101152
assert response.status_code == 400
1111-
resp = response.json()
1112-
assert "Invalid source" in resp["message"]
1153+
assert "To use a local path as a model version" in response.json()["message"]
11131154

11141155
response = requests.post(
11151156
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
11161157
json={
11171158
"name": name,
1118-
"source": "/tmp",
1159+
"source": "file://123.456.789.123/path/to/source",
11191160
"run_id": run.info.run_id,
11201161
},
11211162
)
11221163
assert response.status_code == 400
1123-
resp = response.json()
1124-
assert "Invalid source" in resp["message"]
1164+
assert "MLflow tracking server doesn't allow" in response.json()["message"]
1165+
1166+
1167+
def test_create_model_version_with_file_uri_env_var(tmp_path):
1168+
backend_uri = tmp_path.joinpath("file").as_uri()
1169+
url, process = _init_server(
1170+
backend_uri,
1171+
root_artifact_uri=tmp_path.as_uri(),
1172+
extra_env={"MLFLOW_ALLOW_FILE_URI_AS_MODEL_VERSION_SOURCE": "true"},
1173+
)
1174+
try:
1175+
mlflow_client = MlflowClient(url)
1176+
1177+
name = "test"
1178+
mlflow_client.create_registered_model(name)
1179+
exp_id = mlflow_client.create_experiment("test")
1180+
run = mlflow_client.create_run(experiment_id=exp_id)
1181+
response = requests.post(
1182+
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1183+
json={
1184+
"name": name,
1185+
"source": "file://123.456.789.123/path/to/source",
1186+
"run_id": run.info.run_id,
1187+
},
1188+
)
1189+
assert response.status_code == 200
1190+
finally:
1191+
_terminate_server(process)
11251192

11261193

11271194
def test_logging_model_with_local_artifact_uri(mlflow_client):

Diff for: tests/utils/test_uri.py

+5
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ def test_is_local_uri():
9191
assert is_local_uri("./mlruns")
9292
assert is_local_uri("file:///foo/mlruns")
9393
assert is_local_uri("file:foo/mlruns")
94+
assert is_local_uri("file://./mlruns")
95+
assert is_local_uri("file://localhost/mlruns")
96+
assert is_local_uri("file://localhost:5000/mlruns")
97+
assert is_local_uri("file://127.0.0.1/mlruns")
98+
assert is_local_uri("file://127.0.0.1:5000/mlruns")
9499

95100
assert not is_local_uri("file://myhostname/path/to/file")
96101
assert not is_local_uri("https://whatever")

0 commit comments

Comments
 (0)