Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace get_bucket_location with head_bucket #10731

Merged
merged 9 commits into from Jan 6, 2024
Merged
8 changes: 6 additions & 2 deletions mlflow/store/artifact/optimized_s3_artifact_repo.py
Expand Up @@ -6,6 +6,8 @@
import urllib.parse
from mimetypes import guess_type

from botocore.exceptions import ClientError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this as a top-level import? Can this be a localized import within the calling function to eliminate the need to include this as a dependency?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying that now.


from mlflow.entities import FileInfo
from mlflow.environment_variables import (
MLFLOW_ENABLE_MULTIPART_UPLOAD,
Expand Down Expand Up @@ -55,15 +57,17 @@ def __init__(
self._region_name = self._get_region_name()

def _get_region_name(self):
# note: s3 client enforces path addressing style for get_bucket_location
temp_client = _get_s3_client(
addressing_style="path",
access_key_id=self._access_key_id,
secret_access_key=self._secret_access_key,
session_token=self._session_token,
s3_endpoint_url=self._s3_endpoint_url,
)
return temp_client.get_bucket_location(Bucket=self.bucket)["LocationConstraint"]
try:
return temp_client.head_bucket(Bucket=self.bucket)["BucketRegion"]
kriscon-db marked this conversation as resolved.
Show resolved Hide resolved
except ClientError as error:
return error.response["ResponseMetadata"]["HTTPHeaders"]["x-amz-bucket-region"]

def _get_s3_client(self):
return _get_s3_client(
Expand Down
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -146,6 +146,7 @@ def run(self):
"requests-auth-aws-sigv4",
# Required to log artifacts and models to AWS S3 artifact locations
"boto3",
"botocore",
# Required to log artifacts and models to GCS artifact locations
"google-cloud-storage>=1.30.0",
"azureml-core>=1.2.0",
Expand Down
29 changes: 25 additions & 4 deletions tests/store/artifact/test_optimized_s3_artifact_repo.py
@@ -1,3 +1,5 @@
import botocore.errorfactory

Check failure on line 1 in tests/store/artifact/test_optimized_s3_artifact_repo.py

View workflow job for this annotation

GitHub Actions / lint

[*] Import block is un-sorted or un-formatted. Run `ruff --fix .` or comment `@mlflow-automation autoformat` to fix this error.
import botocore.session

Check failure on line 2 in tests/store/artifact/test_optimized_s3_artifact_repo.py

View workflow job for this annotation

GitHub Actions / lint

[*] `botocore.session` imported but unused. Run `ruff --fix .` or comment `@mlflow-automation autoformat` to fix this error.
import os
import posixpath
from datetime import datetime
Expand All @@ -6,6 +8,8 @@

import pytest

from botocore.exceptions import ClientError

from mlflow.protos.service_pb2 import FileInfo
from mlflow.store.artifact.optimized_s3_artifact_repo import OptimizedS3ArtifactRepository
from mlflow.store.artifact.s3_artifact_repo import (
Expand Down Expand Up @@ -33,7 +37,7 @@
with mock.patch("boto3.client") as mock_get_s3_client:
s3_client_mock = mock.Mock()
mock_get_s3_client.return_value = s3_client_mock
s3_client_mock.get_bucket_location.return_value = {"LocationConstraint": "us-west-2"}
s3_client_mock.head_bucket.return_value = {"BucketRegion": "us-west-2"}
kriscon-db marked this conversation as resolved.
Show resolved Hide resolved

# pylint: disable=no-value-for-parameter
repo = OptimizedS3ArtifactRepository(posixpath.join(s3_artifact_root, "some/path"))
Expand Down Expand Up @@ -98,12 +102,29 @@
)


def test_get_s3_client_region_name_set_correctly(s3_artifact_root):
region_name = "us_random_region_42"
@pytest.mark.parametrize("client_throws", [True, False])
def test_get_s3_client_region_name_set_correctly(s3_artifact_root, client_throws):
if client_throws:

Check failure on line 107 in tests/store/artifact/test_optimized_s3_artifact_repo.py

View workflow job for this annotation

GitHub Actions / lint

Use ternary operator `region_name = "us_random_throwing_region_42" if client_throws else "us_random_region_42"` instead of `if`-`else`-block. See https://beta.ruff.rs/docs/rules/if-else-block-instead-of-if-exp for how to fix this error.
kriscon-db marked this conversation as resolved.
Show resolved Hide resolved
region_name = "us_random_throwing_region_42"
else:
region_name = "us_random_region_42"

with mock.patch("boto3.client") as mock_get_s3_client:
s3_client_mock = mock.Mock()
mock_get_s3_client.return_value = s3_client_mock
s3_client_mock.get_bucket_location.return_value = {"LocationConstraint": region_name}
if client_throws:
error = ClientError(
{
"Error": {"Code": "403", "Message": "Forbidden"},
"ResponseMetadata": {
"HTTPHeaders": {"x-amz-bucket-region": region_name},
},
},
"head_bucket",
)
s3_client_mock.head_bucket.side_effect = error
else:
s3_client_mock.head_bucket.return_value = {"BucketRegion": region_name}

repo = OptimizedS3ArtifactRepository(posixpath.join(s3_artifact_root, "some/path"))
repo._get_s3_client()
Expand Down