Skip to content

Commit

Permalink
Validate path in HttpArtifactRepository.list_artifacts (#10585)
Browse files Browse the repository at this point in the history
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
  • Loading branch information
harupy committed Dec 4, 2023
1 parent 400c226 commit 55c72d0
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 26 deletions.
27 changes: 1 addition & 26 deletions mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,10 @@
from mlflow.tracking.registry import UnsupportedModelRegistryStoreURIException
from mlflow.utils.file_utils import local_file_uri_to_path
from mlflow.utils.mime_type_utils import _guess_mime_type
from mlflow.utils.os import is_windows
from mlflow.utils.promptlab_utils import _create_promptlab_run_impl
from mlflow.utils.proto_json_utils import message_to_json, parse_dict
from mlflow.utils.string_utils import is_string_type
from mlflow.utils.uri import is_file_uri, is_local_uri
from mlflow.utils.uri import is_file_uri, is_local_uri, validate_path_is_safe
from mlflow.utils.validation import _validate_batch_log_api_req

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -541,30 +540,6 @@ def wrapper(*args, **kwargs):
return wrapper


_OS_ALT_SEPS = [sep for sep in [os.sep, os.path.altsep] if sep is not None and sep != "/"]


def validate_path_is_safe(path):
"""
Validates that the specified path is safe to join with a trusted prefix. This is a security
measure to prevent path traversal attacks.
A valid path should:
not contain separators other than '/'
not contain .. to navigate to parent dir in path
not be an absolute path
"""
if is_file_uri(path):
path = local_file_uri_to_path(path)
if (
any((s in path) for s in _OS_ALT_SEPS)
or ".." in path.split("/")
or pathlib.PureWindowsPath(path).is_absolute()
or pathlib.PurePosixPath(path).is_absolute()
or (is_windows() and len(path) >= 2 and path[1] == ":")
):
raise MlflowException(f"Invalid path: {path}", error_code=INVALID_PARAMETER_VALUE)


@catch_mlflow_exception
def get_artifact_handler():
from querystring_parser import parser
Expand Down
2 changes: 2 additions & 0 deletions mlflow/store/artifact/http_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from mlflow.utils.file_utils import read_chunk, relative_path_to_artifact_path
from mlflow.utils.mime_type_utils import _guess_mime_type
from mlflow.utils.rest_utils import augmented_raise_for_status, http_request
from mlflow.utils.uri import validate_path_is_safe

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,6 +89,7 @@ def list_artifacts(self, path=None):
augmented_raise_for_status(resp)
file_infos = []
for f in resp.json().get("files", []):
validate_path_is_safe(f["path"])
file_info = FileInfo(
posixpath.join(path, f["path"]) if path else f["path"],
f["is_dir"],
Expand Down
27 changes: 27 additions & 0 deletions mlflow/utils/uri.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pathlib
import posixpath
import re
Expand Down Expand Up @@ -405,3 +406,29 @@ def resolve_uri_if_local(local_uri):

def generate_tmp_dfs_path(dfs_tmp):
return posixpath.join(dfs_tmp, str(uuid.uuid4()))


_OS_ALT_SEPS = [sep for sep in [os.sep, os.path.altsep] if sep is not None and sep != "/"]


def validate_path_is_safe(path):
"""
Validates that the specified path is safe to join with a trusted prefix. This is a security
measure to prevent path traversal attacks.
A valid path should:
not contain separators other than '/'
not contain .. to navigate to parent dir in path
not be an absolute path
"""
from mlflow.utils.file_utils import local_file_uri_to_path

if is_file_uri(path):
path = local_file_uri_to_path(path)
if (
any((s in path) for s in _OS_ALT_SEPS)
or ".." in path.split("/")
or pathlib.PureWindowsPath(path).is_absolute()
or pathlib.PurePosixPath(path).is_absolute()
or (is_windows() and len(path) >= 2 and path[1] == ":")
):
raise MlflowException(f"Invalid path: {path}", error_code=INVALID_PARAMETER_VALUE)
18 changes: 18 additions & 0 deletions tests/store/artifact/test_http_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MLFLOW_TRACKING_TOKEN,
MLFLOW_TRACKING_USERNAME,
)
from mlflow.exceptions import MlflowException
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
from mlflow.store.artifact.http_artifact_repo import HttpArtifactRepository
from mlflow.tracking._tracking_service.utils import _get_default_host_creds
Expand Down Expand Up @@ -245,6 +246,23 @@ def test_list_artifacts(http_artifact_repo):
http_artifact_repo.list_artifacts()


@pytest.mark.parametrize("path", ["/tmp/path", "../../path"])
def test_list_artifacts_malicious_path(http_artifact_repo, path):
with mock.patch(
"mlflow.store.artifact.http_artifact_repo.http_request",
return_value=MockResponse(
{
"files": [
{"path": path, "is_dir": False, "file_size": 1},
]
},
200,
),
):
with pytest.raises(MlflowException, match=f"Invalid path: {path}"):
http_artifact_repo.list_artifacts()


def read_file(path):
with open(path) as f:
return f.read()
Expand Down

0 comments on commit 55c72d0

Please sign in to comment.