Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable s3express support in the transfer manager when CRT is enabled #299

Merged
merged 1 commit into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/next-release/feature-s3-3423.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "feature",
"category": "``s3``",
"description": "Added CRT support for S3 Express One Zone"
}
22 changes: 19 additions & 3 deletions s3transfer/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
import awscrt.s3
import botocore.awsrequest
import botocore.session
from awscrt.auth import AwsCredentials, AwsCredentialsProvider
from awscrt.auth import (
AwsCredentials,
AwsCredentialsProvider,
AwsSigningAlgorithm,
AwsSigningConfig,
)
from awscrt.io import (
ClientBootstrap,
ClientTlsContext,
Expand All @@ -35,7 +40,12 @@
from s3transfer.constants import MB
from s3transfer.exceptions import TransferNotDoneError
from s3transfer.futures import BaseTransferFuture, BaseTransferMeta
from s3transfer.utils import CallArgs, OSUtils, get_callbacks
from s3transfer.utils import (
CallArgs,
OSUtils,
get_callbacks,
is_s3express_bucket,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -148,6 +158,7 @@ def create_s3_crt_client(
tls_mode=tls_mode,
tls_connection_options=tls_connection_options,
throughput_target_gbps=target_gbps,
enable_s3express=True,
)


Expand Down Expand Up @@ -807,7 +818,7 @@ def _default_get_make_request_args(
on_done_before_calls,
on_done_after_calls,
):
return {
make_request_args = {
'request': self._request_serializer.serialize_http_request(
request_type, future
),
Expand All @@ -819,6 +830,11 @@ def _default_get_make_request_args(
),
'on_progress': self.get_crt_callback(future, 'progress'),
}
if is_s3express_bucket(call_args.bucket):
make_request_args['signing_config'] = AwsSigningConfig(
algorithm=AwsSigningAlgorithm.V4_S3EXPRESS
)
return make_request_args


class RenameTempFileHandler:
Expand Down
56 changes: 56 additions & 0 deletions tests/functional/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class TestCRTTransferManager(unittest.TestCase):
def setUp(self):
self.region = 'us-west-2'
self.bucket = "test_bucket"
self.s3express_bucket = 's3expressbucket--usw2-az5--x-s3'
self.key = "test_key"
self.expected_content = b'my content'
self.expected_download_content = b'new content'
Expand All @@ -77,6 +78,8 @@ def setUp(self):
)
self.expected_path = "/" + self.bucket + "/" + self.key
self.expected_host = "s3.%s.amazonaws.com" % (self.region)
self.expected_s3express_host = f'{self.s3express_bucket}.s3express-usw2-az5.us-west-2.amazonaws.com'
self.expected_s3express_path = f'/{self.key}'
self.s3_request = mock.Mock(awscrt.s3.S3Request)
self.s3_crt_client = mock.Mock(awscrt.s3.S3Client)
self.s3_crt_client.make_request.side_effect = (
Expand Down Expand Up @@ -134,6 +137,21 @@ def _assert_expected_crt_http_request(
for expected_missing_header in expected_missing_headers:
self.assertNotIn(expected_missing_header.lower(), header_names)

def _assert_exected_s3express_request(
self, make_request_kwargs, expected_http_method='GET'
):
self._assert_expected_crt_http_request(
make_request_kwargs["request"],
expected_host=self.expected_s3express_host,
expected_path=self.expected_s3express_path,
expected_http_method=expected_http_method,
)
self.assertIn('signing_config', make_request_kwargs)
self.assertEqual(
make_request_kwargs['signing_config'].algorithm,
awscrt.auth.AwsSigningAlgorithm.V4_S3EXPRESS,
)

def _assert_subscribers_called(self, expected_future=None):
self.assertTrue(self.record_subscriber.on_queued_called)
self.assertTrue(self.record_subscriber.on_done_called)
Expand Down Expand Up @@ -355,6 +373,20 @@ def test_upload_throws_error_for_unsupported_checksum(self):
[self.record_subscriber],
)

def test_upload_with_s3express(self):
future = self.transfer_manager.upload(
self.filename,
self.s3express_bucket,
self.key,
{},
[self.record_subscriber],
)
future.result()
self._assert_exected_s3express_request(
self.s3_crt_client.make_request.call_args[1],
expected_http_method='PUT',
)

def test_download(self):
future = self.transfer_manager.download(
self.bucket, self.key, self.filename, {}, [self.record_subscriber]
Expand Down Expand Up @@ -457,6 +489,20 @@ def test_download_to_nonseekable_stream(self):
underlying_stream.getvalue(), self.expected_download_content
)

def test_download_with_s3express(self):
future = self.transfer_manager.download(
self.s3express_bucket,
self.key,
self.filename,
{},
[self.record_subscriber],
)
future.result()
self._assert_exected_s3express_request(
self.s3_crt_client.make_request.call_args[1],
expected_http_method='GET',
)

def test_delete(self):
future = self.transfer_manager.delete(
self.bucket, self.key, {}, [self.record_subscriber]
Expand All @@ -480,6 +526,16 @@ def test_delete(self):
)
self._assert_subscribers_called(future)

def test_delete_with_s3express(self):
future = self.transfer_manager.delete(
self.s3express_bucket, self.key, {}, [self.record_subscriber]
)
future.result()
self._assert_exected_s3express_request(
self.s3_crt_client.make_request.call_args[1],
expected_http_method='DELETE',
)

def test_blocks_when_max_requests_processes_reached(self):
self.s3_crt_client.make_request.return_value = self.s3_request
# We simulate blocking by not invoking the on_done callbacks for
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,7 @@ def test_target_throughput(
mock_s3_crt_client.call_args[1]['throughput_target_gbps']
== expected_gbps
)

def test_always_enables_s3express(self, mock_s3_crt_client):
s3transfer.crt.create_s3_crt_client('us-west-2')
assert mock_s3_crt_client.call_args[1]['enable_s3express'] is True