diff --git a/README.md b/README.md
index cbacf41..8accc5c 100644
--- a/README.md
+++ b/README.md
@@ -45,10 +45,10 @@ With only **3k** lines of code, with **no dependencies** other than the [Python
* of course with [WebSocket support](https://nggit.github.io/tremolo-docs/reference/websocket/)
* Keep-Alive connections with [configurable limit](https://nggit.github.io/tremolo-docs/configuration.html#keepalive_connections)
* Stream chunked uploads
-* [Stream multipart uploads](https://nggit.github.io/tremolo-docs/body.html#multipart)
+* [Stream multipart uploads](https://nggit.github.io/tremolo-docs/body.html#multipart) with [per-part streaming](https://github.com/nggit/tremolo/pull/293)
* Download/upload speed throttling
* [Resumable downloads](https://nggit.github.io/tremolo-docs/resumable-downloads.html)
-* Framework features; routing, middleware, etc.
+* Framework features; routing, async/[sync handlers](https://nggit.github.io/tremolo-docs/handlers.html#synchronous-handlers), middleware, etc.
* ASGI server implementation
* PyPy compatible
diff --git a/examples/upload_and_save.py b/examples/upload_and_save.py
index dc4b487..1dc43e6 100644
--- a/examples/upload_and_save.py
+++ b/examples/upload_and_save.py
@@ -28,46 +28,55 @@ async def upload(request, response):
# it can still be read continuously bit by bit according to this size
files = request.files(max_file_size=16384) # 16KiB
- fp = None
- filename = None
+ # keep track of incomplete writings
+ incomplete = set()
try:
# read while writing the file(s).
+ # `part` represents a field/file received in a multipart request
async for part in files:
filename = quote(part.get('filename', ''))
if not filename:
continue
- if fp is None:
- fp = open('Uploaded_' + filename, 'wb')
-
- print('Writing %s (len=%d, eof=%s)' % (filename,
- len(part['data']),
- part['eof']))
- fp.write(part['data'])
-
- if part['eof']:
- fp.close()
- fp = None # the next iteration will be another file
- filename = filename.encode()
- content_type = quote(part['type']).encode()
-
- yield (
- b'File %s '
- b'was uploaded.
' % (content_type, filename, filename)
- )
+ with open('Uploaded_' + filename, 'wb') as fp:
+ incomplete.add(fp)
+
+ if part['eof']:
+ # this part is not larger than `max_file_size`.
+ # but you can skip this check and always use `part.stream`
+ fp.write(part['data'])
+ else:
+ # stream a (possibly) large part in chunks
+ async for data in part.stream():
+ print('Writing %s (len=%d, eof=%s)' % (filename,
+ len(data),
+ part['eof']))
+ fp.write(data)
+
+ incomplete.discard(fp) # completed :)
+
+ filename = filename.encode()
+ content_type = quote(part['type']).encode()
+
+ yield (
+ b'File %s '
+ b'was uploaded.
' % (content_type, filename, filename)
+ )
finally:
- if fp is not None:
- fp.close()
- print('Upload canceled, removing incomplete file: %s' % filename)
- os.unlink('Uploaded_' + filename)
+ while incomplete:
+ path = incomplete.pop().name
+ print('Upload canceled, removing incomplete file: %s' % path)
+ os.unlink(path)
yield b''
@app.route('/download')
async def download(request, response):
+ # prepend / append a hardcoded string.
+ # do not let the user freely determine the path
path = 'Uploaded_' + quote(request.query['filename'][0])
content_type = request.query['type'][0]
diff --git a/tea.yaml b/tea.yaml
deleted file mode 100644
index 0eebaba..0000000
--- a/tea.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-# https://tea.xyz/what-is-this-file
----
-version: 1.0.0
-codeOwners:
- - '0x06E2F6ddDb4C7230D5694905312d645e047F48B6'
-quorum: 1
diff --git a/tremolo/lib/http_protocol.py b/tremolo/lib/http_protocol.py
index c4fbf38..8eedb4b 100644
--- a/tremolo/lib/http_protocol.py
+++ b/tremolo/lib/http_protocol.py
@@ -146,7 +146,7 @@ async def put_to_queue(self, data, name=0, rate=-1):
if queue_size <= self.options['max_queue_size']:
if data and rate > 0 and queue_size > 0:
- await asyncio.sleep(1 / (rate / queue_size / len(data)))
+ await asyncio.sleep(queue_size * len(data) / rate)
return True
diff --git a/tremolo/lib/http_request.py b/tremolo/lib/http_request.py
index c7e26d0..d9f1bf5 100644
--- a/tremolo/lib/http_request.py
+++ b/tremolo/lib/http_request.py
@@ -11,11 +11,23 @@
from .request import Request
+class MultipartFile(dict):
+ def __init__(self, files):
+ self.files = files
+
+ async def stream(self):
+ if 'data' in self:
+ yield self.pop('data')
+
+ while not self['eof']:
+ yield (await self.files.__anext__()).pop('data')
+
+
class HTTPRequest(Request):
__slots__ = ('_ip', '_scheme', 'header', 'headers', 'is_valid',
'host', 'method', 'url', 'path', 'query_string', 'version',
'content_length', 'http_continue', 'http_keepalive',
- '_body', '_stream', '_read_buf')
+ '_body', '_read_buf', '_stream', '_files')
def __init__(self, protocol, header):
super().__init__(protocol)
@@ -42,9 +54,11 @@ def __init__(self, protocol, header):
self.content_length = -1
self.http_continue = False
self.http_keepalive = False
+
self._body = bytearray()
- self._stream = None
self._read_buf = bytearray()
+ self._stream = None
+ self._files = None
@property
def ip(self):
@@ -284,7 +298,7 @@ def cookies(self):
return self.params['cookies']
- async def form(self, max_size=8 * 1048576, max_fields=100):
+ async def form(self, max_fields=100, *, max_size=8 * 1048576):
try:
return self.params['post']
except KeyError as exc:
@@ -308,10 +322,18 @@ async def form(self, max_size=8 * 1048576, max_fields=100):
return self.params['post']
- async def files(self, max_files=1024, max_file_size=100 * 1048576):
+ async def files(self, max_files=1024, *, max_file_size=100 * 1048576):
if self.eof():
return
+ if self._files is None:
+ self._files = self.files(max_files, max_file_size=max_file_size)
+
+ async for part in self._files:
+ yield part
+
+ return
+
for key, boundary in parse_fields(self.content_type):
if key == b'boundary' and boundary:
break
@@ -325,7 +347,7 @@ async def files(self, max_files=1024, max_file_size=100 * 1048576):
body_size = 0
content_length = 0
paused = False
- part = None # represents a file received in a multipart request
+ part = None # represents a field/file received in a multipart request
if self._stream is None:
self._stream = self.stream()
@@ -359,7 +381,7 @@ async def files(self, max_files=1024, max_file_size=100 * 1048576):
paused = False
else:
body.extend(self._read_buf[header_size + 2:])
- part = {}
+ part = MultipartFile(self._files)
# use find() instead of startswith() to ignore the preamble
if self._read_buf.find(b'--%s\r\n' % boundary,
@@ -393,10 +415,9 @@ async def files(self, max_files=1024, max_file_size=100 * 1048576):
if body_size == -1:
if len(body) >= max_file_size > boundary_size + 4:
- sub_part = part.copy()
- sub_part['data'] = body[:-boundary_size - 4]
- sub_part['eof'] = False
- yield sub_part
+ part['data'] = bytes(body[:-boundary_size - 4])
+ part['eof'] = False
+ yield part
content_length = max(
content_length - (len(body) - boundary_size - 4), 0
@@ -406,7 +427,7 @@ async def files(self, max_files=1024, max_file_size=100 * 1048576):
paused = False
continue
- part['data'] = body[:body_size]
+ part['data'] = bytes(body[:body_size])
part['eof'] = True
yield part