diff --git a/src/connectrpc/_server_sync.py b/src/connectrpc/_server_sync.py index 73e1cd4..c7aa3c8 100644 --- a/src/connectrpc/_server_sync.py +++ b/src/connectrpc/_server_sync.py @@ -115,6 +115,34 @@ def prepare_response_headers( return headers +def _read_body_with_content_length( + environ: WSGIEnvironment, content_length: int +) -> bytes: + input_stream: BytesIO = environ["wsgi.input"] + + # Many app servers buffer the entire request before executing the app + # so do an optimistic read before looping. + chunk = input_stream.read(content_length) + if len(chunk) == content_length: + return chunk + + bytes_read = len(chunk) + chunks = [chunk] + while bytes_read < content_length: + to_read = content_length - bytes_read + chunk = input_stream.read(to_read) + if not chunk: + break + chunks.append(chunk) + bytes_read += len(chunk) + if bytes_read < content_length: + raise ConnectError( + Code.INVALID_ARGUMENT, + f"request truncated, expected {content_length} bytes but only received {bytes_read} bytes", + ) + return b"".join(chunks) + + def _read_body(environ: WSGIEnvironment) -> Iterator[bytes]: input_stream: BytesIO = environ["wsgi.input"] while True: @@ -257,7 +285,7 @@ def _handle_post_request( content_length = environ.get("CONTENT_LENGTH") content_length = 0 if not content_length else int(content_length) if content_length > 0: - req_body = environ["wsgi.input"].read(content_length) + req_body = _read_body_with_content_length(environ, content_length) else: req_body = b"".join(_read_body(environ))