Skip to content

Commit

Permalink
Merge pull request #103 from kyleknap/request-payer
Browse files Browse the repository at this point in the history
Plumb request payer to complete multipart
  • Loading branch information
kyleknap committed Feb 14, 2018
2 parents f1c9ea7 + e377a20 commit 801d5f7
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 18 deletions.
5 changes: 5 additions & 0 deletions .changes/next-release/bugfix-RequestPayer-75144.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"category": "``RequestPayer``",
"type": "bugfix",
"description": "Plumb ``RequestPayer` argument to the ``CompleteMultipartUpload` operation (`#103 <https://github.com/boto/s3transfer/issues/103>`__)."
}
19 changes: 13 additions & 6 deletions s3transfer/copies.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from s3transfer.tasks import CompleteMultipartUploadTask
from s3transfer.utils import get_callbacks
from s3transfer.utils import calculate_range_parameter
from s3transfer.utils import get_filtered_dict
from s3transfer.utils import ChunksizeAdjuster


Expand Down Expand Up @@ -61,6 +62,10 @@ class CopySubmissionTask(SubmissionTask):
'MetadataDirective'
]

COMPLETE_MULTIPART_ARGS = [
'RequestPayer'
]

def _submit(self, client, config, osutil, request_executor,
transfer_future):
"""
Expand Down Expand Up @@ -212,6 +217,8 @@ def _submit_multipart_request(self, client, config, osutil,
)
)

complete_multipart_extra_args = self._extra_complete_multipart_args(
call_args.extra_args)
# Submit the request to complete the multipart upload.
self._transfer_coordinator.submit(
request_executor,
Expand All @@ -220,7 +227,8 @@ def _submit_multipart_request(self, client, config, osutil,
main_kwargs={
'client': client,
'bucket': call_args.bucket,
'key': call_args.key
'key': call_args.key,
'extra_args': complete_multipart_extra_args,
},
pending_main_kwargs={
'upload_id': create_multipart_future,
Expand All @@ -244,11 +252,10 @@ def _get_head_object_request_from_copy_source(self, copy_source):
def _extra_upload_part_args(self, extra_args):
# Only the args in COPY_PART_ARGS actually need to be passed
# onto the upload_part_copy calls.
upload_parts_args = {}
for key, value in extra_args.items():
if key in self.UPLOAD_PART_COPY_ARGS:
upload_parts_args[key] = value
return upload_parts_args
return get_filtered_dict(extra_args, self.UPLOAD_PART_COPY_ARGS)

def _extra_complete_multipart_args(self, extra_args):
return get_filtered_dict(extra_args, self.COMPLETE_MULTIPART_ARGS)

def _get_transfer_size(self, part_size, part_index, num_parts,
total_transfer_size):
Expand Down
7 changes: 5 additions & 2 deletions s3transfer/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def _main(self, client, bucket, key, extra_args):

class CompleteMultipartUploadTask(Task):
"""Task to complete a multipart upload"""
def _main(self, client, bucket, key, upload_id, parts):
def _main(self, client, bucket, key, upload_id, parts, extra_args):
"""
:param client: The client to use when calling CompleteMultipartUpload
:param bucket: The name of the bucket to upload to
Expand All @@ -355,7 +355,10 @@ def _main(self, client, bucket, key, upload_id, parts):
Each element in the list consists of a return value from
``UploadPartTask.main()``.
:param extra_args: A dictionary of any extra arguments that may be
used in completing the multipart transfer.
"""
client.complete_multipart_upload(
Bucket=bucket, Key=key, UploadId=upload_id,
MultipartUpload={'Parts': parts})
MultipartUpload={'Parts': parts},
**extra_args)
19 changes: 13 additions & 6 deletions s3transfer/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from s3transfer.tasks import CreateMultipartUploadTask
from s3transfer.tasks import CompleteMultipartUploadTask
from s3transfer.utils import get_callbacks
from s3transfer.utils import get_filtered_dict
from s3transfer.utils import DeferredOpenFile, ChunksizeAdjuster


Expand Down Expand Up @@ -491,6 +492,10 @@ class UploadSubmissionTask(SubmissionTask):
'RequestPayer',
]

COMPLETE_MULTIPART_ARGS = [
'RequestPayer'
]

def _get_upload_input_manager_cls(self, transfer_future):
"""Retieves a class for managing input for an upload based on file type
Expand Down Expand Up @@ -636,6 +641,8 @@ def _submit_multipart_request(self, client, config, osutil,
)
)

complete_multipart_extra_args = self._extra_complete_multipart_args(
call_args.extra_args)
# Submit the request to complete the multipart upload.
self._transfer_coordinator.submit(
request_executor,
Expand All @@ -644,7 +651,8 @@ def _submit_multipart_request(self, client, config, osutil,
main_kwargs={
'client': client,
'bucket': call_args.bucket,
'key': call_args.key
'key': call_args.key,
'extra_args': complete_multipart_extra_args,
},
pending_main_kwargs={
'upload_id': create_multipart_future,
Expand All @@ -657,11 +665,10 @@ def _submit_multipart_request(self, client, config, osutil,
def _extra_upload_part_args(self, extra_args):
# Only the args in UPLOAD_PART_ARGS actually need to be passed
# onto the upload_part calls.
upload_parts_args = {}
for key, value in extra_args.items():
if key in self.UPLOAD_PART_ARGS:
upload_parts_args[key] = value
return upload_parts_args
return get_filtered_dict(extra_args, self.UPLOAD_PART_ARGS)

def _extra_complete_multipart_args(self, extra_args):
return get_filtered_dict(extra_args, self.COMPLETE_MULTIPART_ARGS)

def _get_upload_task_tag(self, upload_input_manager, operation_name):
tag = None
Expand Down
18 changes: 18 additions & 0 deletions s3transfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,24 @@ def invoke_progress_callbacks(callbacks, bytes_transferred):
callback(bytes_transferred=bytes_transferred)


def get_filtered_dict(original_dict, whitelisted_keys):
"""Gets a dictionary filtered by whitelisted keys
:param original_dict: The original dictionary of arguments to source keys
and values.
:param whitelisted_key: A list of keys to include in the filtered
dictionary.
:returns: A dictionary containing key/values from the original dictionary
whose key was included in the whitelist
"""
filtered_dict = {}
for key, value in original_dict.items():
if key in whitelisted_keys:
filtered_dict[key] = value
return filtered_dict


class CallArgs(object):
def __init__(self, **kwargs):
"""A class that records call arguments
Expand Down
3 changes: 2 additions & 1 deletion tests/functional/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,8 @@ def test_copy_with_extra_args(self):
self.add_head_object_response(expected_params=head_params)

self._add_params_to_expected_params(
add_copy_kwargs, ['create_mpu', 'copy'], self.extra_args)
add_copy_kwargs, ['create_mpu', 'copy', 'complete_mpu'],
self.extra_args)
self.add_successful_copy_responses(**add_copy_kwargs)

call_kwargs = self.create_call_kwargs()
Expand Down
8 changes: 6 additions & 2 deletions tests/functional/test_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def add_upload_part_responses_with_default_expected_params(
upload_part_response['expected_params'] = expected_params
self.stubber.add_response(**upload_part_response)

def add_complete_multipart_response_with_default_expected_params(self):
def add_complete_multipart_response_with_default_expected_params(
self, extra_expected_params=None):
expected_params = {
'Bucket': self.bucket,
'Key': self.key, 'UploadId': self.multipart_id,
Expand All @@ -348,6 +349,8 @@ def add_complete_multipart_response_with_default_expected_params(self):
]
}
}
if extra_expected_params:
expected_params.update(extra_expected_params)
response = self.create_stubbed_responses()[-1]
response['expected_params'] = expected_params
self.stubber.add_response(**response)
Expand All @@ -360,7 +363,8 @@ def test_upload(self):
extra_expected_params={'RequestPayer': 'requester'})
self.add_upload_part_responses_with_default_expected_params(
extra_expected_params={'RequestPayer': 'requester'})
self.add_complete_multipart_response_with_default_expected_params()
self.add_complete_multipart_response_with_default_expected_params(
extra_expected_params={'RequestPayer': 'requester'})

future = self.manager.upload(
self.filename, self.bucket, self.key, self.extra_args)
Expand Down
32 changes: 31 additions & 1 deletion tests/unit/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,8 @@ def test_main(self):
'bucket': self.bucket,
'key': self.key,
'upload_id': upload_id,
'parts': parts
'parts': parts,
'extra_args': {}
}
)
self.stubber.add_response(
Expand All @@ -737,3 +738,32 @@ def test_main(self):
)
task()
self.stubber.assert_no_pending_responses()

def test_includes_extra_args(self):
upload_id = 'my-id'
parts = [{'ETag': 'etag', 'PartNumber': 0}]
task = self.get_task(
CompleteMultipartUploadTask,
main_kwargs={
'client': self.client,
'bucket': self.bucket,
'key': self.key,
'upload_id': upload_id,
'parts': parts,
'extra_args': {'RequestPayer': 'requester'}
}
)
self.stubber.add_response(
method='complete_multipart_upload',
service_response={},
expected_params={
'Bucket': self.bucket, 'Key': self.key,
'UploadId': upload_id,
'MultipartUpload': {
'Parts': parts
},
'RequestPayer': 'requester'
}
)
task()
self.stubber.assert_no_pending_responses()
14 changes: 14 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from s3transfer.utils import random_file_extension
from s3transfer.utils import invoke_progress_callbacks
from s3transfer.utils import calculate_range_parameter
from s3transfer.utils import get_filtered_dict
from s3transfer.utils import CallArgs
from s3transfer.utils import FunctionContainer
from s3transfer.utils import CountCallbackInvoker
Expand Down Expand Up @@ -76,6 +77,19 @@ def test_get_callbacks_for_missing_type(self):
self.assertEqual(len(callbacks), 0)


class TestGetFilteredDict(unittest.TestCase):
def test_get_filtered_dict(self):
original = {
'Include': 'IncludeValue',
'NotInlude': 'NotIncludeValue'
}
whitelist = ['Include']
self.assertEqual(
get_filtered_dict(original, whitelist),
{'Include': 'IncludeValue'}
)


class TestCallArgs(unittest.TestCase):
def test_call_args(self):
call_args = CallArgs(foo='bar', biz='baz')
Expand Down

0 comments on commit 801d5f7

Please sign in to comment.