Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSocket fails via CSRFMiddleware #14

Open
johnpaulett opened this issue Aug 29, 2023 · 2 comments
Open

WebSocket fails via CSRFMiddleware #14

johnpaulett opened this issue Aug 29, 2023 · 2 comments
Labels
bug Something isn't working polar

Comments

@johnpaulett
Copy link

johnpaulett commented Aug 29, 2023

Ideally, I'd like to use the CSRFMiddleware on a websocket route. But at present, the CSRFMiddleware makes hits an assertion error whenever the websocket route is accessed:

Error

.venv/lib/python3.11/site-packages/starlette/testclient.py:91: in __enter__
    message = self.receive()
.venv/lib/python3.11/site-packages/starlette/testclient.py:160: in receive
    raise message
.venv/lib/python3.11/site-packages/anyio/from_thread.py:219: in _call_func
    retval = await retval
.venv/lib/python3.11/site-packages/starlette/testclient.py:118: in _run
    await self.app(scope, receive, send)
.venv/lib/python3.11/site-packages/fastapi/applications.py:289: in __call__
    await super().__call__(scope, receive, send)
.venv/lib/python3.11/site-packages/starlette/applications.py:122: in __call__
    await self.middleware_stack(scope, receive, send)
.venv/lib/python3.11/site-packages/starlette/middleware/errors.py:149: in __call__
    await self.app(scope, receive, send)
.venv/lib/python3.11/site-packages/starlette/middleware/cors.py:75: in __call__
    await self.app(scope, receive, send)
.venv/lib/python3.11/site-packages/starlette/middleware/sessions.py:86: in __call__
    await self.app(scope, receive, send_wrapper)
.venv/lib/python3.11/site-packages/starlette_csrf/middleware.py:55: in __call__
    request = Request(scope)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <starlette.requests.Request object at 0x137109750>
scope = {'app': <fastapi.applications.FastAPI object at 0x117817e10>, 'client': ['testclient', 50000], 'headers': [(b'host', '...'), (b'connection', b'upgrade'), ...], 'path': '/ws, ...}
receive = <function empty_receive at 0x102887600>, send = <function empty_send at 0x105c49b20>

    def __init__(
        self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
    ):
        super().__init__(scope)
>       assert scope["type"] == "http"
E       AssertionError

.venv/lib/python3.11/site-packages/starlette/requests.py:197: AssertionError

Attempted solutions

  • Failed: Added my websocket route to exempt_urls -- think the error happens before this is every applied
  • Failed: Added my websocket route to required_urls -- I hoped that maybe the initial HTTP connection that gets upgrade would pass thru
  • Tried replacing request = Request(scope) with request = WebSocket(scope, receive=receive, send=send) if scope["type"] == "websocket" else Request(scope), but still occurred

Current workaround

I wrap CSRFMiddleware and only pass HTTP requests into it. This is suboptimal because I would like to enforce CSRF protection for my websocket route.

from starlette_csrf.middleware import CSRFMiddleware as _CSRFMiddleware
 
class CSRFMiddleware(_CSRFMiddleware):
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        # type="websocket" will raise an exception at present
        if scope["type"] == "http":
            await super().__call__(scope, receive, send)
        else:
            await self.app(scope, receive, send)

Partial test case

I tried to document the error in a testcase, but using the httpx client does not expose the .websocket_connect() that starlette's testclient exposes. So this code does not yet fully work

def get_app(**middleware_kwargs) -> Starlette:
    async def get(request: Request):
        return JSONResponse({"hello": "world"})

    async def post(request: Request):
        json = await request.json()
        return JSONResponse(json)
    
    async def websocket(websocket: WebSocket):
        await websocket.accept()
        data = await websocket.receive_text()
        await websocket.send_text(data)
        await websocket.close()

    app = Starlette(
        debug=True,
        routes=[
            Route("/get", get, methods=["GET"]),
            Route("/post1", post, methods=["POST"]),
            Route("/post2", post, methods=["POST"]),
            WebSocketRoute("/ws", websocket),
        ],
        middleware=[Middleware(CSRFMiddleware, secret="SECRET", **middleware_kwargs)],
    )

    return app

@pytest.mark.asyncio
async def test_valid_websocket():
    async with get_test_client(get_app()) as client:
        response_get = await client.get("/get")
        csrf_cookie = response_get.cookies["csrftoken"]

        async with client.websocket_connect("/ws") as websocket:
            await websocket.send_text("hello world")
            data = await websocket.receive_text()
            assert data == "hello world"

Happy to try to help with some pointers.

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar
@frankie567 frankie567 added the bug Something isn't working label Sep 5, 2023
@frankie567
Copy link
Owner

Indeed, I've never considered the case of WebSocket, so it's probable the code is not adequate for such requests.

I'm not sure right now how this can be solved, but I'll investigate. Thank you for the report 🙏

@polar-sh polar-sh bot added the polar label Sep 5, 2023
@gantoine
Copy link

@johnpaulett Not sure if this'll work for ya but I "fixed" this by only applying CSRF to http requests

class CustomCSRFMiddleware(CSRFMiddleware):
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        await super().__call__(scope, receive, send)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working polar
Projects
None yet
Development

No branches or pull requests

3 participants