Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace get_bucket_location with head_bucket #10731

Merged
merged 9 commits into from Jan 6, 2024
27 changes: 24 additions & 3 deletions tests/store/artifact/test_optimized_s3_artifact_repo.py
@@ -1,3 +1,5 @@
import botocore.errorfactory

Check failure on line 1 in tests/store/artifact/test_optimized_s3_artifact_repo.py

View workflow job for this annotation

GitHub Actions / lint

[*] Import block is un-sorted or un-formatted. Run `ruff --fix .` or comment `@mlflow-automation autoformat` to fix this error.
import botocore.session

Check failure on line 2 in tests/store/artifact/test_optimized_s3_artifact_repo.py

View workflow job for this annotation

GitHub Actions / lint

[*] `botocore.session` imported but unused. Run `ruff --fix .` or comment `@mlflow-automation autoformat` to fix this error.
import os
import posixpath
from datetime import datetime
Expand All @@ -6,6 +8,8 @@

import pytest

from botocore.exceptions import ClientError

from mlflow.protos.service_pb2 import FileInfo
from mlflow.store.artifact.optimized_s3_artifact_repo import OptimizedS3ArtifactRepository
from mlflow.store.artifact.s3_artifact_repo import (
Expand Down Expand Up @@ -98,12 +102,29 @@
)


def test_get_s3_client_region_name_set_correctly(s3_artifact_root):
region_name = "us_random_region_42"
@pytest.mark.parametrize("client_throws", [True, False])
def test_get_s3_client_region_name_set_correctly(s3_artifact_root, client_throws):
if client_throws:

Check failure on line 107 in tests/store/artifact/test_optimized_s3_artifact_repo.py

View workflow job for this annotation

GitHub Actions / lint

Use ternary operator `region_name = "us_random_throwing_region_42" if client_throws else "us_random_region_42"` instead of `if`-`else`-block. See https://beta.ruff.rs/docs/rules/if-else-block-instead-of-if-exp for how to fix this error.
kriscon-db marked this conversation as resolved.
Show resolved Hide resolved
region_name = "us_random_throwing_region_42"
else:
region_name = "us_random_region_42"

with mock.patch("boto3.client") as mock_get_s3_client:
s3_client_mock = mock.Mock()
mock_get_s3_client.return_value = s3_client_mock
s3_client_mock.head_bucket.return_value = {"BucketRegion": region_name}
if client_throws:
error = ClientError(
{
"Error": {"Code": "403", "Message": "Forbidden"},
"ResponseMetadata": {
"HTTPHeaders": {"x-amz-bucket-region": region_name},
},
},
"head_bucket",
)
s3_client_mock.head_bucket.side_effect = error
else:
s3_client_mock.head_bucket.return_value = {"BucketRegion": region_name}

repo = OptimizedS3ArtifactRepository(posixpath.join(s3_artifact_root, "some/path"))
repo._get_s3_client()
Expand Down