Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming fileobj support to CRTTransferManager #277

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 104 additions & 24 deletions s3transfer/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,24 @@
EventLoopGroup,
TlsContextOptions,
)
from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType
from awscrt.s3 import (
S3ChecksumAlgorithm,
S3ChecksumConfig,
S3ChecksumLocation,
S3Client,
S3RequestTlsMode,
S3RequestType,
)
from botocore import UNSIGNED
from botocore.compat import urlsplit
from botocore.config import Config
from botocore.exceptions import NoCredentialsError

from s3transfer.compat import seekable
from s3transfer.constants import GB, MB
from s3transfer.exceptions import TransferNotDoneError
from s3transfer.futures import BaseTransferFuture, BaseTransferMeta
from s3transfer.manager import TransferManager
from s3transfer.utils import CallArgs, OSUtils, get_callbacks

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,7 +76,7 @@ def create_s3_crt_client(
region,
botocore_credential_provider=None,
num_threads=None,
target_throughput=5 * GB / 8,
target_throughput=5_000_000_000.0 / 8,
part_size=8 * MB,
use_ssl=True,
verify=None,
Expand All @@ -86,8 +95,8 @@ def create_s3_crt_client(
is the number of processors in the machine.

:type target_throughput: Optional[int]
:param target_throughput: Throughput target in Bytes.
Default is 0.625 GB/s (which translates to 5 Gb/s).
:param target_throughput: Throughput target in bytes per second.
Default translates to 5.0 Gb/s or 0.582 GiB/s.

:type part_size: Optional[int]
:param part_size: Size, in Bytes, of parts that files will be downloaded
Expand Down Expand Up @@ -137,19 +146,24 @@ def create_s3_crt_client(
credentails_provider_adapter
)

target_gbps = target_throughput * 8 / GB
target_gigabits = target_throughput * 8 / 1_000_000_000.0
return S3Client(
bootstrap=bootstrap,
region=region,
credential_provider=provider,
part_size=part_size,
tls_mode=tls_mode,
tls_connection_options=tls_connection_options,
throughput_target_gbps=target_gbps,
throughput_target_gbps=target_gigabits,
)


class CRTTransferManager:

ALLOWED_DOWNLOAD_ARGS = TransferManager.ALLOWED_DOWNLOAD_ARGS
ALLOWED_UPLOAD_ARGS = TransferManager.ALLOWED_UPLOAD_ARGS
ALLOWED_DELETE_ARGS = TransferManager.ALLOWED_DELETE_ARGS
Comment on lines +163 to +165
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I make a common base class for TransferManager and CRTTransferManager, where I can put stuff like these lists, and def _validate_all_known_args() and def _validate_if_bucket_supported()?

Or move these to be standalone lists/functions that both classes can use?

Or just copy/paste the functions, and reference the lists, like I'm doing here?


def __init__(self, crt_s3_client, crt_request_serializer, osutil=None):
"""A transfer manager interface for Amazon S3 on CRT s3 client.

Expand Down Expand Up @@ -192,6 +206,8 @@ def download(
extra_args = {}
if subscribers is None:
subscribers = {}
self._validate_all_known_args(extra_args, TransferManager.ALLOWED_DOWNLOAD_ARGS)
# TODO: _validate_if_bucket_supported() ???
callargs = CallArgs(
bucket=bucket,
key=key,
Expand All @@ -206,6 +222,7 @@ def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None):
extra_args = {}
if subscribers is None:
subscribers = {}
self._validate_all_known_args(extra_args, TransferManager.ALLOWED_UPLOAD_ARGS)
callargs = CallArgs(
bucket=bucket,
key=key,
Expand All @@ -220,6 +237,7 @@ def delete(self, bucket, key, extra_args=None, subscribers=None):
extra_args = {}
if subscribers is None:
subscribers = {}
self._validate_all_known_args(extra_args, TransferManager.ALLOWED_DELETE_ARGS)
callargs = CallArgs(
bucket=bucket,
key=key,
Expand Down Expand Up @@ -260,6 +278,14 @@ def _shutdown(self, cancel=False):
def _release_semaphore(self, **kwargs):
self._semaphore.release()

def _validate_all_known_args(self, actual, allowed):
for kwarg in actual:
if kwarg not in allowed:
raise ValueError(
"Invalid extra_args key '%s', "
"must be one of: %s" % (kwarg, ', '.join(allowed))
)

def _submit_transfer(self, request_type, call_args):
on_done_after_calls = [self._release_semaphore]
coordinator = CRTTransferCoordinator(transfer_id=self._id_counter)
Expand Down Expand Up @@ -359,7 +385,7 @@ def set_exception(self, exception):


class BaseCRTRequestSerializer:
def serialize_http_request(self, transfer_type, future):
def serialize_http_request(self, transfer_type, future, fileobj):
"""Serialize CRT HTTP requests.

:type transfer_type: string
Expand Down Expand Up @@ -428,19 +454,12 @@ def _crt_request_from_aws_request(self, aws_request):
headers_list.append((name, str(value, 'utf-8')))

crt_headers = awscrt.http.HttpHeaders(headers_list)
# CRT requires body (if it exists) to be an I/O stream.
crt_body_stream = None
if aws_request.body:
if hasattr(aws_request.body, 'seek'):
crt_body_stream = aws_request.body
else:
crt_body_stream = BytesIO(aws_request.body)

crt_request = awscrt.http.HttpRequest(
method=aws_request.method,
path=crt_path,
headers=crt_headers,
body_stream=crt_body_stream,
body_stream=aws_request.body,
)
return crt_request

Expand All @@ -451,8 +470,24 @@ def _convert_to_crt_http_request(self, botocore_http_request):
# If host is not set, set it for the request before using CRT s3
url_parts = urlsplit(botocore_http_request.url)
crt_request.headers.set("host", url_parts.netloc)

# Remove bogus Content-MD5 value (see comment elsewhere in file)
if crt_request.headers.get('Content-MD5') is not None:
crt_request.headers.remove("Content-MD5")

# Explicitly set "Content-Length: 0" when there's no body.
# Botocore doesn't bother setting this, but CRT likes to know.
# Note that Content-Length SHOULD be absent if body is nonseekable.
if crt_request.headers.get('Content-Length') is None:
if botocore_http_request.body is None:
crt_request.headers.add('Content-Length', "0")

# Remove "Transfer-Encoding: chunked".
# Botocore sets this on nonseekable streams,
# but CRT currently chokes on this header (TODO: fix this in CRT)
if crt_request.headers.get('Transfer-Encoding') is not None:
crt_request.headers.remove('Transfer-Encoding')

return crt_request

def _capture_http_request(self, request, **kwargs):
Expand Down Expand Up @@ -556,22 +591,57 @@ def get_make_request_args(
self, request_type, call_args, coordinator, future, on_done_after_calls
):
recv_filepath = None
on_body = None
send_filepath = None
s3_meta_request_type = getattr(
S3RequestType, request_type.upper(), S3RequestType.DEFAULT
)
on_done_before_calls = []
checksum_config = S3ChecksumConfig()

if s3_meta_request_type == S3RequestType.GET_OBJECT:
final_filepath = call_args.fileobj
recv_filepath = self._os_utils.get_temp_filename(final_filepath)
file_ondone_call = RenameTempFileHandler(
coordinator, final_filepath, recv_filepath, self._os_utils
)
on_done_before_calls.append(file_ondone_call)
if isinstance(call_args.fileobj, str):
# fileobj is a filepath
final_filepath = call_args.fileobj
recv_filepath = self._os_utils.get_temp_filename(final_filepath)
file_ondone_call = RenameTempFileHandler(
coordinator, final_filepath, recv_filepath, self._os_utils
)
on_done_before_calls.append(file_ondone_call)

elif call_args.fileobj is not None:
# fileobj is a file-like object
response_handler = _FileobjResponseHandler(call_args.fileobj)
on_body = response_handler.on_body

# Only validate response checksums when downloading.
# (upload responses also have checksum headers, but they're just an
# echo of what was in the request, an upload response's body is empty)
checksum_config.validate_response = True

elif s3_meta_request_type == S3RequestType.PUT_OBJECT:
send_filepath = call_args.fileobj
data_len = self._os_utils.get_file_size(send_filepath)
call_args.extra_args["ContentLength"] = data_len
if isinstance(call_args.fileobj, str):
# fileobj is a filepath
send_filepath = call_args.fileobj
data_len = self._os_utils.get_file_size(send_filepath)
call_args.extra_args["ContentLength"] = data_len

elif call_args.fileobj is not None:
# fileobj is a file-like object
call_args.extra_args["Body"] = call_args.fileobj

# We want the CRT S3Client to calculate checksums, not botocore.
# Default to CRC32.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is different than the default TransferManager. With this change CRT is defaulting to CRC32 instead of Content-MD5. The CRT team got grumpy about how Content-MD5 currently works in our code and pushed me to do it this way, but we can still make Content-MD5 work if we need to.

if call_args.extra_args.get('ChecksumAlgorithm') is not None:
algorithm_name = call_args.extra_args.pop('ChecksumAlgorithm')
checksum_config.algorithm = S3ChecksumAlgorithm[algorithm_name]
else:
checksum_config.algorithm = S3ChecksumAlgorithm.CRC32
checksum_config.location = S3ChecksumLocation.TRAILER

# Suppress botocore's MD5 calculation by setting a bogus value.
# (this header gets removed before the request is passed to CRT)
call_args.extra_args["ContentMD5"] = "bogus value deleted later"

crt_request = self._request_serializer.serialize_http_request(
request_type, future
Expand All @@ -582,6 +652,8 @@ def get_make_request_args(
'type': s3_meta_request_type,
'recv_filepath': recv_filepath,
'send_filepath': send_filepath,
'checksum_config': checksum_config,
'on_body': on_body,
'on_done': self.get_crt_callback(
future, 'done', on_done_before_calls, on_done_after_calls
),
Expand Down Expand Up @@ -642,3 +714,11 @@ def __init__(self, coordinator):

def __call__(self, **kwargs):
self._coordinator.set_done_callbacks_complete()


class _FileobjResponseHandler:
def __init__(self, fileobj):
self._fileobj = fileobj

def on_body(self, chunk: bytes, offset: int, **kwargs):
self._fileobj.write(chunk)
20 changes: 12 additions & 8 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,26 +506,30 @@ def write(self, b):
# kind of error even though writeable returns False.
raise io.UnsupportedOperation("write")

def read(self, n=-1):
return self._data.read(n)
def readinto(self, b):
return self._data.readinto(b)


class NonSeekableWriter(io.RawIOBase):
def __init__(self, fileobj):
super().__init__()
self._fileobj = fileobj
def create_nonseekable_writer(fileobj):

def seekable(self):
return False

fileobj.seekable = seekable

def writable(self):
return True

fileobj.writable = writable

def readable(self):
return False

def write(self, b):
self._fileobj.write(b)
fileobj.readable = readable

def read(self, n=-1):
raise io.UnsupportedOperation("read")

fileobj.read = read

return fileobj
Loading