diff --git a/README.md b/README.md index cbacf41..8accc5c 100644 --- a/README.md +++ b/README.md @@ -45,10 +45,10 @@ With only **3k** lines of code, with **no dependencies** other than the [Python * of course with [WebSocket support](https://nggit.github.io/tremolo-docs/reference/websocket/) * Keep-Alive connections with [configurable limit](https://nggit.github.io/tremolo-docs/configuration.html#keepalive_connections) * Stream chunked uploads -* [Stream multipart uploads](https://nggit.github.io/tremolo-docs/body.html#multipart) +* [Stream multipart uploads](https://nggit.github.io/tremolo-docs/body.html#multipart) with [per-part streaming](https://github.com/nggit/tremolo/pull/293) * Download/upload speed throttling * [Resumable downloads](https://nggit.github.io/tremolo-docs/resumable-downloads.html) -* Framework features; routing, middleware, etc. +* Framework features; routing, async/[sync handlers](https://nggit.github.io/tremolo-docs/handlers.html#synchronous-handlers), middleware, etc. * ASGI server implementation * PyPy compatible diff --git a/examples/upload_and_save.py b/examples/upload_and_save.py index dc4b487..1dc43e6 100644 --- a/examples/upload_and_save.py +++ b/examples/upload_and_save.py @@ -28,46 +28,55 @@ async def upload(request, response): # it can still be read continuously bit by bit according to this size files = request.files(max_file_size=16384) # 16KiB - fp = None - filename = None + # keep track of incomplete writings + incomplete = set() try: # read while writing the file(s). + # `part` represents a field/file received in a multipart request async for part in files: filename = quote(part.get('filename', '')) if not filename: continue - if fp is None: - fp = open('Uploaded_' + filename, 'wb') - - print('Writing %s (len=%d, eof=%s)' % (filename, - len(part['data']), - part['eof'])) - fp.write(part['data']) - - if part['eof']: - fp.close() - fp = None # the next iteration will be another file - filename = filename.encode() - content_type = quote(part['type']).encode() - - yield ( - b'File %s ' - b'was uploaded.
' % (content_type, filename, filename) - ) + with open('Uploaded_' + filename, 'wb') as fp: + incomplete.add(fp) + + if part['eof']: + # this part is not larger than `max_file_size`. + # but you can skip this check and always use `part.stream` + fp.write(part['data']) + else: + # stream a (possibly) large part in chunks + async for data in part.stream(): + print('Writing %s (len=%d, eof=%s)' % (filename, + len(data), + part['eof'])) + fp.write(data) + + incomplete.discard(fp) # completed :) + + filename = filename.encode() + content_type = quote(part['type']).encode() + + yield ( + b'File %s ' + b'was uploaded.
' % (content_type, filename, filename) + ) finally: - if fp is not None: - fp.close() - print('Upload canceled, removing incomplete file: %s' % filename) - os.unlink('Uploaded_' + filename) + while incomplete: + path = incomplete.pop().name + print('Upload canceled, removing incomplete file: %s' % path) + os.unlink(path) yield b'' @app.route('/download') async def download(request, response): + # prepend / append a hardcoded string. + # do not let the user freely determine the path path = 'Uploaded_' + quote(request.query['filename'][0]) content_type = request.query['type'][0] diff --git a/tea.yaml b/tea.yaml deleted file mode 100644 index 0eebaba..0000000 --- a/tea.yaml +++ /dev/null @@ -1,6 +0,0 @@ -# https://tea.xyz/what-is-this-file ---- -version: 1.0.0 -codeOwners: - - '0x06E2F6ddDb4C7230D5694905312d645e047F48B6' -quorum: 1 diff --git a/tremolo/lib/http_protocol.py b/tremolo/lib/http_protocol.py index c4fbf38..8eedb4b 100644 --- a/tremolo/lib/http_protocol.py +++ b/tremolo/lib/http_protocol.py @@ -146,7 +146,7 @@ async def put_to_queue(self, data, name=0, rate=-1): if queue_size <= self.options['max_queue_size']: if data and rate > 0 and queue_size > 0: - await asyncio.sleep(1 / (rate / queue_size / len(data))) + await asyncio.sleep(queue_size * len(data) / rate) return True diff --git a/tremolo/lib/http_request.py b/tremolo/lib/http_request.py index c7e26d0..d9f1bf5 100644 --- a/tremolo/lib/http_request.py +++ b/tremolo/lib/http_request.py @@ -11,11 +11,23 @@ from .request import Request +class MultipartFile(dict): + def __init__(self, files): + self.files = files + + async def stream(self): + if 'data' in self: + yield self.pop('data') + + while not self['eof']: + yield (await self.files.__anext__()).pop('data') + + class HTTPRequest(Request): __slots__ = ('_ip', '_scheme', 'header', 'headers', 'is_valid', 'host', 'method', 'url', 'path', 'query_string', 'version', 'content_length', 'http_continue', 'http_keepalive', - '_body', '_stream', '_read_buf') + '_body', '_read_buf', '_stream', '_files') def __init__(self, protocol, header): super().__init__(protocol) @@ -42,9 +54,11 @@ def __init__(self, protocol, header): self.content_length = -1 self.http_continue = False self.http_keepalive = False + self._body = bytearray() - self._stream = None self._read_buf = bytearray() + self._stream = None + self._files = None @property def ip(self): @@ -284,7 +298,7 @@ def cookies(self): return self.params['cookies'] - async def form(self, max_size=8 * 1048576, max_fields=100): + async def form(self, max_fields=100, *, max_size=8 * 1048576): try: return self.params['post'] except KeyError as exc: @@ -308,10 +322,18 @@ async def form(self, max_size=8 * 1048576, max_fields=100): return self.params['post'] - async def files(self, max_files=1024, max_file_size=100 * 1048576): + async def files(self, max_files=1024, *, max_file_size=100 * 1048576): if self.eof(): return + if self._files is None: + self._files = self.files(max_files, max_file_size=max_file_size) + + async for part in self._files: + yield part + + return + for key, boundary in parse_fields(self.content_type): if key == b'boundary' and boundary: break @@ -325,7 +347,7 @@ async def files(self, max_files=1024, max_file_size=100 * 1048576): body_size = 0 content_length = 0 paused = False - part = None # represents a file received in a multipart request + part = None # represents a field/file received in a multipart request if self._stream is None: self._stream = self.stream() @@ -359,7 +381,7 @@ async def files(self, max_files=1024, max_file_size=100 * 1048576): paused = False else: body.extend(self._read_buf[header_size + 2:]) - part = {} + part = MultipartFile(self._files) # use find() instead of startswith() to ignore the preamble if self._read_buf.find(b'--%s\r\n' % boundary, @@ -393,10 +415,9 @@ async def files(self, max_files=1024, max_file_size=100 * 1048576): if body_size == -1: if len(body) >= max_file_size > boundary_size + 4: - sub_part = part.copy() - sub_part['data'] = body[:-boundary_size - 4] - sub_part['eof'] = False - yield sub_part + part['data'] = bytes(body[:-boundary_size - 4]) + part['eof'] = False + yield part content_length = max( content_length - (len(body) - boundary_size - 4), 0 @@ -406,7 +427,7 @@ async def files(self, max_files=1024, max_file_size=100 * 1048576): paused = False continue - part['data'] = body[:body_size] + part['data'] = bytes(body[:body_size]) part['eof'] = True yield part