Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable mlflow-artifacts scheme as wrapper around http artifact scheme #5070

Merged
merged 13 commits into from
Nov 24, 2021
Merged
2 changes: 2 additions & 0 deletions mlflow/store/artifact/artifact_repository_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
from mlflow.store.artifact.sftp_artifact_repo import SFTPArtifactRepository
from mlflow.store.artifact.http_artifact_repo import HttpArtifactRepository
from mlflow.store.artifact.mlflow_artifacts_repo import MlflowArtifactsRepository

from mlflow.utils.uri import get_uri_scheme

Expand Down Expand Up @@ -88,6 +89,7 @@ def get_artifact_repository(self, artifact_uri):
_artifact_repository_registry.register("models", ModelsArtifactRepository)
for scheme in ["http", "https"]:
_artifact_repository_registry.register(scheme, HttpArtifactRepository)
_artifact_repository_registry.register("mlflow-artifacts", MlflowArtifactsRepository)

_artifact_repository_registry.register_entrypoints()

Expand Down
3 changes: 2 additions & 1 deletion mlflow/store/artifact/http_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def __init__(self, artifact_uri):
self._session = requests.Session()

def __del__(self):
self._session.close()
if hasattr(self, "_session"):
self._session.close()

def log_artifact(self, local_file, artifact_path=None):
verify_artifact_path(artifact_path)
Expand Down
96 changes: 96 additions & 0 deletions mlflow/store/artifact/mlflow_artifacts_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from urllib.parse import urlparse
from collections import namedtuple

from mlflow.store.artifact.http_artifact_repo import HttpArtifactRepository
from mlflow.tracking._tracking_service.utils import get_tracking_uri
from mlflow.exceptions import MlflowException


def _parse_artifact_uri(artifact_uri):
ParsedURI = namedtuple("ParsedURI", "scheme host port path")
parsed_uri = urlparse(artifact_uri)
return ParsedURI(parsed_uri.scheme, parsed_uri.hostname, parsed_uri.port, parsed_uri.path)


def _check_if_host_is_numeric(hostname):
if hostname:
try:
float(hostname)
return True
except ValueError:
return False
else:
return False


def _validate_port_mapped_to_hostname(uri_parse):
# This check is to catch an mlflow-artifacts uri that has a port designated but no
# hostname specified. `urllib.parse.urlparse` will treat such a uri as a filesystem
# definition, mapping the provided port as a hostname value if this condition is not
# validated.
if uri_parse.host and _check_if_host_is_numeric(uri_parse.host) and not uri_parse.port:
raise MlflowException(
"The mlflow-artifacts uri was supplied with a port number: "
f"{uri_parse.host}, but no host was defined."
)


def _validate_uri_scheme(scheme):
allowable_schemes = {"http", "https"}
if scheme not in allowable_schemes:
raise MlflowException(
f"The configured tracking uri scheme: '{scheme}' is invalid for use with the proxy "
f"mlflow-artifact scheme. The allowed tracking schemes are: {allowable_schemes}"
)


class MlflowArtifactsRepository(HttpArtifactRepository):
"""Scheme wrapper around HttpArtifactRepository for mlflow-artifacts server functionality"""

def __init__(self, artifact_uri):

super().__init__(self.resolve_uri(artifact_uri))

@classmethod
def resolve_uri(cls, artifact_uri):
tracking_uri = get_tracking_uri()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we raise an exception if the scheme of tracking_uri is not http or https (e.g. sqlite:///foo/bar)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great idea :) added! I also simplified the parsing logic and temporarily set the tracking uri in the test function to validate against a valid server address (then reverted the global variable to not affect other tests with an inadvertent side-effect)

Copy link
Member

@harupy harupy Nov 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! A tracking URI looks like http://localhost:5000 and doesn't contain /api/2.0/mlflow-artifacts/artifacts.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the tests and cleaned all of that up

Copy link
Member

@harupy harupy Nov 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenWilson2 Sorry I wasn't clear. The resolved artifact URI needs to have /api/2.0/mlflow-artifacts/artifacts (e.g. http://localhost:5000/api/2.0/mlflow-artifacts/artifacts/path/to/dir).


track_parse = _parse_artifact_uri(tracking_uri)

uri_parse = _parse_artifact_uri(artifact_uri)

print(f"\nuri_parse: {uri_parse}")
print(f"track_parse: {track_parse}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(f"\nuri_parse: {uri_parse}")
print(f"track_parse: {track_parse}")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted my debug ;)


if track_parse.path == uri_parse.path:
resolved = "/"
else:
resolved = f"{track_parse.path}{uri_parse.path}"

# Check to ensure that a port is present with no hostname
_validate_port_mapped_to_hostname(uri_parse)

# Check that tracking uri is http or https
_validate_uri_scheme(track_parse.scheme)

if uri_parse.host and uri_parse.port:
resolved_artifacts_uri = (
f"{track_parse.scheme}://{uri_parse.host}:{uri_parse.port}{resolved}"
)
elif uri_parse.host and not uri_parse.port:
resolved_artifacts_uri = f"{track_parse.scheme}://{uri_parse.host}{resolved}"
elif not uri_parse.host and not uri_parse.port and uri_parse.path == track_parse.path:
resolved_artifacts_uri = (
f"{track_parse.scheme}://{track_parse.host}:{track_parse.port}{track_parse.path}"
)
elif not uri_parse.host and not uri_parse.port and uri_parse.path != track_parse.path:
resolved_artifacts_uri = (
f"{track_parse.scheme}://{track_parse.host}:"
f"{track_parse.port}{track_parse.path}{uri_parse.path}"
)
else:
raise MlflowException(
f"The supplied artifact uri {artifact_uri} could not be resolved."
)

return resolved_artifacts_uri.replace("///", "/").rstrip("/")