Skip to content

Commit

Permalink
CRTTransferManager validates ExtraArgs and Bucket, same as classic Tr…
Browse files Browse the repository at this point in the history
…ansferManager (#294)

* CRTTransferManager validates ExtraArgs and Bucket, same as classic TransferManager.

---------

Co-authored-by: Nate Prewitt <nate.prewitt@gmail.com>
  • Loading branch information
graebm and nateprewitt committed Jun 18, 2024
1 parent 407232b commit 6461de3
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
36 changes: 36 additions & 0 deletions s3transfer/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from s3transfer.constants import MB
from s3transfer.exceptions import TransferNotDoneError
from s3transfer.futures import BaseTransferFuture, BaseTransferMeta
from s3transfer.manager import TransferManager
from s3transfer.utils import (
CallArgs,
OSUtils,
Expand Down Expand Up @@ -181,6 +182,14 @@ def _get_crt_throughput_target_gbps(provided_throughput_target_bytes=None):


class CRTTransferManager:
ALLOWED_DOWNLOAD_ARGS = TransferManager.ALLOWED_DOWNLOAD_ARGS
ALLOWED_UPLOAD_ARGS = TransferManager.ALLOWED_UPLOAD_ARGS
ALLOWED_DELETE_ARGS = TransferManager.ALLOWED_DELETE_ARGS

VALIDATE_SUPPORTED_BUCKET_VALUES = True

_UNSUPPORTED_BUCKET_PATTERNS = TransferManager._UNSUPPORTED_BUCKET_PATTERNS

def __init__(self, crt_s3_client, crt_request_serializer, osutil=None):
"""A transfer manager interface for Amazon S3 on CRT s3 client.
Expand Down Expand Up @@ -226,6 +235,8 @@ def download(
extra_args = {}
if subscribers is None:
subscribers = {}
self._validate_all_known_args(extra_args, self.ALLOWED_DOWNLOAD_ARGS)
self._validate_if_bucket_supported(bucket)
callargs = CallArgs(
bucket=bucket,
key=key,
Expand All @@ -240,6 +251,8 @@ def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None):
extra_args = {}
if subscribers is None:
subscribers = {}
self._validate_all_known_args(extra_args, self.ALLOWED_UPLOAD_ARGS)
self._validate_if_bucket_supported(bucket)
self._validate_checksum_algorithm_supported(extra_args)
callargs = CallArgs(
bucket=bucket,
Expand All @@ -255,6 +268,8 @@ def delete(self, bucket, key, extra_args=None, subscribers=None):
extra_args = {}
if subscribers is None:
subscribers = {}
self._validate_all_known_args(extra_args, self.ALLOWED_DELETE_ARGS)
self._validate_if_bucket_supported(bucket)
callargs = CallArgs(
bucket=bucket,
key=key,
Expand All @@ -266,6 +281,27 @@ def delete(self, bucket, key, extra_args=None, subscribers=None):
def shutdown(self, cancel=False):
self._shutdown(cancel)

def _validate_if_bucket_supported(self, bucket):
# s3 high level operations don't support some resources
# (eg. S3 Object Lambda) only direct API calls are available
# for such resources
if self.VALIDATE_SUPPORTED_BUCKET_VALUES:
for resource, pattern in self._UNSUPPORTED_BUCKET_PATTERNS.items():
match = pattern.match(bucket)
if match:
raise ValueError(
f'TransferManager methods do not support {resource} '
'resource. Use direct client calls instead.'
)

def _validate_all_known_args(self, actual, allowed):
for kwarg in actual:
if kwarg not in allowed:
raise ValueError(
f"Invalid extra_args key '{kwarg}', "
f"must be one of: {', '.join(allowed)}"
)

def _validate_checksum_algorithm_supported(self, extra_args):
checksum_algorithm = extra_args.get('ChecksumAlgorithm')
if checksum_algorithm is None:
Expand Down
45 changes: 45 additions & 0 deletions tests/functional/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,28 @@ def test_upload_throws_error_for_unsupported_checksum(self):
[self.record_subscriber],
)

def test_upload_throws_error_for_unsupported_arg(self):
with self.assertRaisesRegex(
ValueError, "Invalid extra_args key 'ContentMD5'"
):
self.transfer_manager.upload(
self.filename,
self.bucket,
self.key,
{'ContentMD5': '938c2cc0dcc05f2b68c4287040cfcf71'},
[self.record_subscriber],
)

def test_upload_throws_error_on_s3_object_lambda_resource(self):
s3_object_lambda_arn = (
'arn:aws:s3-object-lambda:us-west-2:123456789012:'
'accesspoint:my-accesspoint'
)
with self.assertRaisesRegex(ValueError, 'methods do not support'):
self.transfer_manager.upload(
self.filename, s3_object_lambda_arn, self.key
)

def test_upload_with_s3express(self):
future = self.transfer_manager.upload(
self.filename,
Expand Down Expand Up @@ -489,6 +511,18 @@ def test_download_to_nonseekable_stream(self):
underlying_stream.getvalue(), self.expected_download_content
)

def test_download_throws_error_for_unsupported_arg(self):
with self.assertRaisesRegex(
ValueError, "Invalid extra_args key 'Range'"
):
self.transfer_manager.download(
self.bucket,
self.key,
self.filename,
{'Range': 'bytes:0-1023'},
[self.record_subscriber],
)

def test_download_with_s3express(self):
future = self.transfer_manager.download(
self.s3express_bucket,
Expand Down Expand Up @@ -526,6 +560,17 @@ def test_delete(self):
)
self._assert_subscribers_called(future)

def test_delete_throws_error_for_unsupported_arg(self):
with self.assertRaisesRegex(
ValueError, "Invalid extra_args key 'BypassGovernanceRetention'"
):
self.transfer_manager.delete(
self.bucket,
self.key,
{'BypassGovernanceRetention': True},
[self.record_subscriber],
)

def test_delete_with_s3express(self):
future = self.transfer_manager.delete(
self.s3express_bucket, self.key, {}, [self.record_subscriber]
Expand Down

0 comments on commit 6461de3

Please sign in to comment.