Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consume request body in middleware is problematic #495

Closed
yihuang opened this issue Apr 30, 2019 · 35 comments · Fixed by #1692
Closed

Consume request body in middleware is problematic #495

yihuang opened this issue Apr 30, 2019 · 35 comments · Fixed by #1692

Comments

@yihuang
Copy link

yihuang commented Apr 30, 2019

from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.middleware.base import BaseHTTPMiddleware


class SampleMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        _ = await request.form()
        return await call_next(request)


app = Starlette()


@app.route('/test', methods=['POST'])
async def test(request):
    _ = await request.form()  # blocked, because middleware already consumed request body
    return PlainTextResponse('Hello, world!')


app.add_middleware(SampleMiddleware)
$ uvicorn test:app --reload
$ curl -d "a=1" http://127.0.0.1:8000/test
# request is blocked
@blueyed
Copy link
Contributor

blueyed commented May 1, 2019

Where is it blocking?

Is there a second request instance involved? (otherwise it should just return, shouldn't it?)

async def form(self) -> FormData:
if not hasattr(self, "_form"):
assert (
parse_options_header is not None
), "The `python-multipart` library must be installed to use form parsing."
content_type_header = self.headers.get("Content-Type")
content_type, options = parse_options_header(content_type_header)
if content_type == b"multipart/form-data":
multipart_parser = MultiPartParser(self.headers, self.stream())
self._form = await multipart_parser.parse()
elif content_type == b"application/x-www-form-urlencoded":
form_parser = FormParser(self.headers, self.stream())
self._form = await form_parser.parse()
else:
self._form = FormData()
return self._form

@blueyed
Copy link
Contributor

blueyed commented May 1, 2019

Looks like it:

request = Request(scope, receive=receive)
and
request = Request(scope, receive=receive)
.

Can you write a test for starlette itself?

(just for reference: #498)

@tomchristie
Copy link
Member

"Consume request body in middleware is problematic"

Indeed. Consuming request data in middleware is problematic.
Not just to Starlette, but generally, everywhere.

On the whole you should avoid doing so if at all possible.

There's some work we could do to make it work better, but there will always be cases where it can't work (eg. if you stream the request data, then it's just not going to be available anymore later down the line).

@wyfo
Copy link

wyfo commented Jul 19, 2019

Coming from FastAPI issue referenced above.

@tomchristie I don't understand the issue about consuming the request in a middleware, could you explain this point ? In fact, I have the need (which is current where I work) to log every requests received by my production server. Is there a better place than a middleware to do it and avoid duplicating code in every endpoint ? (I would like something like this https://github.com/Rhumbix/django-request-logging)

For now, I found this workaround (but that's not very pretty):

async def set_body(request: Request, body: bytes):
    async def receive() -> Message:
        return {"type": "http.request", "body": body}

    request._receive = receive

async def get_body(request: Request) -> bytes:
    body = await request.body()
    set_body(request, body)
    return body

but there will always be cases where it can't work (eg. if you stream the request data, then it's just not going to be available anymore later down the line)

I kind of disagree with your example. In fact, stream data is not stored by default, but stream metadata (is the stream closed) are; there will be an understandable error raised if someone try to stream twice, and that is enough imho. That's why if the body is cached, the stream consumption has to be cached too.

@lvalladares
Copy link

There are plenty use cases like @wyfo mention. In my case i'm using JWT Signature to check all the data integrity, so i decode the token and then compare the decoded result with the body + query + path params of the request. I don't know any better way of doing this

@amorey
Copy link

amorey commented Jan 16, 2020

I'm working on a WTForms plugin for Starlette and I'm running into a similar issue. What is the recommended way to consume requst.form() in middleware?

@dmontagu
Copy link
Contributor

dmontagu commented Jan 16, 2020

Currently, this is one of the better options: https://fastapi.tiangolo.com/advanced/custom-request-and-route/#accessing-the-request-body-in-an-exception-handler but unfortunately there still isn't a great way to do this as far as I'm aware.

@amorey
Copy link

amorey commented Jan 17, 2020

Thanks! Is there an equivalent to before_request() in Flask?
https://flask.palletsprojects.com/en/1.1.x/api/#flask.Flask.before_request

@dmontagu
Copy link
Contributor

dmontagu commented Jan 18, 2020

@amorey Depending on exactly what you want to do, you can either create a custom middleware that only does things before the request, or you can create a dependency that does whatever setup you need, and include it in the router or endpoint decorator if you don't need access to its return value and don't want to have an unused injected argument.

I think that's the closest you'll get to a direct translation of the before_request function, but I'm not very knowledgeable about flask.

@JHBalaji
Copy link

JHBalaji commented Jun 8, 2020

@dmontagu I am using FastAPI and trying to log all the 500 Internal Server error only by using the ServerErrorMiddleware from Starlette using add_middleware in FastAPI.
Is there a way to the request JSON body in this case? It appears to be that I could consume the JSON body in HTTPException and RequestValidationError with add_exception_handler but nothing with ServerErrorMiddleware. Indeed ServerErrorMiddleware has request info but not the JSON Body.

Conversation from Gitter but lmk if anybody have feedbacks

    async def exception_handler(request, exc) -> Response:
        print(exc)
        return PlainTextResponse('Yes! it works!')

    app.add_middleware( ServerErrorMiddleware,  handler=exception_handler )

Okay I found a way to use with starlette-context to retrieve my payload only when I need logging
https://github.com/tomwojcik/starlette-context/blob/e222f739f113b74c2dad772d417d7fcc6f82f0ae/examples/example_with_logger/logger.py

I am writing every request using an incoming middleware with context in starlette-context and retrieving through context.data at HTTPException and ServerErrorMiddleware. This is not required in ValidationError as FastAPI support natively by using exc.body. All this roundabout only when you need to log those failed responses.

You dont need an middleware for logging incoming requests if you dont want to simply use depends in your router

application.include_router(api_router, prefix=API_V1_STR, dependencies=[Depends(log_json)])

@talarari
Copy link

talarari commented Mar 4, 2021

is there an update for this?
it seems like this is something a lot of people need for very legit use cases ( mainly logging it seems)

is there a plan to allow consuming the body in a middleware ?
i see the body is cached on the request object _body, is it possible to cache it on the scope so it accessible from everywhere after it is read?

any other solution would also be ok, but i do feel this i needed

@JHBalaji how are you logging the request body to context, are you not running into the same issue when trying to access the body in the incoming request middleware?

@JivanRoquet
Copy link

Another use case for this: decompressing a POST request's body.

Example (which does not work, but would be amazing if it did):

@app.middleware("http")
async def decompress_if_gzip(request: Request, call_next):
    if request.headers.get("content-encoding", "") == "gzip":
        body = await request.body()
        dec = gzip.decompress(body)
        request._body = dec
    response = await call_next(request)
    return response

Any better place to do this than in the middleware?

@JHBalaji
Copy link

JHBalaji commented Mar 25, 2021

@talarari
I use contextmiddleware from from starlette

from starlette_context.middleware import ContextMiddleware


class ContextFromMiddleware(ContextMiddleware):
    """
    This class helps in setting a Context for a Request-Response which is missing in Starlette default implementation.
    Initialize an empty dict to write data
    """

    async def set_context(self, request: Request) -> dict:
        return {}

You would add the middleware in the app then you could access as context.data

@JivanRoquet
There is a gzip middleware. You could import as

from fastapi.middleware.gzip import GZipMiddleware

@JivanRoquet
Copy link

@JHBalaji unless I'm mistaken, this GZipMiddleware is to gzip the response's body — not unzipping the request's body

@JHBalaji
Copy link

@JivanRoquet yes but the FastAPI documentation does have what you might be looking for.

@JivanRoquet
Copy link

@JHBalaji yes I've seen that since then, thank you for following up.

@kovalevvlad
Copy link

kovalevvlad commented Oct 28, 2021

@kigawas I am using this in production with multiple middlewares before and after this one. Perhaps you are doing something that I am not and hence hitting this problem

Edit: trying to reuse this solution later in another project causes problems. Could be FastAPI version dependent.

@jacksbox
Copy link

any updates?
getting the request body in the middleware feels like something which should work...

@alex-oleshkevich
Copy link
Member

I think we can do it in this way:

  1. when you do Request(scope, receive, send) it will set self instance into scope.
  2. next time, when you instantiate a new request object, it will get the request instance from the scope (achievable by implementing __new__).
    This says we will always have the same Request object.

@zevisert
Copy link

Does Starlette.exception_handler count as registering a middleware function? I can't seem to consume the (or access the cached) body in custom exception handlers either.

@tomchristie
Copy link
Member

Same issue, yes.

@zevisert
Copy link

Thanks for the confirmation, I ended up adapting an idea from FastAPI. Essentially my whole app runs as one middleware, and I pulled the same the handler lookup from ExceptionMiddleware, so that I can pass the same starlette.Request to all of my custom handlers. I should be safe this way, since I'm only catching my application's base Exception, everything else will unwind to the built-in exception handlers.

@heyfavour
Copy link

uvicorn ->app(receive,send) -> receive = queue,get() if log message by receive,receive will be consume and request will block because nothing in queue,get()

@kigawas
Copy link

kigawas commented Jun 23, 2022

I think we can do it in this way:

  1. when you do Request(scope, receive, send) it will set self instance into scope.
  2. next time, when you instantiate a new request object, it will get the request instance from the scope (achievable by implementing __new__).
    This says we will always have the same Request object.

You mean this? This definitely works and it's pretty smart 😄

diff --git a/starlette/requests.py b/starlette/requests.py
index 66c510c..69cad4c 100644
--- a/starlette/requests.py
+++ b/starlette/requests.py
@@ -188,11 +188,22 @@ async def empty_send(message: Message) -> typing.NoReturn:
 
 
 class Request(HTTPConnection):
+    def __new__(
+        cls, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
+    ):
+        if "request" in scope:
+            return scope["request"]
+
+        obj = object.__new__(cls)
+        scope["request"] = obj
+        return obj

@csrgxtu
Copy link

csrgxtu commented Sep 15, 2022

@kigawas so this modification didnt merge into master, right?

@kigawas
Copy link

kigawas commented Sep 16, 2022

@kigawas so this modification didnt merge into master, right?

Check this instead: #1702

@vickumar1981
Copy link

vickumar1981 commented Dec 9, 2022

@tomchristie we're currently using the above workaround using the code below:

class LoggingMiddleware(BaseHTTPMiddleware):
    def __init__(self, app):
        super().__init__(app)
        self.logger = logging.getLogger()

    async def set_body(self, request):
        receive_ = await request.receive()

        async def receive() -> Message:
            return receive_

        request._receive = receive

    async def dispatch(self, request, call_next):
        req_uuid = request.state.correlation_id
        await self.set_body(request)
        body = await request.body()
        self.logger.info(
            "Request",
            extra={
                "uuid": req_uuid,
                "type": "api-request",
                "method": str(request.method).upper(),
                "url": str(request.url),
                "payload": body.decode("utf-8"),
            },
        )
        start_time = time.time()
        response = await call_next(request)
        process_time = (time.time() - start_time) * 1000
        formatted_process_time = "{0:.2f}".format(process_time)
        self.logger.info(
            "Response sent",
            extra={
                "uuid": req_uuid,
                "type": "api-response",
                "url": f"[{str(request.method).upper()}] {str(request.url)}",
                "code": response.status_code,
                "elapsed_time": f"{formatted_process_time}ms",
            },
        )
        return response

this is working great for us. What would you think about maybe adding some kind of mixin class to starlette? maybe like a RequestLoggingMiddleware class, so that it makes it easier to log request payloads from the API, even though, I agree with you that in general, this is problematic and should be avoided?

@SamOyeAH
Copy link

@tomchristie we're currently using the above workaround using the code below:

class LoggingMiddleware(BaseHTTPMiddleware):
    def __init__(self, app):
        super().__init__(app)
        self.logger = logging.getLogger()

    async def set_body(self, request):
        receive_ = await request.receive()

        async def receive() -> Message:
            return receive_

        request._receive = receive

    async def dispatch(self, request, call_next):
        req_uuid = request.state.correlation_id
        await self.set_body(request)
        body = await request.body()
        self.logger.info(
            "Request",
            extra={
                "uuid": req_uuid,
                "type": "api-request",
                "method": str(request.method).upper(),
                "url": str(request.url),
                "payload": body.decode("utf-8"),
            },
        )
        start_time = time.time()
        response = await call_next(request)
        process_time = (time.time() - start_time) * 1000
        formatted_process_time = "{0:.2f}".format(process_time)
        self.logger.info(
            "Response sent",
            extra={
                "uuid": req_uuid,
                "type": "api-response",
                "url": f"[{str(request.method).upper()}] {str(request.url)}",
                "code": response.status_code,
                "elapsed_time": f"{formatted_process_time}ms",
            },
        )
        return response

this is working great for us. What would you think about maybe adding some kind of mixin class to starlette? maybe like a RequestLoggingMiddleware class, so that it makes it easier to log request payloads from the API, even though, I agree with you that in general, this is problematic and should be avoided?

This is very similar to our workaround.

class LogRequestsMiddleware(BaseHTTPMiddleware):
    """Logs requests and responses.

    Logs contain:
    - Headers
    - Cookies
    - Referer
    - Path
    """

    def __init__(self, app: ASGIApp) -> None:
        super().__init__(app, dispatch=self._get_log_requests_middleware())

    @staticmethod
    def _get_log_requests_middleware() -> Callable:
        async def set_body(request: Request, body: bytes) -> None:
            """
            Sets the request body to the given value that can be parsed.
            Args:
                request:
                body:

            Returns:

            """

            async def receive() -> Message:
                return {"type": "http.request", "body": body}

            request._receive = receive

        async def middleware(request: Request, call_next: Callable) -> Response:
            req_body = await request.body()
            await set_body(request, req_body)
            logger.info("REQUEST: req_body=%s", req_body)
            response: Response = await call_next(request)

            response_body = [chunk async for chunk in response.body_iterator]
            response.body_iterator = iterate_in_threadpool(iter(response_body))
            res_body_json = json.loads(response_body[0].decode())
            logger.info("RESPONSE: res_body_json=%s", res_body_json)
            return response

        return middleware

@bricker
Copy link

bricker commented Apr 25, 2023

I'm posting in this 4-year old issue this with the hope that someone will see it and it will lead to a better understanding of the problem and ultimately a fix, because I have run into this problem many times, and because it doesn't throw an error, it just hangs the program it always takes me too long to figure out what's going on.

I think there is some confusion about what the problem is. I don't believe people are confused about how to wrap receive with a function that consumes the request body, nor do I think people are misguided in needing to consume the full request body in a middleware. Sure, it removes the benefits of a streaming architecture, but that's an acceptable trade-off in many cases. For example, in the case of HMAC signature verification: The first thing I want to do when my application receives a request is verify the signature. If the signature is bad, I want the request to be rejected before it uses anymore resources (database connections, logging, etc.).

If I only wrap the receive function in another function that does signature verification, then the signature won't be checked until the request has gotten all the way to the endpoint handler where Starlette begins to consume the request body.

The problem comes when you try to consume the request body (i.e., call receive) directly inside of the middleware, like this:

async def __call__(self, scope, receive, send):
    body = b""

    while True:
        message = await receive()
        chunk: bytes = message.get("body", b"")
        body += chunk
        if message.get("more_body", False) is False:
            break

receive is now exhausted, and this line hangs because it's waiting for something that will never come.

For me, all that's needed is a tiny, undocumented property on state that can signal to Starlette that the body has been consumed and it shouldn't wait for anything else:

async def __call__(self, scope, receive, send):
    body = b""
    # ... consume the body    
    scope["state"]["previously_consumed_body"] = body
    self.app(scope, receive, send)

And in Request.body:

async def body(self) -> bytes:
    if (previously_consumed_body := self.scope["state"].get("previously_consumed_body")):
        return previously_consumed_body
    
    # else, the rest of the function is the same

Sure, it's hacky. But it's just one line of code, doesn't need to be documented (i.e. comes with the understanding that using it is risky), adds virtually no request overhead, but will save me and probably thousands of other developers from having to refactor their middleware stack or routes, fork this repository, or use another library.

If someone needs a fix now, you can do something like this to trick Starlette (this is similar to other workarounds already mentioned above). Note that this code isn't production-ready:

async def __call__(self, scope, receive, send):
    body = b""
    # ... consume the body

    async def dummy_receive():
        return {
            "type": "http.request",
            "body": body,
            "more_body": False,
        }

    await self.app(scope, dummy_receive, send)

@four43
Copy link

four43 commented Apr 25, 2023

This issue seems to be stuck on a new PR: #1519

@adriangb
Copy link
Member

receive is now exhausted, and this line hangs because it's waiting for something that will never come.

You can always insert an empty body message:

async def __call__(self, scope, receive, send):
    body = b""

    while True:
        message = await receive()
        chunk: bytes = message.get("body", b"")
        body += chunk
        if message.get("more_body", False) is False:
            break
    def wrapped_receive():
        yield {"type": "http.request", "body": b""}
        yield await receive()
    await self.app(scope, wrapped_receive, send)

Or something like that

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet