diff --git a/s3transfer/crt.py b/s3transfer/crt.py index 7b5d1301..85539169 100644 --- a/s3transfer/crt.py +++ b/s3transfer/crt.py @@ -25,15 +25,24 @@ EventLoopGroup, TlsContextOptions, ) -from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType +from awscrt.s3 import ( + S3ChecksumAlgorithm, + S3ChecksumConfig, + S3ChecksumLocation, + S3Client, + S3RequestTlsMode, + S3RequestType, +) from botocore import UNSIGNED from botocore.compat import urlsplit from botocore.config import Config from botocore.exceptions import NoCredentialsError +from s3transfer.compat import seekable from s3transfer.constants import GB, MB from s3transfer.exceptions import TransferNotDoneError from s3transfer.futures import BaseTransferFuture, BaseTransferMeta +from s3transfer.manager import TransferManager from s3transfer.utils import CallArgs, OSUtils, get_callbacks logger = logging.getLogger(__name__) @@ -67,7 +76,7 @@ def create_s3_crt_client( region, botocore_credential_provider=None, num_threads=None, - target_throughput=5 * GB / 8, + target_throughput=5_000_000_000.0 / 8, part_size=8 * MB, use_ssl=True, verify=None, @@ -86,8 +95,8 @@ def create_s3_crt_client( is the number of processors in the machine. :type target_throughput: Optional[int] - :param target_throughput: Throughput target in Bytes. - Default is 0.625 GB/s (which translates to 5 Gb/s). + :param target_throughput: Throughput target in bytes per second. + Default translates to 5.0 Gb/s or 0.582 GiB/s. :type part_size: Optional[int] :param part_size: Size, in Bytes, of parts that files will be downloaded @@ -137,7 +146,7 @@ def create_s3_crt_client( credentails_provider_adapter ) - target_gbps = target_throughput * 8 / GB + target_gigabits = target_throughput * 8 / 1_000_000_000.0 return S3Client( bootstrap=bootstrap, region=region, @@ -145,11 +154,16 @@ def create_s3_crt_client( part_size=part_size, tls_mode=tls_mode, tls_connection_options=tls_connection_options, - throughput_target_gbps=target_gbps, + throughput_target_gbps=target_gigabits, ) class CRTTransferManager: + + ALLOWED_DOWNLOAD_ARGS = TransferManager.ALLOWED_DOWNLOAD_ARGS + ALLOWED_UPLOAD_ARGS = TransferManager.ALLOWED_UPLOAD_ARGS + ALLOWED_DELETE_ARGS = TransferManager.ALLOWED_DELETE_ARGS + def __init__(self, crt_s3_client, crt_request_serializer, osutil=None): """A transfer manager interface for Amazon S3 on CRT s3 client. @@ -192,6 +206,8 @@ def download( extra_args = {} if subscribers is None: subscribers = {} + self._validate_all_known_args(extra_args, TransferManager.ALLOWED_DOWNLOAD_ARGS) + # TODO: _validate_if_bucket_supported() ??? callargs = CallArgs( bucket=bucket, key=key, @@ -206,6 +222,7 @@ def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None): extra_args = {} if subscribers is None: subscribers = {} + self._validate_all_known_args(extra_args, TransferManager.ALLOWED_UPLOAD_ARGS) callargs = CallArgs( bucket=bucket, key=key, @@ -220,6 +237,7 @@ def delete(self, bucket, key, extra_args=None, subscribers=None): extra_args = {} if subscribers is None: subscribers = {} + self._validate_all_known_args(extra_args, TransferManager.ALLOWED_DELETE_ARGS) callargs = CallArgs( bucket=bucket, key=key, @@ -260,6 +278,14 @@ def _shutdown(self, cancel=False): def _release_semaphore(self, **kwargs): self._semaphore.release() + def _validate_all_known_args(self, actual, allowed): + for kwarg in actual: + if kwarg not in allowed: + raise ValueError( + "Invalid extra_args key '%s', " + "must be one of: %s" % (kwarg, ', '.join(allowed)) + ) + def _submit_transfer(self, request_type, call_args): on_done_after_calls = [self._release_semaphore] coordinator = CRTTransferCoordinator(transfer_id=self._id_counter) @@ -359,7 +385,7 @@ def set_exception(self, exception): class BaseCRTRequestSerializer: - def serialize_http_request(self, transfer_type, future): + def serialize_http_request(self, transfer_type, future, fileobj): """Serialize CRT HTTP requests. :type transfer_type: string @@ -428,19 +454,12 @@ def _crt_request_from_aws_request(self, aws_request): headers_list.append((name, str(value, 'utf-8'))) crt_headers = awscrt.http.HttpHeaders(headers_list) - # CRT requires body (if it exists) to be an I/O stream. - crt_body_stream = None - if aws_request.body: - if hasattr(aws_request.body, 'seek'): - crt_body_stream = aws_request.body - else: - crt_body_stream = BytesIO(aws_request.body) crt_request = awscrt.http.HttpRequest( method=aws_request.method, path=crt_path, headers=crt_headers, - body_stream=crt_body_stream, + body_stream=aws_request.body, ) return crt_request @@ -451,8 +470,24 @@ def _convert_to_crt_http_request(self, botocore_http_request): # If host is not set, set it for the request before using CRT s3 url_parts = urlsplit(botocore_http_request.url) crt_request.headers.set("host", url_parts.netloc) + + # Remove bogus Content-MD5 value (see comment elsewhere in file) if crt_request.headers.get('Content-MD5') is not None: crt_request.headers.remove("Content-MD5") + + # Explicitly set "Content-Length: 0" when there's no body. + # Botocore doesn't bother setting this, but CRT likes to know. + # Note that Content-Length SHOULD be absent if body is nonseekable. + if crt_request.headers.get('Content-Length') is None: + if botocore_http_request.body is None: + crt_request.headers.add('Content-Length', "0") + + # Remove "Transfer-Encoding: chunked". + # Botocore sets this on nonseekable streams, + # but CRT currently chokes on this header (TODO: fix this in CRT) + if crt_request.headers.get('Transfer-Encoding') is not None: + crt_request.headers.remove('Transfer-Encoding') + return crt_request def _capture_http_request(self, request, **kwargs): @@ -556,22 +591,57 @@ def get_make_request_args( self, request_type, call_args, coordinator, future, on_done_after_calls ): recv_filepath = None + on_body = None send_filepath = None s3_meta_request_type = getattr( S3RequestType, request_type.upper(), S3RequestType.DEFAULT ) on_done_before_calls = [] + checksum_config = S3ChecksumConfig() + if s3_meta_request_type == S3RequestType.GET_OBJECT: - final_filepath = call_args.fileobj - recv_filepath = self._os_utils.get_temp_filename(final_filepath) - file_ondone_call = RenameTempFileHandler( - coordinator, final_filepath, recv_filepath, self._os_utils - ) - on_done_before_calls.append(file_ondone_call) + if isinstance(call_args.fileobj, str): + # fileobj is a filepath + final_filepath = call_args.fileobj + recv_filepath = self._os_utils.get_temp_filename(final_filepath) + file_ondone_call = RenameTempFileHandler( + coordinator, final_filepath, recv_filepath, self._os_utils + ) + on_done_before_calls.append(file_ondone_call) + + elif call_args.fileobj is not None: + # fileobj is a file-like object + response_handler = _FileobjResponseHandler(call_args.fileobj) + on_body = response_handler.on_body + + # Only validate response checksums when downloading. + # (upload responses also have checksum headers, but they're just an + # echo of what was in the request, an upload response's body is empty) + checksum_config.validate_response = True + elif s3_meta_request_type == S3RequestType.PUT_OBJECT: - send_filepath = call_args.fileobj - data_len = self._os_utils.get_file_size(send_filepath) - call_args.extra_args["ContentLength"] = data_len + if isinstance(call_args.fileobj, str): + # fileobj is a filepath + send_filepath = call_args.fileobj + data_len = self._os_utils.get_file_size(send_filepath) + call_args.extra_args["ContentLength"] = data_len + + elif call_args.fileobj is not None: + # fileobj is a file-like object + call_args.extra_args["Body"] = call_args.fileobj + + # We want the CRT S3Client to calculate checksums, not botocore. + # Default to CRC32. + if call_args.extra_args.get('ChecksumAlgorithm') is not None: + algorithm_name = call_args.extra_args.pop('ChecksumAlgorithm') + checksum_config.algorithm = S3ChecksumAlgorithm[algorithm_name] + else: + checksum_config.algorithm = S3ChecksumAlgorithm.CRC32 + checksum_config.location = S3ChecksumLocation.TRAILER + + # Suppress botocore's MD5 calculation by setting a bogus value. + # (this header gets removed before the request is passed to CRT) + call_args.extra_args["ContentMD5"] = "bogus value deleted later" crt_request = self._request_serializer.serialize_http_request( request_type, future @@ -582,6 +652,8 @@ def get_make_request_args( 'type': s3_meta_request_type, 'recv_filepath': recv_filepath, 'send_filepath': send_filepath, + 'checksum_config': checksum_config, + 'on_body': on_body, 'on_done': self.get_crt_callback( future, 'done', on_done_before_calls, on_done_after_calls ), @@ -642,3 +714,11 @@ def __init__(self, coordinator): def __call__(self, **kwargs): self._coordinator.set_done_callbacks_complete() + + +class _FileobjResponseHandler: + def __init__(self, fileobj): + self._fileobj = fileobj + + def on_body(self, chunk: bytes, offset: int, **kwargs): + self._fileobj.write(chunk) diff --git a/tests/__init__.py b/tests/__init__.py index e36c4936..0954ba38 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -506,26 +506,30 @@ def write(self, b): # kind of error even though writeable returns False. raise io.UnsupportedOperation("write") - def read(self, n=-1): - return self._data.read(n) + def readinto(self, b): + return self._data.readinto(b) -class NonSeekableWriter(io.RawIOBase): - def __init__(self, fileobj): - super().__init__() - self._fileobj = fileobj +def create_nonseekable_writer(fileobj): def seekable(self): return False + fileobj.seekable = seekable + def writable(self): return True + fileobj.writable = writable + def readable(self): return False - def write(self, b): - self._fileobj.write(b) + fileobj.readable = readable def read(self, n=-1): raise io.UnsupportedOperation("read") + + fileobj.read = read + + return fileobj diff --git a/tests/functional/test_crt.py b/tests/functional/test_crt.py index 0ead2959..a440eb48 100644 --- a/tests/functional/test_crt.py +++ b/tests/functional/test_crt.py @@ -14,15 +14,22 @@ import threading import time from concurrent.futures import Future +from io import BytesIO from botocore.session import Session from s3transfer.subscribers import BaseSubscriber -from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest +from tests import ( + FileCreator, + HAS_CRT, + NonSeekableReader, + mock, + requires_crt, + unittest, +) if HAS_CRT: import awscrt - import s3transfer.crt @@ -61,7 +68,8 @@ def setUp(self): self.bucket = "test_bucket" self.key = "test_key" self.files = FileCreator() - self.filename = self.files.create_file('myfile', 'my content') + self.content = b'my content' + self.filename = self.files.create_file('myfile', self.content, mode='wb') self.expected_path = "/" + self.bucket + "/" + self.key self.expected_host = "s3.%s.amazonaws.com" % (self.region) self.s3_request = mock.Mock(awscrt.s3.S3Request) @@ -99,6 +107,7 @@ def _invoke_done_callbacks(self, **kwargs): on_done(error=None) def _simulate_file_download(self, recv_filepath): + # Create file that RenameTempFileHandler expects to be there self.files.create_file(recv_filepath, "fake response") def _simulate_make_request_side_effect(self, **kwargs): @@ -107,6 +116,93 @@ def _simulate_make_request_side_effect(self, **kwargs): self._invoke_done_callbacks() return mock.DEFAULT + def _assert_expected_make_request_callargs_for_upload_helper( + self, expecting_filepath, expecting_content_length + ): + call_kwargs = self.s3_crt_client.make_request.call_args[1] + crt_request = call_kwargs["request"] + + if expecting_filepath: + self.assertEqual(self.filename, call_kwargs.get("send_filepath")) + self.assertIsNone(crt_request.body_stream) + else: + self.assertIsNone(call_kwargs.get("send_filepath")) + self.assertIsNotNone(crt_request.body_stream) + + if expecting_content_length: + self.assertEqual(str(len(self.content)), crt_request.headers.get("content-length")) + else: + self.assertIsNone(crt_request.headers.get("content-length")) + + self.assertIsNone(call_kwargs.get("recv_filepath")) + self.assertIsNone(call_kwargs.get("on_body")) + self.assertEqual(awscrt.s3.S3RequestType.PUT_OBJECT, call_kwargs.get("type")) + self.assertEqual("PUT", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + + # The CRT should be doing checksums, and those settings are passed via checksum_config. + # Botocore should NOT be adding headers for checksums or Content-MD5. + self.assertIsNone(crt_request.headers.get("content-md5")) + self.assertIsNone(crt_request.headers.get("x-amz-sdk-checksum-algorithm")) + self.assertFalse(any([k.lower().startswith('x-amz-checksum') for k, v in crt_request.headers])) + + def _assert_expected_make_request_callargs_for_upload_file(self): + self._assert_expected_make_request_callargs_for_upload_helper( + expecting_filepath=True, + expecting_content_length=True, + ) + + def _assert_expected_make_request_callargs_for_upload_fileobj(self): + self._assert_expected_make_request_callargs_for_upload_helper( + expecting_filepath=False, + expecting_content_length=True, + ) + + def _assert_expected_make_request_callargs_for_upload_nonseekable_fileobj(self): + self._assert_expected_make_request_callargs_for_upload_helper( + expecting_filepath=False, + expecting_content_length=False, + ) + + def _assert_expected_make_request_callargs_for_download_helper( + self, expecting_filepath + ): + call_kwargs = self.s3_crt_client.make_request.call_args[1] + crt_request = call_kwargs["request"] + + if expecting_filepath: + # the recv_filepath will be set to a temporary file path with some + # random suffix + self.assertTrue( + fnmatch.fnmatch( + call_kwargs.get("recv_filepath"), + f'{self.filename}.*', + ) + ) + self.assertIsNone(call_kwargs.get("on_body")) + else: + self.assertIsNone(call_kwargs.get("recv_filepath")) + self.assertIsNotNone(call_kwargs.get("on_body")) + + self.assertIsNone(call_kwargs.get("send_filepath")) + self.assertIsNone(crt_request.body_stream) + self.assertEqual("0", crt_request.headers.get("content-length")) + self.assertEqual(awscrt.s3.S3RequestType.GET_OBJECT, call_kwargs.get("type")) + self.assertEqual("GET", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + + def _assert_expected_make_request_callargs_for_download_file(self): + self._assert_expected_make_request_callargs_for_download_helper( + expecting_filepath=True, + ) + + def _assert_expected_make_request_callargs_for_download_fileobj(self): + self._assert_expected_make_request_callargs_for_download_helper( + expecting_filepath=False, + ) + def test_upload(self): self.s3_crt_client.make_request.side_effect = ( self._simulate_make_request_side_effect @@ -116,19 +212,73 @@ def test_upload(self): ) future.result() - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] - self.assertEqual(callargs_kwargs["send_filepath"], self.filename) - self.assertIsNone(callargs_kwargs["recv_filepath"]) - self.assertEqual( - callargs_kwargs["type"], awscrt.s3.S3RequestType.PUT_OBJECT + self._assert_expected_make_request_callargs_for_upload_file() + self._assert_subscribers_called(future) + + def test_upload_for_fileobj(self): + self.s3_crt_client.make_request.side_effect = ( + self._simulate_make_request_side_effect ) - crt_request = callargs_kwargs["request"] - self.assertEqual("PUT", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) + with open(self.filename, 'rb') as f: + future = self.transfer_manager.upload( + f, self.bucket, self.key, {}, [self.record_subscriber] + ) + future.result() + + self._assert_expected_make_request_callargs_for_upload_fileobj() self._assert_subscribers_called(future) + def test_upload_for_seekable_filelike_obj(self): + self.s3_crt_client.make_request.side_effect = ( + self._simulate_make_request_side_effect + ) + bytes_io = BytesIO(self.content) + future = self.transfer_manager.upload( + bytes_io, self.bucket, self.key, {}, [self.record_subscriber] + ) + future.result() + + self._assert_expected_make_request_callargs_for_upload_fileobj() + self._assert_subscribers_called(future) + + def test_upload_for_non_seekable_filelike_obj(self): + self.s3_crt_client.make_request.side_effect = ( + self._simulate_make_request_side_effect + ) + body = NonSeekableReader(self.content) + future = self.transfer_manager.upload( + body, self.bucket, self.key, {}, [self.record_subscriber] + ) + future.result() + + self._assert_expected_make_request_callargs_for_upload_nonseekable_fileobj() + self._assert_subscribers_called(future) + + def test_upload_with_checksum_algorithm(self): + self.s3_crt_client.make_request.side_effect = ( + self._simulate_make_request_side_effect + ) + extra_args = { + 'ChecksumAlgorithm': 'SHA256', + } + future = self.transfer_manager.upload( + self.filename, self.bucket, self.key, extra_args, [self.record_subscriber] + ) + future.result() + + self._assert_expected_make_request_callargs_for_upload_file() + self._assert_subscribers_called(future) + + call_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual(call_kwargs['checksum_config'].algorithm.name, 'SHA256') + + def test_allowed_upload_args(self): + # assert that only ALLOWED_UPLOAD_ARGS are permitted + # (ContentMD5 is not currently in the list) + with self.assertRaises(ValueError) as cm: + extra_args = {'ContentMD5', 'e484175540065aec988a26593675785d'} + self.transfer_manager.upload(self.filename, self.bucket, self.key, extra_args) + def test_download(self): self.s3_crt_client.make_request.side_effect = ( self._simulate_make_request_side_effect @@ -138,28 +288,21 @@ def test_download(self): ) future.result() - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] - # the recv_filepath will be set to a temporary file path with some - # random suffix - self.assertTrue( - fnmatch.fnmatch( - callargs_kwargs["recv_filepath"], - f'{self.filename}.*', - ) - ) - self.assertIsNone(callargs_kwargs["send_filepath"]) - self.assertEqual( - callargs_kwargs["type"], awscrt.s3.S3RequestType.GET_OBJECT + self._assert_expected_make_request_callargs_for_download_file() + self._assert_subscribers_called(future) + + def test_download_for_fileobj(self): + self.s3_crt_client.make_request.side_effect = ( + self._simulate_make_request_side_effect ) - crt_request = callargs_kwargs["request"] - self.assertEqual("GET", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) + with open(self.filename, 'wb') as f: + future = self.transfer_manager.download( + self.bucket, self.key, f, {}, [self.record_subscriber] + ) + future.result() + + self._assert_expected_make_request_callargs_for_download_fileobj() self._assert_subscribers_called(future) - with open(self.filename, 'rb') as f: - # Check the fake response overwrites the file because of download - self.assertEqual(f.read(), b'fake response') def test_delete(self): self.s3_crt_client.make_request.side_effect = ( diff --git a/tests/functional/test_download.py b/tests/functional/test_download.py index f458721d..33f9ceb0 100644 --- a/tests/functional/test_download.py +++ b/tests/functional/test_download.py @@ -26,10 +26,10 @@ from tests import ( BaseGeneralInterfaceTest, FileSizeProvider, - NonSeekableWriter, RecordingOSUtils, RecordingSubscriber, StreamWithError, + create_nonseekable_writer, skip_if_using_serial_implementation, skip_if_windows, ) @@ -181,7 +181,7 @@ def test_download_for_nonseekable_filelike_obj(self): with open(self.filename, 'wb') as f: future = self.manager.download( - self.bucket, self.key, NonSeekableWriter(f), self.extra_args + self.bucket, self.key, create_nonseekable_writer(f), self.extra_args ) future.result() diff --git a/tests/integration/test_crt.py b/tests/integration/test_crt.py index 157ae2dc..4e4855fb 100644 --- a/tests/integration/test_crt.py +++ b/tests/integration/test_crt.py @@ -11,11 +11,18 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import glob +from io import BytesIO import os from s3transfer.subscribers import BaseSubscriber from s3transfer.utils import OSUtils -from tests import HAS_CRT, assert_files_equal, requires_crt +from tests import ( + HAS_CRT, + NonSeekableReader, + assert_files_equal, + create_nonseekable_writer, + requires_crt, +) from tests.integration import BaseTransferManagerIntegTest if HAS_CRT: @@ -40,8 +47,7 @@ def on_done(self, **kwargs): self.on_done_called = True -@requires_crt -class TestCRTS3Transfers(BaseTransferManagerIntegTest): +class BaseCRTS3TransfersTest(BaseTransferManagerIntegTest): """Tests for the high level s3transfer based on CRT implementation.""" def _create_s3_transfer(self): @@ -76,88 +82,79 @@ def _assert_subscribers_called(self, expected_bytes_transferred=None): expected_bytes_transferred, ) - def test_upload_below_multipart_chunksize(self): - transfer = self._create_s3_transfer() - file_size = 1024 * 1024 - filename = self.files.create_file_with_size( - 'foo.txt', filesize=file_size - ) - self.addCleanup(self.delete_object, 'foo.txt') - with transfer: - future = transfer.upload( - filename, - self.bucket_name, - 'foo.txt', - subscribers=[self.record_subscriber], - ) - future.result() +@requires_crt +class TestCRTUpload(BaseCRTS3TransfersTest): + # CRTTransferManager upload tests. Defaults to using filepath, but + # subclasses override the function below to use streaming fileobj instead. - self.assertTrue(self.object_exists('foo.txt')) - self._assert_subscribers_called(file_size) + def get_input_fileobj(self, name, contents): + # override this in subclasses to upload via fileobj instead of filepath + mode = 'w' if isinstance(contents, str) else 'wb' + return self.files.create_file(name, contents, mode) + + def get_input_fileobj_with_size(self, name, size): + return self.get_input_fileobj(name, b'a' * size) + + def _assert_object_exists(self, key, expected_content_length, extra_args={}): + self.assertTrue(self.object_exists(key, extra_args)) + response = self.client.head_object(Bucket=self.bucket_name, Key=key, **extra_args) + self.assertEqual(response['ContentLength'], expected_content_length) - def test_upload_above_multipart_chunksize(self): + def _test_basic_upload(self, key, file_size, extra_args=None): transfer = self._create_s3_transfer() - file_size = 20 * 1024 * 1024 - filename = self.files.create_file_with_size( - '20mb.txt', filesize=file_size - ) - self.addCleanup(self.delete_object, '20mb.txt') + file = self.get_input_fileobj_with_size(key, file_size) + self.addCleanup(self.delete_object, key) with transfer: future = transfer.upload( - filename, + file, self.bucket_name, - '20mb.txt', + key, + extra_args, subscribers=[self.record_subscriber], ) future.result() - self.assertTrue(self.object_exists('20mb.txt')) + + self._assert_object_exists(key, file_size) self._assert_subscribers_called(file_size) - def test_upload_file_above_threshold_with_acl(self): - transfer = self._create_s3_transfer() + def test_below_multipart_chunksize(self): + self._test_basic_upload('1mb.txt', file_size=1024 * 1024) + + def test_above_multipart_chunksize(self): + self._test_basic_upload('20mb.txt', file_size=20 * 1024 * 1024) + + def test_empty_file(self): + self._test_basic_upload('0mb.txt', file_size=0) + + def test_file_above_threshold_with_acl(self): + key = '6mb.txt' file_size = 6 * 1024 * 1024 - filename = self.files.create_file_with_size( - '6mb.txt', filesize=file_size - ) extra_args = {'ACL': 'public-read'} - self.addCleanup(self.delete_object, '6mb.txt') - - with transfer: - future = transfer.upload( - filename, - self.bucket_name, - '6mb.txt', - extra_args=extra_args, - subscribers=[self.record_subscriber], - ) - future.result() + self._test_basic_upload(key, file_size, extra_args) - self.assertTrue(self.object_exists('6mb.txt')) response = self.client.get_object_acl( - Bucket=self.bucket_name, Key='6mb.txt' + Bucket=self.bucket_name, Key=key ) self._assert_has_public_read_acl(response) - self._assert_subscribers_called(file_size) - def test_upload_file_above_threshold_with_ssec(self): + def test_file_above_threshold_with_ssec(self): key_bytes = os.urandom(32) extra_args = { 'SSECustomerKey': key_bytes, 'SSECustomerAlgorithm': 'AES256', } + key = '6mb.txt' file_size = 6 * 1024 * 1024 transfer = self._create_s3_transfer() - filename = self.files.create_file_with_size( - '6mb.txt', filesize=file_size - ) - self.addCleanup(self.delete_object, '6mb.txt') + file = self.get_input_fileobj_with_size(key, file_size) + self.addCleanup(self.delete_object, key) with transfer: future = transfer.upload( - filename, + file, self.bucket_name, - '6mb.txt', + key, extra_args=extra_args, subscribers=[self.record_subscriber], ) @@ -165,147 +162,248 @@ def test_upload_file_above_threshold_with_ssec(self): # A head object will fail if it has a customer key # associated with it and it's not provided in the HeadObject # request so we can use this to verify our functionality. - oringal_extra_args = { + original_extra_args = { 'SSECustomerKey': key_bytes, 'SSECustomerAlgorithm': 'AES256', } - self.wait_object_exists('6mb.txt', oringal_extra_args) + self._assert_object_exists(key, file_size, original_extra_args) response = self.client.head_object( - Bucket=self.bucket_name, Key='6mb.txt', **oringal_extra_args + Bucket=self.bucket_name, Key=key, **original_extra_args ) self.assertEqual(response['SSECustomerAlgorithm'], 'AES256') self._assert_subscribers_called(file_size) - def test_can_send_extra_params_on_download(self): - # We're picking the customer provided sse feature - # of S3 to test the extra_args functionality of - # S3. - key_bytes = os.urandom(32) - extra_args = { - 'SSECustomerKey': key_bytes, - 'SSECustomerAlgorithm': 'AES256', - } - filename = self.files.create_file('foo.txt', 'hello world') - self.upload_file(filename, 'foo.txt', extra_args) - transfer = self._create_s3_transfer() + def test_checksum_algorithm(self): + key = 'sha1.txt' + file_size = 1 * 1024 * 1024 + extra_args = {'ChecksumAlgorithm': 'SHA1'} + self._test_basic_upload(key, file_size, extra_args) - download_path = os.path.join(self.files.rootdir, 'downloaded.txt') + response = self.client.head_object( + Bucket=self.bucket_name, Key=key, ChecksumMode='ENABLED', + ) + self.assertIsNotNone(response.get('ChecksumSHA1')) + + def test_many_files(self): + transfer = self._create_s3_transfer() + keys = [] + file_size = 1024 * 1024 + files = [] + base_key = 'foo' + suffix = '.txt' + for i in range(10): + key = base_key + str(i) + suffix + keys.append(key) + file = self.get_input_fileobj_with_size(key, file_size) + files.append(file) + self.addCleanup(self.delete_object, key) with transfer: - future = transfer.download( - self.bucket_name, - 'foo.txt', - download_path, - extra_args=extra_args, - subscribers=[self.record_subscriber], - ) + for file, key in zip(files, keys): + transfer.upload(file, self.bucket_name, key) + + for key in keys: + self._assert_object_exists(key, file_size) + + def test_cancel(self): + transfer = self._create_s3_transfer() + key = '20mb.txt' + file_size = 20 * 1024 * 1024 + file = self.get_input_fileobj_with_size(key, file_size) + future = None + try: + with transfer: + future = transfer.upload(file, self.bucket_name, key) + raise KeyboardInterrupt() + except KeyboardInterrupt: + pass + + with self.assertRaises(AwsCrtError) as cm: future.result() - file_size = self.osutil.get_file_size(download_path) - self._assert_subscribers_called(file_size) - with open(download_path, 'rb') as f: - self.assertEqual(f.read(), b'hello world') + self.assertEqual(cm.name, 'AWS_ERROR_S3_CANCELED') + self.assertTrue(self.object_not_exists('20mb.txt')) + + +@requires_crt +class TestCRTUploadSeekableStream(TestCRTUpload): + # Repeat upload tests, but use seekable streams + def get_input_fileobj(self, name, contents): + return BytesIO(contents) + + +@requires_crt +class TestCRTUploadNonSeekableStream(TestCRTUpload): + # Repeat upload tests, but use nonseekable streams + def get_input_fileobj(self, name, contents): + return NonSeekableReader(contents) + - def test_download_below_threshold(self): +@requires_crt +class TestCRTDownload(BaseCRTS3TransfersTest): + # CRTTransferManager download tests. Defaults to using filepath, but + # subclasses override the function below to use streaming fileobj instead. + + def get_output_fileobj(self, name): + # override this in subclasses to download via fileobj instead of filepath + return os.path.join(self.files.rootdir, name) + + def _assert_files_equal(self, orig_file, download_file): + # download_file is either a path or a file-like object + if isinstance(download_file, str): + assert_files_equal(orig_file, download_file) + else: + download_file.close() + assert_files_equal(orig_file, download_file.name) + + def _test_basic_download(self, file_size): transfer = self._create_s3_transfer() - filename = self.files.create_file_with_size( - 'foo.txt', filesize=1024 * 1024 - ) - self.upload_file(filename, 'foo.txt') + key = 'foo.txt' + orig_file = self.files.create_file_with_size(key, file_size) + self.upload_file(orig_file, key) - download_path = os.path.join(self.files.rootdir, 'downloaded.txt') + download_file = self.get_output_fileobj('downloaded.txt') with transfer: future = transfer.download( self.bucket_name, - 'foo.txt', - download_path, + key, + download_file, subscribers=[self.record_subscriber], ) future.result() - file_size = self.osutil.get_file_size(download_path) + self._assert_files_equal(orig_file, download_file) self._assert_subscribers_called(file_size) - assert_files_equal(filename, download_path) - def test_download_above_threshold(self): - transfer = self._create_s3_transfer() - filename = self.files.create_file_with_size( - 'foo.txt', filesize=20 * 1024 * 1024 - ) - self.upload_file(filename, 'foo.txt') + def test_below_threshold(self): + self._test_basic_download(file_size=1024 * 1024) + + def test_above_threshold(self): + self._test_basic_download(file_size=20 * 1024 * 1024) - download_path = os.path.join(self.files.rootdir, 'downloaded.txt') + def test_empty_file(self): + self._test_basic_download(file_size=0) + + def test_can_send_extra_params(self): + # We're picking the customer provided sse feature + # of S3 to test the extra_args functionality of S3. + key_bytes = os.urandom(32) + extra_args = { + 'SSECustomerKey': key_bytes, + 'SSECustomerAlgorithm': 'AES256', + } + key = 'foo.txt' + orig_file = self.files.create_file(key, 'hello world') + self.upload_file(orig_file, key, extra_args) + + transfer = self._create_s3_transfer() + download_file = self.get_output_fileobj('downloaded.txt') with transfer: future = transfer.download( self.bucket_name, - 'foo.txt', - download_path, + key, + download_file, + extra_args=extra_args, subscribers=[self.record_subscriber], ) future.result() - assert_files_equal(filename, download_path) - file_size = self.osutil.get_file_size(download_path) - self._assert_subscribers_called(file_size) + self._assert_files_equal(orig_file, download_file) + self._assert_subscribers_called(len('hello world')) - def test_delete(self): + def test_many_files(self): transfer = self._create_s3_transfer() - filename = self.files.create_file_with_size( - 'foo.txt', filesize=1024 * 1024 - ) - self.upload_file(filename, 'foo.txt') + key = '1mb.txt' + file_size = 1024 * 1024 + orig_file = self.files.create_file_with_size(key, file_size) + self.upload_file(orig_file, key) + + files = [] + base_filename = os.path.join(self.files.rootdir, 'file') + for i in range(10): + files.append(self.get_output_fileobj(base_filename + str(i))) with transfer: - future = transfer.delete(self.bucket_name, 'foo.txt') - future.result() - self.assertTrue(self.object_not_exists('foo.txt')) + for file in files: + transfer.download(self.bucket_name, key, file) + for download_file in files: + self._assert_files_equal(orig_file, download_file) - def test_many_files_download(self): + def test_cancel(self): transfer = self._create_s3_transfer() + key = 'foo.txt' + file_size = 20 * 1024 * 1024 + orig_file = self.files.create_file_with_size(key, file_size) + self.upload_file(orig_file, key) - filename = self.files.create_file_with_size( - '1mb.txt', filesize=1024 * 1024 - ) - self.upload_file(filename, '1mb.txt') + download_file = self.get_output_fileobj('downloaded.txt') + future = None + try: + with transfer: + future = transfer.download( + self.bucket_name, + key, + download_file, + subscribers=[self.record_subscriber], + ) + raise KeyboardInterrupt() + except KeyboardInterrupt: + pass - filenames = [] - base_filename = os.path.join(self.files.rootdir, 'file') - for i in range(10): - filenames.append(base_filename + str(i)) + with self.assertRaises(AwsCrtError) as err: + future.result() + self.assertEqual(err.name, 'AWS_ERROR_S3_CANCELED') - with transfer: - for filename in filenames: - transfer.download(self.bucket_name, '1mb.txt', filename) - for download_path in filenames: - assert_files_equal(filename, download_path) + # if passing filepath, assert that the file (and/or temp file) was removed + if isinstance(download_file, str): + possible_matches = glob.glob('%s*' % download_file) + self.assertEqual(possible_matches, []) + else: + download_file.close() + + self._assert_subscribers_called() - def test_many_files_upload(self): + +@requires_crt +class TestCRTDownloadSeekableStream(TestCRTDownload): + # Repeat download tests, but use seekable streams + def get_output_fileobj(self, name): + # Open stream to file on disk (vs just streaming to memory). + # This lets tests check the results of a download in the same way + # whether file path or file-like object was used. + filepath = super().get_output_fileobj(name) + return open(filepath, 'wb') + + +@requires_crt +class TestCRTDownloadNonSeekableStream(TestCRTDownload): + # Repeat download tests, but use nonseekable streams + def get_output_fileobj(self, name): + filepath = super().get_output_fileobj(name) + return create_nonseekable_writer(open(filepath, 'wb')) + + +@requires_crt +class TestCRTS3Transfers(BaseCRTS3TransfersTest): + # for misc non-upload-or-download CRTTransferManager tests + + def test_delete(self): transfer = self._create_s3_transfer() - keys = [] - filenames = [] - base_key = 'foo' - sufix = '.txt' - for i in range(10): - key = base_key + str(i) + sufix - keys.append(key) - filename = self.files.create_file_with_size( - key, filesize=1024 * 1024 - ) - filenames.append(filename) - self.addCleanup(self.delete_object, key) - with transfer: - for filename, key in zip(filenames, keys): - transfer.upload(filename, self.bucket_name, key) + key = 'foo.txt' + filename = self.files.create_file_with_size(key, filesize=1) + self.upload_file(filename, key) - for key in keys: - self.assertTrue(self.object_exists(key)) + with transfer: + future = transfer.delete(self.bucket_name, key) + future.result() + self.assertTrue(self.object_not_exists(key)) def test_many_files_delete(self): transfer = self._create_s3_transfer() keys = [] base_key = 'foo' - sufix = '.txt' + suffix = '.txt' filename = self.files.create_file_with_size( - '1mb.txt', filesize=1024 * 1024 + '1mb.txt', filesize=1 ) for i in range(10): - key = base_key + str(i) + sufix + key = base_key + str(i) + suffix keys.append(key) self.upload_file(filename, key) @@ -314,52 +412,3 @@ def test_many_files_delete(self): transfer.delete(self.bucket_name, key) for key in keys: self.assertTrue(self.object_not_exists(key)) - - def test_upload_cancel(self): - transfer = self._create_s3_transfer() - filename = self.files.create_file_with_size( - '20mb.txt', filesize=20 * 1024 * 1024 - ) - future = None - try: - with transfer: - future = transfer.upload( - filename, self.bucket_name, '20mb.txt' - ) - raise KeyboardInterrupt() - except KeyboardInterrupt: - pass - - with self.assertRaises(AwsCrtError) as cm: - future.result() - self.assertEqual(cm.name, 'AWS_ERROR_S3_CANCELED') - self.assertTrue(self.object_not_exists('20mb.txt')) - - def test_download_cancel(self): - transfer = self._create_s3_transfer() - filename = self.files.create_file_with_size( - 'foo.txt', filesize=20 * 1024 * 1024 - ) - self.upload_file(filename, 'foo.txt') - - download_path = os.path.join(self.files.rootdir, 'downloaded.txt') - future = None - try: - with transfer: - future = transfer.download( - self.bucket_name, - 'foo.txt', - download_path, - subscribers=[self.record_subscriber], - ) - raise KeyboardInterrupt() - except KeyboardInterrupt: - pass - - with self.assertRaises(AwsCrtError) as err: - future.result() - self.assertEqual(err.name, 'AWS_ERROR_S3_CANCELED') - - possible_matches = glob.glob('%s*' % download_path) - self.assertEqual(possible_matches, []) - self._assert_subscribers_called() diff --git a/tests/integration/test_download.py b/tests/integration/test_download.py index 6a07f933..12eb91af 100644 --- a/tests/integration/test_download.py +++ b/tests/integration/test_download.py @@ -18,9 +18,9 @@ from s3transfer.manager import TransferConfig from tests import ( - NonSeekableWriter, RecordingSubscriber, assert_files_equal, + create_nonseekable_writer, skip_if_using_serial_implementation, skip_if_windows, ) @@ -248,7 +248,7 @@ def test_below_threshold_for_nonseekable_fileobj(self): download_path = os.path.join(self.files.rootdir, '1mb.txt') with open(download_path, 'wb') as f: future = transfer_manager.download( - self.bucket_name, '1mb.txt', NonSeekableWriter(f) + self.bucket_name, '1mb.txt', create_nonseekable_writer(f) ) future.result() assert_files_equal(filename, download_path) @@ -264,7 +264,7 @@ def test_above_threshold_for_nonseekable_fileobj(self): download_path = os.path.join(self.files.rootdir, '20mb.txt') with open(download_path, 'wb') as f: future = transfer_manager.download( - self.bucket_name, '20mb.txt', NonSeekableWriter(f) + self.bucket_name, '20mb.txt', create_nonseekable_writer(f) ) future.result() assert_files_equal(filename, download_path) diff --git a/tests/unit/test_download.py b/tests/unit/test_download.py index 1907437d..201dd724 100644 --- a/tests/unit/test_download.py +++ b/tests/unit/test_download.py @@ -41,9 +41,9 @@ BaseSubmissionTaskTest, BaseTaskTest, FileCreator, - NonSeekableWriter, RecordingExecutor, StreamWithError, + create_nonseekable_writer, mock, unittest, ) @@ -539,7 +539,7 @@ def tests_submits_tag_for_get_object_nonseekable_fileobj(self): self.add_get_responses() with open(self.filename, 'wb') as f: - self.use_fileobj_in_call_args(NonSeekableWriter(f)) + self.use_fileobj_in_call_args(create_nonseekable_writer(f)) self.submission_task = self.get_download_submission_task() self.wait_and_assert_completed_successfully(self.submission_task) @@ -554,7 +554,7 @@ def tests_submits_tag_for_ranged_get_object_nonseekable_fileobj(self): self.add_get_responses() with open(self.filename, 'wb') as f: - self.use_fileobj_in_call_args(NonSeekableWriter(f)) + self.use_fileobj_in_call_args(create_nonseekable_writer(f)) self.submission_task = self.get_download_submission_task() self.wait_and_assert_completed_successfully(self.submission_task) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 83ce1265..8024edc6 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -43,7 +43,7 @@ invoke_progress_callbacks, random_file_extension, ) -from tests import NonSeekableWriter, RecordingSubscriber, mock, unittest +from tests import RecordingSubscriber, create_nonseekable_writer, mock, unittest class TestGetCallbacks(unittest.TestCase): @@ -361,7 +361,7 @@ def recording_open_function(self, filename, mode): def open_nonseekable(self, filename, mode): self.open_call_args.append((filename, mode)) - return NonSeekableWriter(BytesIO(self.content)) + return create_nonseekable_writer(BytesIO(self.content)) def test_instantiation_does_not_open_file(self): DeferredOpenFile(