diff --git a/.changes/next-release/bugfix-RequestPayer-75144.json b/.changes/next-release/bugfix-RequestPayer-75144.json new file mode 100644 index 00000000..a46f4628 --- /dev/null +++ b/.changes/next-release/bugfix-RequestPayer-75144.json @@ -0,0 +1,5 @@ +{ + "category": "``RequestPayer``", + "type": "bugfix", + "description": "Plumb ``RequestPayer` argument to the ``CompleteMultipartUpload` operation (`#103 `__)." +} diff --git a/s3transfer/copies.py b/s3transfer/copies.py index 5b598970..9c34d165 100644 --- a/s3transfer/copies.py +++ b/s3transfer/copies.py @@ -19,6 +19,7 @@ from s3transfer.tasks import CompleteMultipartUploadTask from s3transfer.utils import get_callbacks from s3transfer.utils import calculate_range_parameter +from s3transfer.utils import get_filtered_dict from s3transfer.utils import ChunksizeAdjuster @@ -61,6 +62,10 @@ class CopySubmissionTask(SubmissionTask): 'MetadataDirective' ] + COMPLETE_MULTIPART_ARGS = [ + 'RequestPayer' + ] + def _submit(self, client, config, osutil, request_executor, transfer_future): """ @@ -212,6 +217,8 @@ def _submit_multipart_request(self, client, config, osutil, ) ) + complete_multipart_extra_args = self._extra_complete_multipart_args( + call_args.extra_args) # Submit the request to complete the multipart upload. self._transfer_coordinator.submit( request_executor, @@ -220,7 +227,8 @@ def _submit_multipart_request(self, client, config, osutil, main_kwargs={ 'client': client, 'bucket': call_args.bucket, - 'key': call_args.key + 'key': call_args.key, + 'extra_args': complete_multipart_extra_args, }, pending_main_kwargs={ 'upload_id': create_multipart_future, @@ -244,11 +252,10 @@ def _get_head_object_request_from_copy_source(self, copy_source): def _extra_upload_part_args(self, extra_args): # Only the args in COPY_PART_ARGS actually need to be passed # onto the upload_part_copy calls. - upload_parts_args = {} - for key, value in extra_args.items(): - if key in self.UPLOAD_PART_COPY_ARGS: - upload_parts_args[key] = value - return upload_parts_args + return get_filtered_dict(extra_args, self.UPLOAD_PART_COPY_ARGS) + + def _extra_complete_multipart_args(self, extra_args): + return get_filtered_dict(extra_args, self.COMPLETE_MULTIPART_ARGS) def _get_transfer_size(self, part_size, part_index, num_parts, total_transfer_size): diff --git a/s3transfer/tasks.py b/s3transfer/tasks.py index ae4abdd5..1d314216 100644 --- a/s3transfer/tasks.py +++ b/s3transfer/tasks.py @@ -343,7 +343,7 @@ def _main(self, client, bucket, key, extra_args): class CompleteMultipartUploadTask(Task): """Task to complete a multipart upload""" - def _main(self, client, bucket, key, upload_id, parts): + def _main(self, client, bucket, key, upload_id, parts, extra_args): """ :param client: The client to use when calling CompleteMultipartUpload :param bucket: The name of the bucket to upload to @@ -355,7 +355,10 @@ def _main(self, client, bucket, key, upload_id, parts): Each element in the list consists of a return value from ``UploadPartTask.main()``. + :param extra_args: A dictionary of any extra arguments that may be + used in completing the multipart transfer. """ client.complete_multipart_upload( Bucket=bucket, Key=key, UploadId=upload_id, - MultipartUpload={'Parts': parts}) + MultipartUpload={'Parts': parts}, + **extra_args) diff --git a/s3transfer/upload.py b/s3transfer/upload.py index 0c2feda2..32cd4b90 100644 --- a/s3transfer/upload.py +++ b/s3transfer/upload.py @@ -21,6 +21,7 @@ from s3transfer.tasks import CreateMultipartUploadTask from s3transfer.tasks import CompleteMultipartUploadTask from s3transfer.utils import get_callbacks +from s3transfer.utils import get_filtered_dict from s3transfer.utils import DeferredOpenFile, ChunksizeAdjuster @@ -491,6 +492,10 @@ class UploadSubmissionTask(SubmissionTask): 'RequestPayer', ] + COMPLETE_MULTIPART_ARGS = [ + 'RequestPayer' + ] + def _get_upload_input_manager_cls(self, transfer_future): """Retieves a class for managing input for an upload based on file type @@ -636,6 +641,8 @@ def _submit_multipart_request(self, client, config, osutil, ) ) + complete_multipart_extra_args = self._extra_complete_multipart_args( + call_args.extra_args) # Submit the request to complete the multipart upload. self._transfer_coordinator.submit( request_executor, @@ -644,7 +651,8 @@ def _submit_multipart_request(self, client, config, osutil, main_kwargs={ 'client': client, 'bucket': call_args.bucket, - 'key': call_args.key + 'key': call_args.key, + 'extra_args': complete_multipart_extra_args, }, pending_main_kwargs={ 'upload_id': create_multipart_future, @@ -657,11 +665,10 @@ def _submit_multipart_request(self, client, config, osutil, def _extra_upload_part_args(self, extra_args): # Only the args in UPLOAD_PART_ARGS actually need to be passed # onto the upload_part calls. - upload_parts_args = {} - for key, value in extra_args.items(): - if key in self.UPLOAD_PART_ARGS: - upload_parts_args[key] = value - return upload_parts_args + return get_filtered_dict(extra_args, self.UPLOAD_PART_ARGS) + + def _extra_complete_multipart_args(self, extra_args): + return get_filtered_dict(extra_args, self.COMPLETE_MULTIPART_ARGS) def _get_upload_task_tag(self, upload_input_manager, operation_name): tag = None diff --git a/s3transfer/utils.py b/s3transfer/utils.py index fba56485..5ca4d9f0 100644 --- a/s3transfer/utils.py +++ b/s3transfer/utils.py @@ -126,6 +126,24 @@ def invoke_progress_callbacks(callbacks, bytes_transferred): callback(bytes_transferred=bytes_transferred) +def get_filtered_dict(original_dict, whitelisted_keys): + """Gets a dictionary filtered by whitelisted keys + + :param original_dict: The original dictionary of arguments to source keys + and values. + :param whitelisted_key: A list of keys to include in the filtered + dictionary. + + :returns: A dictionary containing key/values from the original dictionary + whose key was included in the whitelist + """ + filtered_dict = {} + for key, value in original_dict.items(): + if key in whitelisted_keys: + filtered_dict[key] = value + return filtered_dict + + class CallArgs(object): def __init__(self, **kwargs): """A class that records call arguments diff --git a/tests/functional/test_copy.py b/tests/functional/test_copy.py index 8c7d9713..53b239f0 100644 --- a/tests/functional/test_copy.py +++ b/tests/functional/test_copy.py @@ -408,7 +408,8 @@ def test_copy_with_extra_args(self): self.add_head_object_response(expected_params=head_params) self._add_params_to_expected_params( - add_copy_kwargs, ['create_mpu', 'copy'], self.extra_args) + add_copy_kwargs, ['create_mpu', 'copy', 'complete_mpu'], + self.extra_args) self.add_successful_copy_responses(**add_copy_kwargs) call_kwargs = self.create_call_kwargs() diff --git a/tests/functional/test_upload.py b/tests/functional/test_upload.py index 06ee2b0c..9d7690bd 100644 --- a/tests/functional/test_upload.py +++ b/tests/functional/test_upload.py @@ -336,7 +336,8 @@ def add_upload_part_responses_with_default_expected_params( upload_part_response['expected_params'] = expected_params self.stubber.add_response(**upload_part_response) - def add_complete_multipart_response_with_default_expected_params(self): + def add_complete_multipart_response_with_default_expected_params( + self, extra_expected_params=None): expected_params = { 'Bucket': self.bucket, 'Key': self.key, 'UploadId': self.multipart_id, @@ -348,6 +349,8 @@ def add_complete_multipart_response_with_default_expected_params(self): ] } } + if extra_expected_params: + expected_params.update(extra_expected_params) response = self.create_stubbed_responses()[-1] response['expected_params'] = expected_params self.stubber.add_response(**response) @@ -360,7 +363,8 @@ def test_upload(self): extra_expected_params={'RequestPayer': 'requester'}) self.add_upload_part_responses_with_default_expected_params( extra_expected_params={'RequestPayer': 'requester'}) - self.add_complete_multipart_response_with_default_expected_params() + self.add_complete_multipart_response_with_default_expected_params( + extra_expected_params={'RequestPayer': 'requester'}) future = self.manager.upload( self.filename, self.bucket, self.key, self.extra_args) diff --git a/tests/unit/test_tasks.py b/tests/unit/test_tasks.py index b126d8c1..e434246b 100644 --- a/tests/unit/test_tasks.py +++ b/tests/unit/test_tasks.py @@ -721,7 +721,8 @@ def test_main(self): 'bucket': self.bucket, 'key': self.key, 'upload_id': upload_id, - 'parts': parts + 'parts': parts, + 'extra_args': {} } ) self.stubber.add_response( @@ -737,3 +738,32 @@ def test_main(self): ) task() self.stubber.assert_no_pending_responses() + + def test_includes_extra_args(self): + upload_id = 'my-id' + parts = [{'ETag': 'etag', 'PartNumber': 0}] + task = self.get_task( + CompleteMultipartUploadTask, + main_kwargs={ + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'upload_id': upload_id, + 'parts': parts, + 'extra_args': {'RequestPayer': 'requester'} + } + ) + self.stubber.add_response( + method='complete_multipart_upload', + service_response={}, + expected_params={ + 'Bucket': self.bucket, 'Key': self.key, + 'UploadId': upload_id, + 'MultipartUpload': { + 'Parts': parts + }, + 'RequestPayer': 'requester' + } + ) + task() + self.stubber.assert_no_pending_responses() diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index f76b2699..26cb881e 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -30,6 +30,7 @@ from s3transfer.utils import random_file_extension from s3transfer.utils import invoke_progress_callbacks from s3transfer.utils import calculate_range_parameter +from s3transfer.utils import get_filtered_dict from s3transfer.utils import CallArgs from s3transfer.utils import FunctionContainer from s3transfer.utils import CountCallbackInvoker @@ -76,6 +77,19 @@ def test_get_callbacks_for_missing_type(self): self.assertEqual(len(callbacks), 0) +class TestGetFilteredDict(unittest.TestCase): + def test_get_filtered_dict(self): + original = { + 'Include': 'IncludeValue', + 'NotInlude': 'NotIncludeValue' + } + whitelist = ['Include'] + self.assertEqual( + get_filtered_dict(original, whitelist), + {'Include': 'IncludeValue'} + ) + + class TestCallArgs(unittest.TestCase): def test_call_args(self): call_args = CallArgs(foo='bar', biz='baz')