Skip to content

Commit f731474

Browse files
authored
Disable ability to provide relative paths in sources (#8281)
* Disable ability to provide relative paths in sources Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com> * no relative paths allowed Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com> --------- Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
1 parent 9ac6494 commit f731474

File tree

2 files changed

+129
-1
lines changed

2 files changed

+129
-1
lines changed

Diff for: mlflow/server/handlers.py

+34
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,36 @@ def _delete_registered_model_tag():
13231323
return _wrap_response(DeleteRegisteredModelTag.Response())
13241324

13251325

1326+
def _validate_non_local_source_contains_relative_paths(source: str):
1327+
"""
1328+
Validation check to ensure that sources that are provided that conform to the schemes:
1329+
http, https, or mlflow-artifacts do not contain relative path designations that are intended
1330+
to access local file system paths on the tracking server.
1331+
1332+
Example paths that this validation function is intended to find and raise an Exception if
1333+
passed:
1334+
"mlflow-artifacts://host:port/../../../../"
1335+
"http://host:port/api/2.0/mlflow-artifacts/artifacts/../../../../"
1336+
"https://host:port/api/2.0/mlflow-artifacts/artifacts/../../../../"
1337+
"/models/artifacts/../../../"
1338+
"s3:/my_bucket/models/path/../../other/path"
1339+
"file://path/to/../../../../some/where/you/should/not/be"
1340+
"""
1341+
source_path = urllib.parse.urlparse(source).path
1342+
resolved_source = pathlib.Path(source_path).resolve().as_posix()
1343+
# NB: drive split is specifically for Windows since WindowsPath.resolve() will append the
1344+
# drive path of the pwd to a given path. We don't care about the drive here, though.
1345+
_, resolved_path = os.path.splitdrive(resolved_source)
1346+
1347+
if resolved_path != source_path:
1348+
raise MlflowException(
1349+
f"Invalid model version source: '{source}'. If supplying a source as an http, https, "
1350+
"local file path, ftp, objectstore, or mlflow-artifacts uri, an absolute path must be "
1351+
"provided without relative path references present. Please provide an absolute path.",
1352+
INVALID_PARAMETER_VALUE,
1353+
)
1354+
1355+
13261356
def _validate_source(source: str, run_id: str) -> None:
13271357
if is_local_uri(source):
13281358
if run_id:
@@ -1352,6 +1382,10 @@ def _validate_source(source: str, run_id: str) -> None:
13521382
INVALID_PARAMETER_VALUE,
13531383
)
13541384

1385+
# Checks if relative paths are present in the source (a security threat). If any are present,
1386+
# raises an Exception.
1387+
_validate_non_local_source_contains_relative_paths(source)
1388+
13551389

13561390
@catch_mlflow_exception
13571391
@_disable_if_artifacts_only

Diff for: tests/tracking/test_rest_tracking.py

+95-1
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,7 @@ def get(self, key, default=None):
10451045

10461046

10471047
def test_create_model_version_with_path_source(mlflow_client):
1048-
name = "mode"
1048+
name = "model"
10491049
mlflow_client.create_registered_model(name)
10501050
exp_id = mlflow_client.create_experiment("test")
10511051
run = mlflow_client.create_run(experiment_id=exp_id)
@@ -1084,6 +1084,100 @@ def test_create_model_version_with_path_source(mlflow_client):
10841084
assert "To use a local path as a model version" in response.json()["message"]
10851085

10861086

1087+
def test_create_model_version_with_non_local_source(mlflow_client, monkeypatch):
1088+
name = "model"
1089+
mlflow_client.create_registered_model(name)
1090+
exp_id = mlflow_client.create_experiment("test")
1091+
run = mlflow_client.create_run(experiment_id=exp_id)
1092+
1093+
response = requests.post(
1094+
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1095+
json={
1096+
"name": name,
1097+
"source": run.info.artifact_uri[len("file://") :],
1098+
"run_id": run.info.run_id,
1099+
},
1100+
)
1101+
assert response.status_code == 200
1102+
1103+
# Test that remote uri's supplied as a source with absolute paths work fine
1104+
response = requests.post(
1105+
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1106+
json={
1107+
"name": name,
1108+
"source": "mlflow-artifacts:/models",
1109+
"run_id": run.info.run_id,
1110+
},
1111+
)
1112+
assert response.status_code == 200
1113+
1114+
response = requests.post(
1115+
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1116+
json={
1117+
"name": name,
1118+
"source": "mlflow-artifacts://host:9000/models",
1119+
"run_id": run.info.run_id,
1120+
},
1121+
)
1122+
assert response.status_code == 200
1123+
1124+
# Test that invalid remote uri's cannot be created
1125+
response = requests.post(
1126+
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1127+
json={
1128+
"name": name,
1129+
"source": "mlflow-artifacts://host:9000/models/../../../",
1130+
"run_id": run.info.run_id,
1131+
},
1132+
)
1133+
assert response.status_code == 400
1134+
assert "If supplying a source as an http, https," in response.json()["message"]
1135+
1136+
response = requests.post(
1137+
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1138+
json={
1139+
"name": name,
1140+
"source": "http://host:9000/models/../../../",
1141+
"run_id": run.info.run_id,
1142+
},
1143+
)
1144+
assert response.status_code == 400
1145+
assert "If supplying a source as an http, https," in response.json()["message"]
1146+
1147+
response = requests.post(
1148+
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1149+
json={
1150+
"name": name,
1151+
"source": "https://host/api/2.0/mlflow-artifacts/artifacts/../../../",
1152+
"run_id": run.info.run_id,
1153+
},
1154+
)
1155+
assert response.status_code == 400
1156+
assert "If supplying a source as an http, https," in response.json()["message"]
1157+
1158+
response = requests.post(
1159+
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1160+
json={
1161+
"name": name,
1162+
"source": "s3a://my_bucket/api/2.0/mlflow-artifacts/artifacts/../../../",
1163+
"run_id": run.info.run_id,
1164+
},
1165+
)
1166+
assert response.status_code == 400
1167+
assert "If supplying a source as an http, https," in response.json()["message"]
1168+
1169+
response = requests.post(
1170+
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1171+
json={
1172+
"name": name,
1173+
"source": "ftp://host:8888/api/2.0/mlflow-artifacts/artifacts/../../../",
1174+
"run_id": run.info.run_id,
1175+
},
1176+
)
1177+
assert response.status_code == 400
1178+
assert "If supplying a source as an http, https," in response.json()["message"]
1179+
1180+
10871181
def test_create_model_version_with_file_uri(mlflow_client):
10881182
name = "test"
10891183
mlflow_client.create_registered_model(name)

0 commit comments

Comments
 (0)