Skip to content

Commit

Permalink
Handle slashes in _validate_non_local_source_contains_relative_paths (
Browse files Browse the repository at this point in the history
#8338)

* Handle slashes in _validate_non_local_source_contains_relative_paths

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Fix

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Fixl lint

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

---------

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
  • Loading branch information
harupy committed Apr 27, 2023
1 parent 9e35947 commit af38edf
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mlflow/server/handlers.py
@@ -1,11 +1,11 @@
# Define all the service endpoint handlers here.
import json
import os
import re
import tempfile
import posixpath
import urllib
import pathlib
import re

import logging
from functools import wraps
Expand Down Expand Up @@ -1338,7 +1338,7 @@ def _validate_non_local_source_contains_relative_paths(source: str):
"s3:/my_bucket/models/path/../../other/path"
"file://path/to/../../../../some/where/you/should/not/be"
"""
source_path = urllib.parse.urlparse(source).path
source_path = re.sub(r"/+", "/", urllib.parse.urlparse(source).path.rstrip("/"))
resolved_source = pathlib.Path(source_path).resolve().as_posix()
# NB: drive split is specifically for Windows since WindowsPath.resolve() will append the
# drive path of the pwd to a given path. We don't care about the drive here, though.
Expand Down
2 changes: 1 addition & 1 deletion mlflow/utils/requirements_utils.py
Expand Up @@ -370,7 +370,7 @@ def _infer_requirements(model_uri, flavor):

modules = _capture_imported_modules(model_uri, flavor)
packages = _flatten([_MODULES_TO_PACKAGES.get(module, []) for module in modules])
packages = map(_normalize_package_name, packages)
packages = set(map(_normalize_package_name, packages))
packages = _prune_packages(packages)
excluded_packages = [
# Certain packages (e.g. scikit-learn 0.24.2) imports `setuptools` or `pkg_resources`
Expand Down
33 changes: 33 additions & 0 deletions tests/tracking/test_rest_tracking.py
Expand Up @@ -1111,6 +1111,39 @@ def test_create_model_version_with_non_local_source(mlflow_client, monkeypatch):
)
assert response.status_code == 200

# A single trailing slash
response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "mlflow-artifacts:/models/",
"run_id": run.info.run_id,
},
)
assert response.status_code == 200

# Multiple trailing slashes
response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "mlflow-artifacts:/models///",
"run_id": run.info.run_id,
},
)
assert response.status_code == 200

# Multiple slashes
response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "mlflow-artifacts:/models/foo///bar",
"run_id": run.info.run_id,
},
)
assert response.status_code == 200

response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
Expand Down

0 comments on commit af38edf

Please sign in to comment.