Skip to content

Commit

Permalink
Fix for windows traversal attack (#10647)
Browse files Browse the repository at this point in the history
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <39283302+BenWilson2@users.noreply.github.com>
Co-authored-by: Harutaka Kawamura <hkawamura0130@gmail.com>
  • Loading branch information
BenWilson2 and harupy committed Dec 14, 2023
1 parent b6ea835 commit 3d3c146
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 16 deletions.
36 changes: 20 additions & 16 deletions mlflow/data/http_dataset_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import posixpath
import re
from typing import Any, Dict
from urllib.parse import urlparse
Expand Down Expand Up @@ -41,6 +40,23 @@ def url(self):
def _get_source_type() -> str:
return "http"

def _extract_filename(self, response) -> str:
"""
Extracts a filename from the Content-Disposition header or the URL's path.
"""
if content_disposition := response.headers.get("Content-Disposition"):
for match in re.finditer(r"filename=(.+)", content_disposition):
filename = match[1].strip("'\"")
if _is_path(filename):
raise MlflowException.invalid_parameter_value(
f"Invalid filename in Content-Disposition header: {filename}. "
"It must be a file name, not a path."
)
return filename

# Extract basename from URL if no valid filename in Content-Disposition
return os.path.basename(urlparse(self.url).path)

def load(self, dst_path=None) -> str:
"""
Downloads the dataset source to the local filesystem.
Expand All @@ -58,21 +74,9 @@ def load(self, dst_path=None) -> str:
)
augmented_raise_for_status(resp)

path = urlparse(self.url).path
content_disposition = resp.headers.get("Content-Disposition")
if content_disposition is not None and (
file_name := next(re.finditer(r"filename=(.+)", content_disposition), None)
):
# NB: If the filename is quoted, unquote it
basename = file_name[1].strip("'\"")
if _is_path(basename):
raise MlflowException.invalid_parameter_value(
f"Invalid filename in Content-Disposition header: {basename}. "
"It must be a file name, not a path."
)
elif path is not None and len(posixpath.basename(path)) > 0:
basename = posixpath.basename(path)
else:
basename = self._extract_filename(resp)

if not basename:
basename = "dataset_source"

if dst_path is None:
Expand Down
28 changes: 28 additions & 0 deletions tests/data/test_http_dataset_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from mlflow.data.dataset_source_registry import get_dataset_source_from_json, resolve_dataset_source
from mlflow.data.http_dataset_source import HTTPDatasetSource
from mlflow.exceptions import MlflowException
from mlflow.utils.os import is_windows
from mlflow.utils.rest_utils import cloud_storage_http_request


Expand Down Expand Up @@ -155,3 +156,30 @@ def download_with_mock_content_disposition_headers(*args, **kwargs):

with pytest.raises(MlflowException, match="Invalid filename in Content-Disposition header"):
source.load()


@pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows")
@pytest.mark.parametrize(
"filename",
[
r"..\..\poc.txt",
r"Users\User\poc.txt",
],
)
def test_source_load_with_content_disposition_header_invalid_filename_windows(filename):
def download_with_mock_content_disposition_headers(*args, **kwargs):
response = cloud_storage_http_request(*args, **kwargs)
response.headers = {"Content-Disposition": f"attachment; filename={filename}"}
return response

with mock.patch(
"mlflow.data.http_dataset_source.cloud_storage_http_request",
side_effect=download_with_mock_content_disposition_headers,
):
source = HTTPDatasetSource(
"https://raw.githubusercontent.com/mlflow/mlflow/master/tests/datasets/winequality-red.csv"
)

# Expect an MlflowException for invalid filenames
with pytest.raises(MlflowException, match="Invalid filename in Content-Disposition header"):
source.load()

0 comments on commit 3d3c146

Please sign in to comment.