Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable mlflow-artifacts scheme as wrapper around http artifact scheme #5070

Merged
merged 13 commits into from
Nov 24, 2021
Merged
2 changes: 2 additions & 0 deletions mlflow/store/artifact/artifact_repository_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
from mlflow.store.artifact.sftp_artifact_repo import SFTPArtifactRepository
from mlflow.store.artifact.http_artifact_repo import HttpArtifactRepository
from mlflow.store.artifact.mlflow_artifacts_repo import MlflowArtifactsRepository

from mlflow.utils.uri import get_uri_scheme

Expand Down Expand Up @@ -88,6 +89,7 @@ def get_artifact_repository(self, artifact_uri):
_artifact_repository_registry.register("models", ModelsArtifactRepository)
for scheme in ["http", "https"]:
_artifact_repository_registry.register(scheme, HttpArtifactRepository)
_artifact_repository_registry.register("mlflow-artifacts", MlflowArtifactsRepository)

_artifact_repository_registry.register_entrypoints()

Expand Down
3 changes: 2 additions & 1 deletion mlflow/store/artifact/http_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def __init__(self, artifact_uri):
self._session = requests.Session()

def __del__(self):
self._session.close()
if hasattr(self, "_session"):
self._session.close()

def log_artifact(self, local_file, artifact_path=None):
verify_artifact_path(artifact_path)
Expand Down
79 changes: 79 additions & 0 deletions mlflow/store/artifact/mlflow_artifacts_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from urllib.parse import urlparse
import posixpath
from collections import namedtuple

from mlflow.store.artifact.http_artifact_repo import HttpArtifactRepository
from mlflow.tracking._tracking_service.utils import get_tracking_uri
from mlflow.exceptions import MlflowException


def _parse_artifact_uri(artifact_uri):
ParsedURI = namedtuple("ParsedURI", "scheme host port path")
parsed_uri = urlparse(artifact_uri)
return ParsedURI(parsed_uri.scheme, parsed_uri.hostname, parsed_uri.port, parsed_uri.path)


def _check_if_host_is_numeric(hostname):
if hostname:
try:
float(hostname)
return True
except ValueError:
return False
else:
return False


def _validate_port_mapped_to_hostname(uri_parse):
# This check is to catch an mlflow-artifacts uri that has a port designated but no
# hostname specified. `urllib.parse.urlparse` will treat such a uri as a filesystem
# definition, mapping the provided port as a hostname value if this condition is not
# validated.
if uri_parse.host and _check_if_host_is_numeric(uri_parse.host) and not uri_parse.port:
raise MlflowException(
f"The mlflow-artifacts uri was supplied with a port number: "
BenWilson2 marked this conversation as resolved.
Show resolved Hide resolved
f"{uri_parse.host}, but no host was defined."
)


class MlflowArtifactsRepository(HttpArtifactRepository):
"""Scheme wrapper around HttpArtifactRepository for mlflow-artifacts server functionality"""

def __init__(self, artifact_uri):

super().__init__(self.resolve_uri(artifact_uri))

@classmethod
def resolve_uri(cls, artifact_uri):
tracking_uri = get_tracking_uri()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we raise an exception if the scheme of tracking_uri is not http or https (e.g. sqlite:///foo/bar)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great idea :) added! I also simplified the parsing logic and temporarily set the tracking uri in the test function to validate against a valid server address (then reverted the global variable to not affect other tests with an inadvertent side-effect)

Copy link
Member

@harupy harupy Nov 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! A tracking URI looks like http://localhost:5000 and doesn't contain /api/2.0/mlflow-artifacts/artifacts.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the tests and cleaned all of that up

Copy link
Member

@harupy harupy Nov 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenWilson2 Sorry I wasn't clear. The resolved artifact URI needs to have /api/2.0/mlflow-artifacts/artifacts (e.g. http://localhost:5000/api/2.0/mlflow-artifacts/artifacts/path/to/dir).


track_parse = _parse_artifact_uri(tracking_uri)

uri_parse = _parse_artifact_uri(artifact_uri)

# Check to ensure that a port is present with no hostname
_validate_port_mapped_to_hostname(uri_parse)

api_path = "/api/2.0/mlflow-artifacts/artifacts"
harupy marked this conversation as resolved.
Show resolved Hide resolved

# If root directory is specified (empty path), `urllib.parse.urlparse` will pull
# the api path from the uri. This logic is to handle this.
if uri_parse.path != api_path:
request_path = posixpath.join(api_path, uri_parse.path.lstrip("/"))
else:
request_path = api_path

if uri_parse.host and uri_parse.port:
resolved_artifacts_uri = (
f"{track_parse.scheme}://{uri_parse.host}:{uri_parse.port}{request_path}"
)
elif uri_parse.host and not uri_parse.port:
resolved_artifacts_uri = f"{track_parse.scheme}://{uri_parse.host}{request_path}"
elif not uri_parse.host and not uri_parse.port:
resolved_artifacts_uri = f"{tracking_uri}{request_path}"
else:
raise MlflowException(
f"The supplied artifact uri {artifact_uri} could not be resolved."
)

return resolved_artifacts_uri.replace("///", "/").rstrip("/")
281 changes: 281 additions & 0 deletions tests/store/artifact/test_mlflow_artifact_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
import os
from unittest import mock
import posixpath
import pytest
from urllib.parse import urlparse

from mlflow.store.artifact.mlflow_artifacts_repo import MlflowArtifactsRepository
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
from mlflow.exceptions import MlflowException
from mlflow.tracking._tracking_service.utils import get_tracking_uri


def test_artifact_uri_factory():
repo = get_artifact_repository("mlflow-artifacts://test.com")
assert isinstance(repo, MlflowArtifactsRepository)


def test_mlflow_artifact_uri_formats_resolved():

tracking_uri = urlparse(get_tracking_uri())

conditions = [
(
"mlflow-artifacts://myhostname:4242/my/artifact/path/hostport",
f"{tracking_uri.scheme}:"
f"//myhostname:4242/api/2.0/mlflow-artifacts/artifacts/my/artifact/path/hostport",
harupy marked this conversation as resolved.
Show resolved Hide resolved
),
(
"mlflow-artifacts://myhostname/my/artifact/path/host",
f"{tracking_uri.scheme}:"
f"//myhostname/api/2.0/mlflow-artifacts/artifacts/my/artifact/path/host",
),
(
"mlflow-artifacts:/my/artifact/path/nohost",
f"{tracking_uri.scheme}:"
f"{tracking_uri.path}/api/2.0/mlflow-artifacts/artifacts/my/artifact/path/nohost",
),
(
"mlflow-artifacts:///my/artifact/path/redundant",
f"{tracking_uri.scheme}:"
f"{tracking_uri.path}/api/2.0/mlflow-artifacts/artifacts/my/artifact/path/redundant",
),
(
"mlflow-artifacts:/",
f"{tracking_uri.scheme}:{tracking_uri.path}/api/2.0/mlflow-artifacts/artifacts",
),
]
failing_condition = "mlflow-artifacts://5000/my/artifact/path"

for submit, resolved in conditions:
artifact_repo = MlflowArtifactsRepository(submit)
assert artifact_repo.resolve_uri(submit) == resolved
with pytest.raises(
MlflowException,
match="The mlflow-artifacts uri was supplied with a port number: 5000, but no "
"host was defined.",
):
uri = MlflowArtifactsRepository( # pylint: disable=unused-variable
failing_condition
).artifact_uri
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
uri = MlflowArtifactsRepository( # pylint: disable=unused-variable
failing_condition
).artifact_uri
MlflowArtifactsRepository(failing_condition)

nit

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to just the class instantiation



class MockResponse:
def __init__(self, data, status_code):
self.data = data
self.status_code = status_code

def json(self):
return self.data

def raise_for_status(self):
if self.status_code >= 400:
raise Exception("request failed")


class MockStreamResponse(MockResponse):
def iter_content(self, chunk_size): # pylint: disable=unused-argument
yield self.data.encode("utf-8")

def __enter__(self):
return self

def __exit__(self, *exc):
pass


class FileObjectMatcher:
def __init__(self, name, mode):
self.name = name
self.mode = mode

def __eq__(self, other):
return self.name == other.name and self.mode == other.mode


@pytest.fixture
def mlflow_artifact_repo():
artifact_uri = "mlflow-artifacts:/api/2.0/mlflow-artifacts/artifacts"
return MlflowArtifactsRepository(artifact_uri)


@pytest.fixture
def mlflow_artifact_repo_with_host():
artifact_uri = "mlflow-artifacts://test.com:5000/api/2.0/mlflow-artifacts/artifacts"
return MlflowArtifactsRepository(artifact_uri)


@pytest.mark.parametrize("artifact_path", [None, "dir", "path/to/artifacts/storage"])
def test_log_artifact(mlflow_artifact_repo, tmpdir, artifact_path):
tmp_path = tmpdir.join("a.txt")
tmp_path.write("0")
with mock.patch("requests.Session.put", return_value=MockResponse({}, 200)) as mock_put:
mlflow_artifact_repo.log_artifact(tmp_path, artifact_path)
paths = (artifact_path,) if artifact_path else ()
expected_url = posixpath.join(mlflow_artifact_repo.artifact_uri, *paths, tmp_path.basename)
mock_put.assert_called_once_with(
expected_url, data=FileObjectMatcher(tmp_path, "rb"), timeout=mock.ANY
)

with mock.patch("requests.Session.put", return_value=MockResponse({}, 400)) as mock_put:
with pytest.raises(Exception, match="request failed"):
mlflow_artifact_repo.log_artifact(tmp_path, artifact_path)


@pytest.mark.parametrize("artifact_path", [None, "dir", "path/to/artifacts/storage"])
def test_log_artifact_with_host_and_port(mlflow_artifact_repo_with_host, tmpdir, artifact_path):
tmp_path = tmpdir.join("a.txt")
tmp_path.write("0")
with mock.patch("requests.Session.put", return_value=MockResponse({}, 200)) as mock_put:
mlflow_artifact_repo_with_host.log_artifact(tmp_path, artifact_path)
paths = (artifact_path,) if artifact_path else ()
expected_url = posixpath.join(
mlflow_artifact_repo_with_host.artifact_uri, *paths, tmp_path.basename
)
mock_put.assert_called_once_with(
expected_url, data=FileObjectMatcher(tmp_path, "rb"), timeout=mock.ANY
)

with mock.patch("requests.Session.put", return_value=MockResponse({}, 400)) as mock_put:
with pytest.raises(Exception, match="request failed"):
mlflow_artifact_repo_with_host.log_artifact(tmp_path, artifact_path)


@pytest.mark.parametrize("artifact_path", [None, "dir", "path/to/artifacts/storage"])
def test_log_artifacts(mlflow_artifact_repo, tmpdir, artifact_path):
tmp_path_a = tmpdir.join("a.txt")
tmp_path_b = tmpdir.mkdir("dir").join("b.txt")
tmp_path_a.write("0")
tmp_path_b.write("1")

with mock.patch("requests.Session.put", return_value=MockResponse({}, 200)) as mock_put:
mlflow_artifact_repo.log_artifacts(tmpdir, artifact_path)
paths = (artifact_path,) if artifact_path else ()
expected_url_1 = posixpath.join(
mlflow_artifact_repo.artifact_uri, *paths, tmp_path_a.basename
)
expected_url_2 = posixpath.join(
mlflow_artifact_repo.artifact_uri, *paths, "dir", tmp_path_b.basename
)
calls = [(args[0], kwargs["data"]) for args, kwargs in mock_put.call_args_list]
assert calls == [
(expected_url_1, FileObjectMatcher(tmp_path_a, "rb")),
(expected_url_2, FileObjectMatcher(tmp_path_b, "rb")),
]

with mock.patch("requests.Session.put", return_value=MockResponse({}, 400)) as mock_put:
with pytest.raises(Exception, match="request failed"):
mlflow_artifact_repo.log_artifacts(tmpdir, artifact_path)


def test_list_artifacts(mlflow_artifact_repo):
with mock.patch("requests.Session.get", return_value=MockResponse({}, 200)) as mock_get:
assert mlflow_artifact_repo.list_artifacts() == []
mock_get.assert_called_once_with(
mlflow_artifact_repo.artifact_uri, params={"path": ""}, timeout=mock.ANY
)

with mock.patch(
"requests.Session.get",
return_value=MockResponse(
{
"files": [
{"path": "1.txt", "is_dir": False, "file_size": 1},
{"path": "dir", "is_dir": True},
]
},
200,
),
) as mock_get:
assert [a.path for a in mlflow_artifact_repo.list_artifacts()] == ["1.txt", "dir"]

with mock.patch(
"requests.Session.get",
return_value=MockResponse(
{
"files": [
{"path": "1.txt", "is_dir": False, "file_size": 1},
{"path": "dir", "is_dir": True},
]
},
200,
),
) as mock_get:
assert [a.path for a in mlflow_artifact_repo.list_artifacts(path="path")] == [
"path/1.txt",
"path/dir",
]

with mock.patch("requests.Session.get", return_value=MockResponse({}, 400)) as mock_get:
with pytest.raises(Exception, match="request failed"):
mlflow_artifact_repo.list_artifacts()


def read_file(path):
with open(path) as f:
return f.read()


@pytest.mark.parametrize("remote_file_path", ["a.txt", "dir/b.xtx"])
def test_download_file(mlflow_artifact_repo, tmpdir, remote_file_path):
with mock.patch(
"requests.Session.get", return_value=MockStreamResponse("data", 200)
) as mock_get:
tmp_path = tmpdir.join(posixpath.basename(remote_file_path))
mlflow_artifact_repo._download_file(remote_file_path, tmp_path)
expected_url = posixpath.join(mlflow_artifact_repo.artifact_uri, remote_file_path)
mock_get.assert_called_once_with(expected_url, stream=True, timeout=mock.ANY)
with open(tmp_path) as f:
assert f.read() == "data"

with mock.patch(
"requests.Session.get", return_value=MockStreamResponse("data", 400)
) as mock_get:
with pytest.raises(Exception, match="request failed"):
mlflow_artifact_repo._download_file(remote_file_path, tmp_path)


def test_download_artifacts(mlflow_artifact_repo, tmpdir):
# This test simulates downloading artifacts in the following structure:
# ---------
# - a.txt
# - dir
# - b.txt
# ---------
side_effect = [
# Response for `list_experiments("")` called by `_is_directory("")`
MockResponse(
{
"files": [
{"path": "a.txt", "is_dir": False, "file_size": 6},
{"path": "dir", "is_dir": True},
]
},
200,
),
# Response for `list_experiments("")`
MockResponse(
{
"files": [
{"path": "a.txt", "is_dir": False, "file_size": 6},
{"path": "dir", "is_dir": True},
]
},
200,
),
# Response for `_download_file("a.txt")`
MockStreamResponse("data_a", 200),
# Response for `list_experiments("dir")`
MockResponse({"files": [{"path": "b.txt", "is_dir": False, "file_size": 1}]}, 200),
# Response for `_download_file("dir/b.txt")`
MockStreamResponse("data_b", 200),
]
with mock.patch("requests.Session.get", side_effect=side_effect):
mlflow_artifact_repo.download_artifacts("", tmpdir)
paths = [os.path.join(root, f) for root, _, files in os.walk(tmpdir) for f in files]
assert [os.path.relpath(p, tmpdir) for p in paths] == [
"a.txt",
os.path.join("dir", "b.txt"),
]
assert read_file(paths[0]) == "data_a"
assert read_file(paths[1]) == "data_b"