Skip to content

Commit

Permalink
Revert (#8351)
Browse files Browse the repository at this point in the history
Signed-off-by: dbczumar <corey.zumar@databricks.com>
  • Loading branch information
dbczumar committed Apr 28, 2023
1 parent af38edf commit 2b50b88
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 363 deletions.
65 changes: 4 additions & 61 deletions mlflow/store/artifact/databricks_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@
from mlflow.utils.file_utils import (
download_file_using_http_uri,
relative_path_to_artifact_path,
parallelized_download_file_using_http_uri,
download_chunk,
)
from mlflow.utils.os import is_windows
from mlflow.utils.proto_json_utils import message_to_json
from mlflow.utils import rest_utils
from mlflow.utils.file_utils import read_chunk
Expand All @@ -62,7 +59,7 @@
)

_logger = logging.getLogger(__name__)
_DOWNLOAD_CHUNK_SIZE = 10_000_000
_DOWNLOAD_CHUNK_SIZE = 100000000
_MULTIPART_UPLOAD_CHUNK_SIZE = 10_000_000 # 10 MB
_MAX_CREDENTIALS_REQUEST_SIZE = 2000 # Max number of artifact paths in a single credentials request
_SERVICE_AND_METHOD_TO_INFO = {
Expand Down Expand Up @@ -422,43 +419,6 @@ def _upload_to_cloud(
message="Cloud provider not supported.", error_code=INTERNAL_ERROR
)

def _parallelized_download_from_cloud(
self, cloud_credential_info, file_size, dst_local_file_path, dst_run_relative_artifact_path
):
try:
failed_downloads = parallelized_download_file_using_http_uri(
http_uri=cloud_credential_info.signed_uri,
download_path=dst_local_file_path,
file_size=file_size,
uri_type=cloud_credential_info.type,
chunk_size=_DOWNLOAD_CHUNK_SIZE,
headers=self._extract_headers_from_credentials(cloud_credential_info.headers),
)
download_errors = [
e for e in failed_downloads.values() if e.response.status_code not in (401, 403)
]
if download_errors:
raise MlflowException(
f"Failed to download artifact {dst_run_relative_artifact_path}: "
f"{download_errors}"
)

if failed_downloads:
new_cloud_creds = self._get_read_credential_infos(
self.run_id, dst_run_relative_artifact_path
)[0]
new_signed_uri = new_cloud_creds.signed_uri
new_headers = self._extract_headers_from_credentials(new_cloud_creds.headers)

for i in failed_downloads:
download_chunk(
i, _DOWNLOAD_CHUNK_SIZE, new_headers, dst_local_file_path, new_signed_uri
)
except Exception as err:
if os.path.exists(dst_local_file_path):
os.remove(dst_local_file_path)
raise MlflowException(err)

def _download_from_cloud(self, cloud_credential_info, dst_local_file_path):
"""
Download a file from the input `cloud_credential_info` and save it to `dst_local_file_path`.
Expand Down Expand Up @@ -748,16 +708,6 @@ def list_artifacts(self, path=None):
return infos

def _download_file(self, remote_file_path, local_path):
# list_artifacts API only returns a list of FileInfos at the specified path
# if it's a directory. To get file size, we need to iterate over FileInfos
# contained by the parent directory. A bad path could result in there being
# no matching FileInfos (by path), so fall back to None size to prevent
# parallelized download.
parent_dir, _ = posixpath.split(remote_file_path)
file_infos = self.list_artifacts(parent_dir)
file_info = [info for info in file_infos if info.path == remote_file_path]
file_size = file_info[0].file_size if len(file_info) == 1 else None

run_relative_remote_file_path = posixpath.join(
self.run_relative_artifact_repo_root_path, remote_file_path
)
Expand All @@ -767,16 +717,9 @@ def _download_file(self, remote_file_path, local_path):
# Read credentials for only one file were requested. So we expected only one value in
# the response.
assert len(read_credentials) == 1
# Windows doesn't support the 'fork' multiprocessing context.
if file_size is None or file_size <= _DOWNLOAD_CHUNK_SIZE or is_windows():
self._download_from_cloud(
cloud_credential_info=read_credentials[0],
dst_local_file_path=local_path,
)
else:
self._parallelized_download_from_cloud(
read_credentials[0], file_size, local_path, remote_file_path
)
self._download_from_cloud(
cloud_credential_info=read_credentials[0], dst_local_file_path=local_path
)

def delete_artifacts(self, artifact_path=None):
raise MlflowException("Not implemented yet")
66 changes: 3 additions & 63 deletions mlflow/store/artifact/databricks_models_artifact_repo.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import logging
import json
import os
import posixpath

import mlflow.tracking
from mlflow.entities import FileInfo
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.store.artifact.artifact_repo import ArtifactRepository
from mlflow.utils.databricks_utils import get_databricks_host_creds
from mlflow.utils.file_utils import (
download_file_using_http_uri,
parallelized_download_file_using_http_uri,
download_chunk,
)
from mlflow.utils.os import is_windows
from mlflow.utils.file_utils import download_file_using_http_uri
from mlflow.utils.rest_utils import http_request
from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri
from mlflow.store.artifact.utils.models import (
Expand All @@ -23,7 +16,7 @@
)

_logger = logging.getLogger(__name__)
_DOWNLOAD_CHUNK_SIZE = 10_000_000
_DOWNLOAD_CHUNK_SIZE = 100000000
# The constant REGISTRY_LIST_ARTIFACT_ENDPOINT is defined as @developer_stable
REGISTRY_LIST_ARTIFACTS_ENDPOINT = "/api/2.0/mlflow/model-versions/list-artifacts"
# The constant REGISTRY_ARTIFACT_PRESIGNED_URI_ENDPOINT is defined as @developer_stable
Expand Down Expand Up @@ -126,67 +119,14 @@ def _extract_headers_from_signed_url(self, headers):
filtered_headers = filter(lambda h: "name" in h and "value" in h, headers)
return {header.get("name"): header.get("value") for header in filtered_headers}

def _parallelized_download_from_cloud(
self, signed_uri, headers, file_size, dst_local_file_path, dst_run_relative_artifact_path
):
try:
failed_downloads = parallelized_download_file_using_http_uri(
http_uri=signed_uri,
download_path=dst_local_file_path,
file_size=file_size,
# URI type is not known in this context
uri_type=None,
chunk_size=_DOWNLOAD_CHUNK_SIZE,
headers=headers,
)
download_errors = [
e for e in failed_downloads.values() if e.response.status_code not in (401, 403)
]
if download_errors:
raise MlflowException(
f"Failed to download artifact {dst_run_relative_artifact_path}: "
f"{download_errors}"
)
if failed_downloads:
new_signed_uri, new_headers = self._get_signed_download_uri(
dst_run_relative_artifact_path
)
for i in failed_downloads:
download_chunk(
i, _DOWNLOAD_CHUNK_SIZE, new_headers, dst_local_file_path, new_signed_uri
)
except Exception as err:
if os.path.exists(dst_local_file_path):
os.remove(dst_local_file_path)
raise MlflowException(err)

def _download_file(self, remote_file_path, local_path):
parent_dir, _ = posixpath.split(remote_file_path)
file_infos = self.list_artifacts(parent_dir)
file_info = [info for info in file_infos if info.path == remote_file_path]
file_size = file_info[0].file_size if len(file_info) == 1 else None
try:
signed_uri, raw_headers = self._get_signed_download_uri(remote_file_path)
headers = {}
if raw_headers is not None:
# Don't send None to _extract_headers_from_signed_url
headers = self._extract_headers_from_signed_url(raw_headers)
# Windows doesn't support the 'fork' multiprocessing context.
if file_size is None or file_size <= _DOWNLOAD_CHUNK_SIZE or is_windows():
download_file_using_http_uri(
signed_uri,
local_path,
_DOWNLOAD_CHUNK_SIZE,
headers,
)
else:
self._parallelized_download_from_cloud(
signed_uri,
headers,
file_size,
local_path,
remote_file_path,
)
download_file_using_http_uri(signed_uri, local_path, _DOWNLOAD_CHUNK_SIZE, headers)
except Exception as err:
raise MlflowException(err)

Expand Down
77 changes: 0 additions & 77 deletions mlflow/utils/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import codecs
import errno
import gzip
import math
import multiprocessing
import os
import posixpath
import shutil
Expand All @@ -15,21 +13,18 @@

import urllib.parse
import urllib.request
from concurrent.futures import ProcessPoolExecutor
from urllib.parse import unquote
from urllib.request import pathname2url

import atexit

import requests
import yaml

try:
from yaml import CSafeLoader as YamlSafeLoader, CSafeDumper as YamlSafeDumper
except ImportError:
from yaml import SafeLoader as YamlSafeLoader, SafeDumper as YamlSafeDumper

from mlflow.protos.databricks_artifacts_pb2 import ArtifactCredentialType
from mlflow.entities import FileInfo
from mlflow.exceptions import MissingConfigException
from mlflow.utils.rest_utils import cloud_storage_http_request, augmented_raise_for_status
Expand All @@ -39,7 +34,6 @@
from mlflow.utils.os import is_windows

ENCODING = "utf-8"
MAX_PARALLEL_DOWNLOAD_WORKERS = 32


def is_directory(name):
Expand Down Expand Up @@ -595,77 +589,6 @@ def download_file_using_http_uri(http_uri, download_path, chunk_size=100000000,
output_file.write(chunk)


def download_chunk(request_index, chunk_size, headers, download_path, http_uri):
range_start = chunk_size * request_index
range_end = range_start + chunk_size - 1
combined_headers = {**headers, "Range": f"bytes={range_start}-{range_end}"}
with cloud_storage_http_request(
"get", http_uri, stream=False, headers=combined_headers
) as response:
# File will have been created upstream. Use r+b to ensure chunks
# don't overwrite the entire file.
with open(download_path, "r+b") as f:
f.seek(range_start)
f.write(response.content)


def parallelized_download_file_using_http_uri(
http_uri, download_path, file_size, uri_type, chunk_size, headers=None
):
"""
Downloads a file specified using the `http_uri` to a local `download_path`. This function
sends multiple requests in parallel each specifying its own desired byte range as a header,
then reconstructs the file from the downloaded chunks. This allows for downloads of large files
without OOM risk.
Note : This function is meant to download files using presigned urls from various cloud
providers.
Returns a dict of chunk index : exception, if one was thrown for that index.
"""
num_requests = int(math.ceil(file_size / float(chunk_size)))
# Create file if it doesn't exist or erase the contents if it does. We should do this here
# before sending to the workers so they can each individually seek to their respective positions
# and write chunks without overwriting.
open(download_path, "w").close()
futures = {}
starting_index = 0
if uri_type == ArtifactCredentialType.GCP_SIGNED_URL or uri_type is None:
# GCP files could be transcoded, in which case the range header is ignored.
# Test if this is the case by downloading one chunk and seeing if it's larger than the
# requested size. If yes, let that be the file; if not, continue downloading more chunks.
download_chunk(0, chunk_size, headers, download_path, http_uri)
downloaded_size = os.path.getsize(download_path)
# If downloaded size was equal to the chunk size it would have been downloaded serially,
# so we don't need to consider this here
if downloaded_size > chunk_size:
return {}
else:
starting_index = 1

failed_downloads = {}
with ProcessPoolExecutor(
max_workers=MAX_PARALLEL_DOWNLOAD_WORKERS, mp_context=multiprocessing.get_context("fork")
) as executor:
for i in range(starting_index, num_requests):
fut = executor.submit(
download_chunk,
request_index=i,
chunk_size=chunk_size,
headers=headers,
download_path=download_path,
http_uri=http_uri,
)
futures[i] = fut

for i, fut in futures.items():
try:
fut.result()
except requests.HTTPError as e:
failed_downloads[i] = e

return failed_downloads


def _handle_readonly_on_windows(func, path, exc_info):
"""
This function should not be called directly but should be passed to `onerror` of
Expand Down
64 changes: 0 additions & 64 deletions tests/store/artifact/test_databricks_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

import pytest
import posixpath

import requests
from requests.models import Response
from unittest import mock
from unittest.mock import ANY
Expand All @@ -29,7 +27,6 @@
DatabricksArtifactRepository,
_MAX_CREDENTIALS_REQUEST_SIZE,
)
from mlflow.utils.os import is_windows

DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE = "mlflow.store.artifact.databricks_artifact_repo"
DATABRICKS_ARTIFACT_REPOSITORY = (
Expand Down Expand Up @@ -1464,64 +1461,3 @@ def test_multipart_upload_abort(databricks_artifact_repo, large_file, mock_chunk
headers={"header": "abort"},
timeout=None,
)


@pytest.mark.skipif(is_windows(), reason="This test fails on Windows")
def test_parallelized_download_retries_failed_chunks(
databricks_artifact_repo, large_file, mock_chunk_size
):
mock_credential_info = ArtifactCredentialInfo(
signed_uri=MOCK_AWS_SIGNED_URI, type=ArtifactCredentialType.AWS_PRESIGNED_URL
)
response = Response()
response.status_code = 401
failed_downloads = {
2: requests.HTTPError(response=response),
5: requests.HTTPError(response=response),
}

with mock.patch(
f"{DATABRICKS_ARTIFACT_REPOSITORY}._get_read_credential_infos",
return_value=[mock_credential_info],
) as get_creds_mock, mock.patch(
f"{DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE}.parallelized_download_file_using_http_uri",
return_value=failed_downloads,
), mock.patch(
f"{DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE}.download_chunk"
) as download_chunk_mock, mock.patch(
f"{DATABRICKS_ARTIFACT_REPOSITORY}.list_artifacts",
side_effect=[[], [FileInfo(path="a.txt", is_dir=False, file_size=20_000_000)]],
):
databricks_artifact_repo.download_artifacts("a.txt")
assert get_creds_mock.call_count == 2 # Once for initial fetch, once for retries
assert {call.args[0] for call in download_chunk_mock.call_args_list} == {2, 5}


@pytest.mark.skipif(is_windows(), reason="This test fails on Windows")
def test_parallelized_download_throws_for_other_errors(
databricks_artifact_repo, large_file, mock_chunk_size
):
mock_credential_info = ArtifactCredentialInfo(
signed_uri=MOCK_AWS_SIGNED_URI, type=ArtifactCredentialType.AWS_PRESIGNED_URL
)
response = Response()
response.status_code = 500
failed_downloads = {
2: requests.HTTPError(response=response),
5: requests.HTTPError(response=response),
}

with mock.patch(
f"{DATABRICKS_ARTIFACT_REPOSITORY}._get_read_credential_infos",
return_value=[mock_credential_info],
), mock.patch(
f"{DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE}.parallelized_download_file_using_http_uri",
return_value=failed_downloads,
), mock.patch(
f"{DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE}.download_chunk"
), mock.patch(
f"{DATABRICKS_ARTIFACT_REPOSITORY}.list_artifacts",
side_effect=[[], [FileInfo(path="a.txt", is_dir=False, file_size=20_000_000)]],
):
with pytest.raises(MlflowException, match="Failed to download artifact"):
databricks_artifact_repo.download_artifacts("a.txt")

0 comments on commit 2b50b88

Please sign in to comment.