Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable mlflow-artifacts scheme as wrapper around http artifact scheme #5070

Merged
merged 13 commits into from
Nov 24, 2021
Merged
2 changes: 2 additions & 0 deletions mlflow/store/artifact/artifact_repository_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
from mlflow.store.artifact.sftp_artifact_repo import SFTPArtifactRepository
from mlflow.store.artifact.http_artifact_repo import HttpArtifactRepository
from mlflow.store.artifact.mlflow_artifacts_repo import MlflowArtifactsRepository

from mlflow.utils.uri import get_uri_scheme

Expand Down Expand Up @@ -88,6 +89,7 @@ def get_artifact_repository(self, artifact_uri):
_artifact_repository_registry.register("models", ModelsArtifactRepository)
for scheme in ["http", "https"]:
_artifact_repository_registry.register(scheme, HttpArtifactRepository)
_artifact_repository_registry.register("mlflow-artifacts", MlflowArtifactsRepository)

_artifact_repository_registry.register_entrypoints()

Expand Down
25 changes: 25 additions & 0 deletions mlflow/store/artifact/mlflow_artifacts_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from urllib import parse
from collections import namedtuple
from mlflow.store.artifact.http_artifact_repo import HttpArtifactRepository
from mlflow.tracking._tracking_service.utils import get_tracking_uri


def _resolve_connection_params(artifact_uri):
ParsedURI = namedtuple("ParsedURI", "scheme host port path")
parsed_uri = parse.urlparse(artifact_uri)
return ParsedURI(parsed_uri.scheme, parsed_uri.hostname, parsed_uri.port, parsed_uri.path)


class MlflowArtifactsRepository(HttpArtifactRepository):
"""Scheme wrapper around HttpArtifactRepository for mlflow-artifacts server functionality"""

def __init__(self, artifact_uri):
parsed = _resolve_connection_params(artifact_uri)
tracking_uri = get_tracking_uri()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we raise an exception if the scheme of tracking_uri is not http or https (e.g. sqlite:///foo/bar)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great idea :) added! I also simplified the parsing logic and temporarily set the tracking uri in the test function to validate against a valid server address (then reverted the global variable to not affect other tests with an inadvertent side-effect)

Copy link
Member

@harupy harupy Nov 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! A tracking URI looks like http://localhost:5000 and doesn't contain /api/2.0/mlflow-artifacts/artifacts.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the tests and cleaned all of that up

Copy link
Member

@harupy harupy Nov 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenWilson2 Sorry I wasn't clear. The resolved artifact URI needs to have /api/2.0/mlflow-artifacts/artifacts (e.g. http://localhost:5000/api/2.0/mlflow-artifacts/artifacts/path/to/dir).

resolved_artifacts_uri = (
artifact_uri.replace("mlflow-artifacts:", f"{tracking_uri}")
.replace(f"mlflow-artifacts:{parsed.host}", f"{tracking_uri}")
.replace(f"mlflow-artifacts:{parsed.host}:{parsed.port}", f"{tracking_uri}")
BenWilson2 marked this conversation as resolved.
Show resolved Hide resolved
)

super().__init__(resolved_artifacts_uri)
233 changes: 233 additions & 0 deletions tests/store/artifact/test_mlflow_artifact_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import os
from unittest import mock
import posixpath
import pytest

from mlflow.store.artifact.mlflow_artifacts_repo import MlflowArtifactsRepository
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository


def test_artifact_uri_factory():
repo = get_artifact_repository("mlflow-artifacts://test.com")
assert isinstance(repo, MlflowArtifactsRepository)


class MockResponse:
def __init__(self, data, status_code):
self.data = data
self.status_code = status_code

def json(self):
return self.data

def raise_for_status(self):
if self.status_code >= 400:
raise Exception("request failed")


class MockStreamResponse(MockResponse):
def iter_content(self, chunk_size): # pylint: disable=unused-argument
yield self.data.encode("utf-8")

def __enter__(self):
return self

def __exit__(self, *exc):
pass


class FileObjectMatcher:
def __init__(self, name, mode):
self.name = name
self.mode = mode

def __eq__(self, other):
return self.name == other.name and self.mode == other.mode


@pytest.fixture
def mlflow_artifact_repo():
artifact_uri = "mlflow-artifacts:/api/2.0/mlflow-artifacts/artifacts"
return MlflowArtifactsRepository(artifact_uri)


@pytest.fixture
def mlflow_artifact_repo_with_host():
artifact_uri = "mlflow-artifacts://test.com:5000/api/2.0/mlflow-artifacts/artifacts"
return MlflowArtifactsRepository(artifact_uri)


@pytest.mark.parametrize("artifact_path", [None, "dir"])
def test_log_artifact(mlflow_artifact_repo, tmpdir, artifact_path):
tmp_path = tmpdir.join("a.txt")
tmp_path.write("0")
with mock.patch("requests.Session.put", return_value=MockResponse({}, 200)) as mock_put:
mlflow_artifact_repo.log_artifact(tmp_path, artifact_path)
paths = (artifact_path,) if artifact_path else ()
expected_url = posixpath.join(mlflow_artifact_repo.artifact_uri, *paths, tmp_path.basename)
mock_put.assert_called_once_with(
expected_url, data=FileObjectMatcher(tmp_path, "rb"), timeout=mock.ANY
)

with mock.patch("requests.Session.put", return_value=MockResponse({}, 400)) as mock_put:
with pytest.raises(Exception, match="request failed"):
mlflow_artifact_repo.log_artifact(tmp_path, artifact_path)


@pytest.mark.parametrize("artifact_path", [None, "dir"])
def test_log_artifact_with_host_and_port(mlflow_artifact_repo_with_host, tmpdir, artifact_path):
tmp_path = tmpdir.join("a.txt")
tmp_path.write("0")
with mock.patch("requests.Session.put", return_value=MockResponse({}, 200)) as mock_put:
mlflow_artifact_repo_with_host.log_artifact(tmp_path, artifact_path)
paths = (artifact_path,) if artifact_path else ()
expected_url = posixpath.join(
mlflow_artifact_repo_with_host.artifact_uri, *paths, tmp_path.basename
)
mock_put.assert_called_once_with(
expected_url, data=FileObjectMatcher(tmp_path, "rb"), timeout=mock.ANY
)

with mock.patch("requests.Session.put", return_value=MockResponse({}, 400)) as mock_put:
with pytest.raises(Exception, match="request failed"):
mlflow_artifact_repo_with_host.log_artifact(tmp_path, artifact_path)


@pytest.mark.parametrize("artifact_path", [None, "dir"])
def test_log_artifacts(mlflow_artifact_repo, tmpdir, artifact_path):
tmp_path_a = tmpdir.join("a.txt")
tmp_path_b = tmpdir.mkdir("dir").join("b.txt")
tmp_path_a.write("0")
tmp_path_b.write("1")

with mock.patch("requests.Session.put", return_value=MockResponse({}, 200)) as mock_put:
mlflow_artifact_repo.log_artifacts(tmpdir, artifact_path)
paths = (artifact_path,) if artifact_path else ()
expected_url_1 = posixpath.join(
mlflow_artifact_repo.artifact_uri, *paths, tmp_path_a.basename
)
expected_url_2 = posixpath.join(
mlflow_artifact_repo.artifact_uri, *paths, "dir", tmp_path_b.basename
)
calls = [(args[0], kwargs["data"]) for args, kwargs in mock_put.call_args_list]
assert calls == [
(expected_url_1, FileObjectMatcher(tmp_path_a, "rb")),
(expected_url_2, FileObjectMatcher(tmp_path_b, "rb")),
]

with mock.patch("requests.Session.put", return_value=MockResponse({}, 400)) as mock_put:
with pytest.raises(Exception, match="request failed"):
mlflow_artifact_repo.log_artifacts(tmpdir, artifact_path)


def test_list_artifacts(mlflow_artifact_repo):
with mock.patch("requests.Session.get", return_value=MockResponse({}, 200)) as mock_get:
assert mlflow_artifact_repo.list_artifacts() == []
mock_get.assert_called_once_with(
mlflow_artifact_repo.artifact_uri, params={"path": ""}, timeout=mock.ANY
)

with mock.patch(
"requests.Session.get",
return_value=MockResponse(
{
"files": [
{"path": "1.txt", "is_dir": False, "file_size": 1},
{"path": "dir", "is_dir": True},
]
},
200,
),
) as mock_get:
assert [a.path for a in mlflow_artifact_repo.list_artifacts()] == ["1.txt", "dir"]

with mock.patch(
"requests.Session.get",
return_value=MockResponse(
{
"files": [
{"path": "1.txt", "is_dir": False, "file_size": 1},
{"path": "dir", "is_dir": True},
]
},
200,
),
) as mock_get:
assert [a.path for a in mlflow_artifact_repo.list_artifacts(path="path")] == [
"path/1.txt",
"path/dir",
]

with mock.patch("requests.Session.get", return_value=MockResponse({}, 400)) as mock_get:
with pytest.raises(Exception, match="request failed"):
mlflow_artifact_repo.list_artifacts()


def read_file(path):
with open(path) as f:
return f.read()


@pytest.mark.parametrize("remote_file_path", ["a.txt", "dir/b.xtx"])
def test_download_file(mlflow_artifact_repo, tmpdir, remote_file_path):
with mock.patch(
"requests.Session.get", return_value=MockStreamResponse("data", 200)
) as mock_get:
tmp_path = tmpdir.join(posixpath.basename(remote_file_path))
mlflow_artifact_repo._download_file(remote_file_path, tmp_path)
expected_url = posixpath.join(mlflow_artifact_repo.artifact_uri, remote_file_path)
mock_get.assert_called_once_with(expected_url, stream=True, timeout=mock.ANY)
with open(tmp_path) as f:
assert f.read() == "data"

with mock.patch(
"requests.Session.get", return_value=MockStreamResponse("data", 400)
) as mock_get:
with pytest.raises(Exception, match="request failed"):
mlflow_artifact_repo._download_file(remote_file_path, tmp_path)


def test_download_artifacts(mlflow_artifact_repo, tmpdir):
# This test simulates downloading artifacts in the following structure:
# ---------
# - a.txt
# - dir
# - b.txt
# ---------
side_effect = [
# Response for `list_experiments("")` called by `_is_directory("")`
MockResponse(
{
"files": [
{"path": "a.txt", "is_dir": False, "file_size": 6},
{"path": "dir", "is_dir": True},
]
},
200,
),
# Response for `list_experiments("")`
MockResponse(
{
"files": [
{"path": "a.txt", "is_dir": False, "file_size": 6},
{"path": "dir", "is_dir": True},
]
},
200,
),
# Response for `_download_file("a.txt")`
MockStreamResponse("data_a", 200),
# Response for `list_experiments("dir")`
MockResponse({"files": [{"path": "b.txt", "is_dir": False, "file_size": 1}]}, 200),
# Response for `_download_file("dir/b.txt")`
MockStreamResponse("data_b", 200),
]
with mock.patch("requests.Session.get", side_effect=side_effect):
mlflow_artifact_repo.download_artifacts("", tmpdir)
paths = [os.path.join(root, f) for root, _, files in os.walk(tmpdir) for f in files]
assert [os.path.relpath(p, tmpdir) for p in paths] == [
"a.txt",
os.path.join("dir", "b.txt"),
]
assert read_file(paths[0]) == "data_a"
assert read_file(paths[1]) == "data_b"