diff --git a/mangum/__init__.py b/mangum/__init__.py index f8090ec3..6364b08a 100644 --- a/mangum/__init__.py +++ b/mangum/__init__.py @@ -72,11 +72,10 @@ def on_response_start(self, headers: dict, status_code: int) -> None: self.response["headers"] = {k.decode(): v.decode() for k, v in headers.items()} def on_response_close(self) -> None: + body = self.body if self.binary: - body = base64.b64encode(self.body) - else: - body = self.body.decode() - self.response["body"] = body + body = base64.b64encode(body) + self.response["body"] = body.decode() class Mangum: @@ -131,16 +130,12 @@ def asgi(self, event: dict, context: dict) -> dict: } binary = event.get("isBase64Encoded", False) - body = event["body"] - - if body: - if binary: - body = base64.b64decode(body) + body = event["body"] or b"" - else: - body = body.encode() - else: - body = b"" + if binary: + body = base64.b64decode(body) + elif not isinstance(body, bytes): + body = body.encode() response = ASGICycle(scope, binary=binary)(self.app, body=body) return response diff --git a/tests/test_aws.py b/tests/test_aws.py index 894bdcd0..446b5af8 100644 --- a/tests/test_aws.py +++ b/tests/test_aws.py @@ -100,7 +100,7 @@ async def asgi(receive, send): "statusCode": 200, "isBase64Encoded": True, "headers": {"content-length": "3", "content-type": "text/plain; charset=utf-8"}, - "body": body_encoded, + "body": body_encoded.decode(), }