diff --git a/mlflow/store/artifact/optimized_s3_artifact_repo.py b/mlflow/store/artifact/optimized_s3_artifact_repo.py index edde49a1ad593..40a4ee7c87302 100644 --- a/mlflow/store/artifact/optimized_s3_artifact_repo.py +++ b/mlflow/store/artifact/optimized_s3_artifact_repo.py @@ -55,7 +55,8 @@ def __init__( self._region_name = self._get_region_name() def _get_region_name(self): - # note: s3 client enforces path addressing style for get_bucket_location + from botocore.exceptions import ClientError + temp_client = _get_s3_client( addressing_style="path", access_key_id=self._access_key_id, @@ -63,7 +64,10 @@ def _get_region_name(self): session_token=self._session_token, s3_endpoint_url=self._s3_endpoint_url, ) - return temp_client.get_bucket_location(Bucket=self.bucket)["LocationConstraint"] + try: + return temp_client.head_bucket(Bucket=self.bucket)["BucketRegion"] + except ClientError as error: + return error.response["ResponseMetadata"]["HTTPHeaders"]["x-amz-bucket-region"] def _get_s3_client(self): return _get_s3_client( diff --git a/requirements/test-requirements.txt b/requirements/test-requirements.txt index aee905302b1ae..344f04698758d 100644 --- a/requirements/test-requirements.txt +++ b/requirements/test-requirements.txt @@ -45,4 +45,6 @@ openai<1.0 # Required for showing pytest stats psutil # SQLAlchemy == 2.0.25 requires typing_extensions >= 4.6.0 -typing_extensions>=4.6.0 \ No newline at end of file +typing_extensions>=4.6.0 +# Required for importing boto3 ClientError directly for testing +botocore>=1.34 \ No newline at end of file diff --git a/setup.py b/setup.py index 69238ee7d7f6b..c37b1c2e05808 100644 --- a/setup.py +++ b/setup.py @@ -146,6 +146,7 @@ def run(self): "requests-auth-aws-sigv4", # Required to log artifacts and models to AWS S3 artifact locations "boto3", + "botocore", # Required to log artifacts and models to GCS artifact locations "google-cloud-storage>=1.30.0", "azureml-core>=1.2.0", @@ -169,6 +170,7 @@ def run(self): "azure-storage-file-datalake>12", "google-cloud-storage>=1.30.0", "boto3>1", + "botocore>1.34", ], "gateway": GATEWAY_REQUIREMENTS, "genai": GATEWAY_REQUIREMENTS, diff --git a/tests/store/artifact/test_optimized_s3_artifact_repo.py b/tests/store/artifact/test_optimized_s3_artifact_repo.py index 2feea07307048..bf7416ef0a7d2 100644 --- a/tests/store/artifact/test_optimized_s3_artifact_repo.py +++ b/tests/store/artifact/test_optimized_s3_artifact_repo.py @@ -33,7 +33,7 @@ def test_get_s3_client_hits_cache(s3_artifact_root, monkeypatch): with mock.patch("boto3.client") as mock_get_s3_client: s3_client_mock = mock.Mock() mock_get_s3_client.return_value = s3_client_mock - s3_client_mock.get_bucket_location.return_value = {"LocationConstraint": "us-west-2"} + s3_client_mock.head_bucket.return_value = {"BucketRegion": "us-west-2"} # pylint: disable=no-value-for-parameter repo = OptimizedS3ArtifactRepository(posixpath.join(s3_artifact_root, "some/path")) @@ -98,12 +98,27 @@ def test_get_s3_client_verify_param_set_correctly( ) -def test_get_s3_client_region_name_set_correctly(s3_artifact_root): +@pytest.mark.parametrize("client_throws", [True, False]) +def test_get_s3_client_region_name_set_correctly(s3_artifact_root, client_throws): region_name = "us_random_region_42" with mock.patch("boto3.client") as mock_get_s3_client: + from botocore.exceptions import ClientError + s3_client_mock = mock.Mock() mock_get_s3_client.return_value = s3_client_mock - s3_client_mock.get_bucket_location.return_value = {"LocationConstraint": region_name} + if client_throws: + error = ClientError( + { + "Error": {"Code": "403", "Message": "Forbidden"}, + "ResponseMetadata": { + "HTTPHeaders": {"x-amz-bucket-region": region_name}, + }, + }, + "head_bucket", + ) + s3_client_mock.head_bucket.side_effect = error + else: + s3_client_mock.head_bucket.return_value = {"BucketRegion": region_name} repo = OptimizedS3ArtifactRepository(posixpath.join(s3_artifact_root, "some/path")) repo._get_s3_client()