Skip to content

Commit

Permalink
Validate path parameter is safe (#7170)
Browse files Browse the repository at this point in the history
* Check path is safe

Signed-off-by: harupy <hkawamura0130@gmail.com>

* add tests

Signed-off-by: harupy <hkawamura0130@gmail.com>

Signed-off-by: harupy <hkawamura0130@gmail.com>
  • Loading branch information
harupy committed Oct 28, 2022
1 parent d1c7621 commit ac4b697
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 5 deletions.
34 changes: 29 additions & 5 deletions mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,24 +524,46 @@ def wrapper(*args, **kwargs):
return wrapper


_os_alt_seps = list(sep for sep in [os.sep, os.path.altsep] if sep is not None and sep != "/")


def validate_path_is_safe(path):
"""
Validates that the specified path is safe to join with a trusted prefix. This is a security
measure to prevent path traversal attacks. The implementation is based on
`werkzeug.security.safe_join` (https://github.com/pallets/werkzeug/blob/a3005e6acda7246fe0a684c71921bf4882b4ba1c/src/werkzeug/security.py#L110).
"""
if path != "":
path = posixpath.normpath(path)
if (
any(sep in path for sep in _os_alt_seps)
or os.path.isabs(path)
or path == ".."
or path.startswith("../")
):
raise MlflowException(f"Invalid path: {path}", error_code=INVALID_PARAMETER_VALUE)


@catch_mlflow_exception
def get_artifact_handler():
from querystring_parser import parser

query_string = request.query_string.decode("utf-8")
request_dict = parser.parse(query_string, normalized=True)
run_id = request_dict.get("run_id") or request_dict.get("run_uuid")
path = request_dict["path"]
validate_path_is_safe(path)
run = _get_tracking_store().get_run(run_id)

if _is_servable_proxied_run_artifact_root(run.info.artifact_uri):
artifact_repo = _get_artifact_repo_mlflow_artifacts()
artifact_path = _get_proxied_run_artifact_destination_path(
proxied_artifact_root=run.info.artifact_uri,
relative_path=request_dict["path"],
relative_path=path,
)
else:
artifact_repo = _get_artifact_repo(run)
artifact_path = request_dict["path"]
artifact_path = path

return _send_artifact(artifact_repo, artifact_path)

Expand Down Expand Up @@ -896,6 +918,7 @@ def _list_artifacts():
response_message = ListArtifacts.Response()
if request_message.HasField("path"):
path = request_message.path
validate_path_is_safe(path)
else:
path = None
run_id = request_message.run_id or request_message.run_uuid
Expand Down Expand Up @@ -1273,17 +1296,18 @@ def get_model_version_artifact_handler():
request_dict = parser.parse(query_string, normalized=True)
name = request_dict.get("name")
version = request_dict.get("version")
path = request_dict["path"]
validate_path_is_safe(path)
artifact_uri = _get_model_registry_store().get_model_version_download_uri(name, version)

if _is_servable_proxied_run_artifact_root(artifact_uri):
artifact_repo = _get_artifact_repo_mlflow_artifacts()
artifact_path = _get_proxied_run_artifact_destination_path(
proxied_artifact_root=artifact_uri,
relative_path=request_dict["path"],
relative_path=path,
)
else:
artifact_repo = get_artifact_repository(artifact_uri)
artifact_path = request_dict["path"]
artifact_path = path

return _send_artifact(artifact_repo, artifact_path)

Expand Down
62 changes: 62 additions & 0 deletions tests/tracking/test_rest_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
import tempfile
import time
import urllib.parse
import requests

import mlflow.experiments
from mlflow.exceptions import MlflowException
from mlflow.entities import Metric, Param, RunTag, ViewType
from mlflow.store.tracking.file_store import FileStore
from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore
from mlflow.server.handlers import validate_path_is_safe
from mlflow.models import Model

import mlflow.pyfunc
Expand Down Expand Up @@ -501,6 +503,66 @@ def assert_bad_request(payload, expected_error_message):
assert response.status_code == 200


@pytest.mark.parametrize(
"path",
[
"path",
"path/",
"path/to/file",
"path/../to/file",
],
)
def test_validate_path_is_safe_good(path):
validate_path_is_safe(path)


@pytest.mark.parametrize(
"path",
[
"/path",
"../path",
"../../path",
"./../path",
"path/../../to/file",
],
)
def test_validate_path_is_safe_bad(path):
with pytest.raises(MlflowException, match="Invalid path"):
validate_path_is_safe(path)


def test_path_validation(mlflow_client):
experiment_id = mlflow_client.create_experiment("tags validation")
created_run = mlflow_client.create_run(experiment_id)
run_id = created_run.info.run_id
invalid_path = "../path"

def assert_response(resp):
assert resp.status_code == 400
assert response.json() == {
"error_code": "INVALID_PARAMETER_VALUE",
"message": f"Invalid path: {invalid_path}",
}

response = requests.get(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/artifacts/list",
params={"run_id": run_id, "path": invalid_path},
)
assert_response(response)

response = requests.get(
f"{mlflow_client.tracking_uri}/get-artifact",
params={"run_id": run_id, "path": invalid_path},
)
assert_response(response)

response = requests.get(
f"{mlflow_client.tracking_uri}//model-versions/get-artifact",
params={"name": "model", "version": 1, "path": invalid_path},
)
assert_response(response)


def test_set_experiment_tag(mlflow_client):
experiment_id = mlflow_client.create_experiment("SetExperimentTagTest")
mlflow_client.set_experiment_tag(experiment_id, "dataset", "imagenet1K")
Expand Down

0 comments on commit ac4b697

Please sign in to comment.