Skip to content

Commit

Permalink
Merge pull request #280 from kyleknap/crt-checksums
Browse files Browse the repository at this point in the history
Turn on checksum validation for CRT S3 transfer manager
  • Loading branch information
kyleknap committed Nov 13, 2023
2 parents b8906b3 + d85a0fa commit db20da4
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 5 deletions.
5 changes: 5 additions & 0 deletions .changes/next-release/enhancement-crt-30257.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "``crt``",
"description": "Automatically configure CRC32 checksums for uploads and checksum validation for downloads through the CRT transfer manager."
}
27 changes: 24 additions & 3 deletions s3transfer/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None):
extra_args = {}
if subscribers is None:
subscribers = {}
self._validate_checksum_algorithm_supported(extra_args)
callargs = CallArgs(
bucket=bucket,
key=key,
Expand All @@ -231,6 +232,17 @@ def delete(self, bucket, key, extra_args=None, subscribers=None):
def shutdown(self, cancel=False):
self._shutdown(cancel)

def _validate_checksum_algorithm_supported(self, extra_args):
checksum_algorithm = extra_args.get('ChecksumAlgorithm')
if checksum_algorithm is None:
return
supported_algorithms = list(awscrt.s3.S3ChecksumAlgorithm.__members__)
if checksum_algorithm.upper() not in supported_algorithms:
raise ValueError(
f'ChecksumAlgorithm: {checksum_algorithm} not supported. '
f'Supported algorithms are: {supported_algorithms}'
)

def _cancel_transfers(self):
for coordinator in self._future_coordinators:
if not coordinator.done():
Expand Down Expand Up @@ -623,11 +635,17 @@ def _get_make_request_args_put_object(
else:
call_args.extra_args["Body"] = call_args.fileobj

checksum_algorithm = call_args.extra_args.pop(
'ChecksumAlgorithm', 'CRC32'
).upper()
checksum_config = awscrt.s3.S3ChecksumConfig(
algorithm=awscrt.s3.S3ChecksumAlgorithm[checksum_algorithm],
location=awscrt.s3.S3ChecksumLocation.TRAILER,
)
# Suppress botocore's automatic MD5 calculation by setting an override
# value that will get deleted in the BotocoreCRTRequestSerializer.
# The CRT S3 client is able automatically compute checksums as part of
# requests it makes, and the intention is to configure automatic
# checksums in a future update.
# As part of the CRT S3 request, we request the CRT S3 client to
# automatically add trailing checksums to its uploads.
call_args.extra_args["ContentMD5"] = "override-to-be-removed"

make_request_args = self._default_get_make_request_args(
Expand All @@ -639,6 +657,7 @@ def _get_make_request_args_put_object(
on_done_after_calls=on_done_after_calls,
)
make_request_args['send_filepath'] = send_filepath
make_request_args['checksum_config'] = checksum_config
return make_request_args

def _get_make_request_args_get_object(
Expand All @@ -652,6 +671,7 @@ def _get_make_request_args_get_object(
):
recv_filepath = None
on_body = None
checksum_config = awscrt.s3.S3ChecksumConfig(validate_response=True)
if isinstance(call_args.fileobj, str):
final_filepath = call_args.fileobj
recv_filepath = self._os_utils.get_temp_filename(final_filepath)
Expand All @@ -673,6 +693,7 @@ def _get_make_request_args_get_object(
)
make_request_args['recv_filepath'] = recv_filepath
make_request_args['on_body'] = on_body
make_request_args['checksum_config'] = checksum_config
return make_request_args

def _default_get_make_request_args(
Expand Down
111 changes: 109 additions & 2 deletions tests/functional/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ def _assert_expected_crt_http_request(
str(expected_content_length),
)
if expected_missing_headers is not None:
header_names = [header[0] for header in crt_http_request.headers]
header_names = [
header[0].lower() for header in crt_http_request.headers
]
for expected_missing_header in expected_missing_headers:
self.assertNotIn(expected_missing_header, header_names)
self.assertNotIn(expected_missing_header.lower(), header_names)

def _assert_subscribers_called(self, expected_future=None):
self.assertTrue(self.record_subscriber.on_queued_called)
Expand All @@ -143,6 +145,21 @@ def _assert_subscribers_called(self, expected_future=None):
self.record_subscriber.on_done_future, expected_future
)

def _get_expected_upload_checksum_config(self, **overrides):
checksum_config_kwargs = {
'algorithm': awscrt.s3.S3ChecksumAlgorithm.CRC32,
'location': awscrt.s3.S3ChecksumLocation.TRAILER,
}
checksum_config_kwargs.update(overrides)
return awscrt.s3.S3ChecksumConfig(**checksum_config_kwargs)

def _get_expected_download_checksum_config(self, **overrides):
checksum_config_kwargs = {
'validate_response': True,
}
checksum_config_kwargs.update(overrides)
return awscrt.s3.S3ChecksumConfig(**checksum_config_kwargs)

def _invoke_done_callbacks(self, **kwargs):
callargs = self.s3_crt_client.make_request.call_args
callargs_kwargs = callargs[1]
Expand Down Expand Up @@ -180,6 +197,7 @@ def test_upload(self):
'send_filepath': self.filename,
'on_progress': mock.ANY,
'on_done': mock.ANY,
'checksum_config': self._get_expected_upload_checksum_config(),
},
)
self._assert_expected_crt_http_request(
Expand All @@ -206,6 +224,7 @@ def test_upload_from_seekable_stream(self):
'send_filepath': None,
'on_progress': mock.ANY,
'on_done': mock.ANY,
'checksum_config': self._get_expected_upload_checksum_config(),
},
)
self._assert_expected_crt_http_request(
Expand Down Expand Up @@ -237,6 +256,7 @@ def test_upload_from_nonseekable_stream(self):
'send_filepath': None,
'on_progress': mock.ANY,
'on_done': mock.ANY,
'checksum_config': self._get_expected_upload_checksum_config(),
},
)
self._assert_expected_crt_http_request(
Expand All @@ -251,6 +271,90 @@ def test_upload_from_nonseekable_stream(self):
)
self._assert_subscribers_called(future)

def test_upload_override_checksum_algorithm(self):
future = self.transfer_manager.upload(
self.filename,
self.bucket,
self.key,
{'ChecksumAlgorithm': 'CRC32C'},
[self.record_subscriber],
)
future.result()

callargs_kwargs = self.s3_crt_client.make_request.call_args[1]
self.assertEqual(
callargs_kwargs,
{
'request': mock.ANY,
'type': awscrt.s3.S3RequestType.PUT_OBJECT,
'send_filepath': self.filename,
'on_progress': mock.ANY,
'on_done': mock.ANY,
'checksum_config': self._get_expected_upload_checksum_config(
algorithm=awscrt.s3.S3ChecksumAlgorithm.CRC32C
),
},
)
self._assert_expected_crt_http_request(
callargs_kwargs["request"],
expected_http_method='PUT',
expected_content_length=len(self.expected_content),
expected_missing_headers=[
'Content-MD5',
'x-amz-sdk-checksum-algorithm',
'X-Amz-Trailer',
],
)
self._assert_subscribers_called(future)

def test_upload_override_checksum_algorithm_accepts_lowercase(self):
future = self.transfer_manager.upload(
self.filename,
self.bucket,
self.key,
{'ChecksumAlgorithm': 'crc32c'},
[self.record_subscriber],
)
future.result()

callargs_kwargs = self.s3_crt_client.make_request.call_args[1]
self.assertEqual(
callargs_kwargs,
{
'request': mock.ANY,
'type': awscrt.s3.S3RequestType.PUT_OBJECT,
'send_filepath': self.filename,
'on_progress': mock.ANY,
'on_done': mock.ANY,
'checksum_config': self._get_expected_upload_checksum_config(
algorithm=awscrt.s3.S3ChecksumAlgorithm.CRC32C
),
},
)
self._assert_expected_crt_http_request(
callargs_kwargs["request"],
expected_http_method='PUT',
expected_content_length=len(self.expected_content),
expected_missing_headers=[
'Content-MD5',
'x-amz-sdk-checksum-algorithm',
'X-Amz-Trailer',
],
)
self._assert_subscribers_called(future)

def test_upload_throws_error_for_unsupported_checksum(self):
with self.assertRaisesRegex(
ValueError, 'ChecksumAlgorithm: UNSUPPORTED not supported'
):
self.transfer_manager.upload(
self.filename,
self.bucket,
self.key,
{'ChecksumAlgorithm': 'UNSUPPORTED'},
[self.record_subscriber],
)

def test_download(self):
future = self.transfer_manager.download(
self.bucket, self.key, self.filename, {}, [self.record_subscriber]
Expand All @@ -267,6 +371,7 @@ def test_download(self):
'on_progress': mock.ANY,
'on_done': mock.ANY,
'on_body': None,
'checksum_config': self._get_expected_download_checksum_config(),
},
)
# the recv_filepath will be set to a temporary file path with some
Expand Down Expand Up @@ -304,6 +409,7 @@ def test_download_to_seekable_stream(self):
'on_progress': mock.ANY,
'on_done': mock.ANY,
'on_body': mock.ANY,
'checksum_config': self._get_expected_download_checksum_config(),
},
)
self._assert_expected_crt_http_request(
Expand Down Expand Up @@ -338,6 +444,7 @@ def test_download_to_nonseekable_stream(self):
'on_progress': mock.ANY,
'on_done': mock.ANY,
'on_body': mock.ANY,
'checksum_config': self._get_expected_download_checksum_config(),
},
)
self._assert_expected_crt_http_request(
Expand Down

0 comments on commit db20da4

Please sign in to comment.