Skip to content

Commit

Permalink
Fix proxy artifact URI handling for fetching trace data (#12147)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: Yuki Watanabe <31463517+B-Step62@users.noreply.github.com>
Co-authored-by: Harutaka Kawamura <hkawamura0130@gmail.com>
  • Loading branch information
B-Step62 and harupy committed May 28, 2024
1 parent 688cf08 commit 86d5b10
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 8 deletions.
43 changes: 36 additions & 7 deletions mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mlflow.entities import DatasetInput, ExperimentTag, FileInfo, Metric, Param, RunTag, ViewType
from mlflow.entities.model_registry import ModelVersionTag, RegisteredModelTag
from mlflow.entities.multipart_upload import MultipartUploadPart
from mlflow.entities.trace_info import TraceInfo
from mlflow.entities.trace_status import TraceStatus
from mlflow.environment_variables import MLFLOW_DEPLOYMENTS_TARGET
from mlflow.exceptions import MlflowException, _UnsupportedMultipartUploadException
Expand Down Expand Up @@ -187,6 +188,39 @@ def _get_artifact_repo_mlflow_artifacts():
return _artifact_repo


def _get_trace_artifact_repo(trace_info: TraceInfo):
"""
Resolve the artifact repository for fetching data for the given trace.
Args:
trace_info: The trace info object containing metadata about the trace.
"""
artifact_uri = get_artifact_uri_for_trace(trace_info)

if _is_servable_proxied_run_artifact_root(artifact_uri):
# If the artifact location is a proxied run artifact root (e.g. mlflow-artifacts://...),
# we need to resolve it to the actual artifact location.
from mlflow.server import ARTIFACTS_DESTINATION_ENV_VAR

path = _get_proxied_run_artifact_destination_path(artifact_uri)
if not path:
raise MlflowException(
f"Failed to resolve the proxied run artifact URI: {artifact_uri}. ",
"Trace artifact URI must contain subpath to the trace data directory.",
error_code=BAD_REQUEST,
)
root = os.environ[ARTIFACTS_DESTINATION_ENV_VAR]
artifact_uri = posixpath.join(root, path)

# We don't set it to global var unlike run artifact, because the artifact repo has
# to be created with full trace artifact URI including request_id.
# e.g. s3://<experiment_id>/traces/<request_id>
artifact_repo = get_artifact_repository(artifact_uri)
else:
artifact_repo = get_artifact_repository(artifact_uri)
return artifact_repo


def _is_serving_proxied_artifacts():
"""
Returns:
Expand Down Expand Up @@ -2455,15 +2489,10 @@ def get_trace_artifact_handler():
)

trace_info = _get_tracking_store().get_trace_info(request_id)
artifact_uri = get_artifact_uri_for_trace(trace_info)

if _is_servable_proxied_run_artifact_root(artifact_uri):
artifact_repo = _get_artifact_repo_mlflow_artifacts()
else:
artifact_repo = get_artifact_repository(artifact_uri)
trace_data = _get_trace_artifact_repo(trace_info).download_trace_data()

# Write data to a BytesIO buffer instead of needing to save a temp file
buf = io.BytesIO()
trace_data = artifact_repo.download_trace_data()
buf.write(json.dumps(trace_data).encode())
buf.seek(0)

Expand Down
41 changes: 40 additions & 1 deletion tests/server/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
RegisteredModel,
RegisteredModelTag,
)
from mlflow.entities.trace_info import TraceInfo
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, INVALID_PARAMETER_VALUE, ErrorCode
from mlflow.protos.model_registry_pb2 import (
Expand All @@ -38,7 +39,12 @@
UpdateRegisteredModel,
)
from mlflow.protos.service_pb2 import CreateExperiment, SearchRuns
from mlflow.server import BACKEND_STORE_URI_ENV_VAR, SERVE_ARTIFACTS_ENV_VAR, app
from mlflow.server import (
ARTIFACTS_DESTINATION_ENV_VAR,
BACKEND_STORE_URI_ENV_VAR,
SERVE_ARTIFACTS_ENV_VAR,
app,
)
from mlflow.server.handlers import (
_convert_path_parameter_to_flask_format,
_create_experiment,
Expand All @@ -56,6 +62,7 @@
_get_model_version_download_uri,
_get_registered_model,
_get_request_message,
_get_trace_artifact_repo,
_log_batch,
_rename_registered_model,
_search_model_versions,
Expand All @@ -71,11 +78,15 @@
catch_mlflow_exception,
get_endpoints,
)
from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository
from mlflow.store.artifact.local_artifact_repo import LocalArtifactRepository
from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
from mlflow.store.entities.paged_list import PagedList
from mlflow.store.model_registry import (
SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD,
SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
)
from mlflow.utils.mlflow_tags import MLFLOW_ARTIFACT_LOCATION
from mlflow.utils.proto_json_utils import message_to_json
from mlflow.utils.validation import MAX_BATCH_LOG_REQUEST_SIZE

Expand Down Expand Up @@ -857,3 +868,31 @@ def test_local_file_read_write_by_pass_vulnerability(uri):
),
):
_validate_source("/local/path/xyz", run_id)


@pytest.mark.parametrize(
("location", "expected_class", "expected_uri"),
[
("file:///0/traces/123", LocalArtifactRepository, "file:///0/traces/123"),
("s3://bucket/0/traces/123", S3ArtifactRepository, "s3://bucket/0/traces/123"),
(
"wasbs://container@account.blob.core.windows.net/bucket/1/traces/123",
AzureBlobArtifactRepository,
"wasbs://container@account.blob.core.windows.net/bucket/1/traces/123",
),
# Proxy URI must be resolved to the actual storage URI
(
"https://127.0.0.1/api/2.0/mlflow-artifacts/artifacts/2/traces/123",
S3ArtifactRepository,
"s3://bucket/2/traces/123",
),
("mlflow-artifacts:/1/traces/123", S3ArtifactRepository, "s3://bucket/1/traces/123"),
],
)
def test_get_trace_artifact_repo(location, expected_class, expected_uri, monkeypatch):
monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true")
monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket")
trace_info = TraceInfo("123", "0", 0, 1, "OK", tags={MLFLOW_ARTIFACT_LOCATION: location})
repo = _get_trace_artifact_repo(trace_info)
assert isinstance(repo, expected_class)
assert repo.artifact_uri == expected_uri

0 comments on commit 86d5b10

Please sign in to comment.