From 2cd235695f9c6a8bb88ae06d734a051337394342 Mon Sep 17 00:00:00 2001 From: Dimitris Kontokostas Date: Fri, 24 Oct 2025 11:34:05 +0300 Subject: [PATCH 1/2] Fix ASGI event handling for long-lived connections After body events are consumed for authentication, the middleware's _fake_receive function now delegates to the original receive callable instead of returning None. This allows downstream applications to properly receive lifecycle events like http.disconnect, enabling proper cleanup for SSE connections, streaming responses, and other long-lived HTTP connections. Adds test to verify that _fake_receive correctly delegates to original receive after body events are exhausted. --- mauth_client/middlewares/asgi.py | 22 ++++++++---- tests/middlewares/asgi_test.py | 62 ++++++++++++++++++++++++++++++-- 2 files changed, 75 insertions(+), 9 deletions(-) diff --git a/mauth_client/middlewares/asgi.py b/mauth_client/middlewares/asgi.py index e0a1045..c571b25 100644 --- a/mauth_client/middlewares/asgi.py +++ b/mauth_client/middlewares/asgi.py @@ -62,7 +62,7 @@ async def __call__( scope_copy[ENV_APP_UUID] = signed.app_uuid scope_copy[ENV_AUTHENTIC] = True scope_copy[ENV_PROTOCOL_VERSION] = signed.protocol_version() - await self.app(scope_copy, self._fake_receive(events), send) + await self.app(scope_copy, self._fake_receive(events, receive), send) else: await self._send_response(send, status, message) @@ -100,12 +100,18 @@ async def _send_response(self, send: ASGISendCallable, status: int, msg: str) -> "body": json.dumps(body).encode("utf-8"), }) - def _fake_receive(self, events: List[ASGIReceiveEvent]) -> ASGIReceiveCallable: + def _fake_receive(self, events: List[ASGIReceiveEvent], + original_receive: ASGIReceiveCallable) -> ASGIReceiveCallable: """ - Create a fake, async receive function using an iterator of the events - we've already read. This will be passed to downstream middlewares/apps - instead of the usual receive fn, so that they can also "receive" the - body events. + Create a fake receive function that replays cached body events. + + After the middleware consumes request body events for authentication, + this allows downstream apps to also "receive" those events. Once all + cached events are exhausted, delegates to the original receive to + properly forward lifecycle events (like http.disconnect). + + This is essential for long-lived connections (SSE, streaming responses) + that need to detect client disconnects. """ events_iter = iter(events) @@ -113,5 +119,7 @@ async def _receive() -> ASGIReceiveEvent: try: return next(events_iter) except StopIteration: - pass + # After body events are consumed, delegate to original receive + # This allows proper handling of disconnects for SSE connections + return await original_receive() return _receive diff --git a/tests/middlewares/asgi_test.py b/tests/middlewares/asgi_test.py index 20f704b..d5624e2 100644 --- a/tests/middlewares/asgi_test.py +++ b/tests/middlewares/asgi_test.py @@ -1,9 +1,9 @@ import unittest -from unittest.mock import patch - from fastapi import FastAPI, Request from fastapi.testclient import TestClient from fastapi.websockets import WebSocket +from unittest.mock import AsyncMock +from unittest.mock import patch from uuid import uuid4 from mauth_client.authenticator import LocalAuthenticator @@ -220,3 +220,61 @@ def is_authentic_effect(self): self.client.get("/sub_app/path") self.assertEqual(request_url, "/sub_app/path") + + +class TestMAuthASGIMiddlewareInLongLivedConnections(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.app = FastAPI() + Config.APP_UUID = str(uuid4()) + Config.MAUTH_URL = "https://mauth.com" + Config.MAUTH_API_VERSION = "v1" + Config.PRIVATE_KEY = "key" + + @patch.object(LocalAuthenticator, "is_authentic") + async def test_fake_receive_delegates_to_original_after_body_consumed(self, is_authentic_mock): + """Test that after body events are consumed, _fake_receive delegates to original receive""" + is_authentic_mock.return_value = (True, 200, "") + + # Track that original receive was called after body events exhausted + call_order = [] + + async def mock_app(scope, receive, send): + # First receive should get body event + event1 = await receive() + call_order.append(("body", event1["type"])) + + # Second receive should delegate to original receive + event2 = await receive() + call_order.append(("disconnect", event2["type"])) + + await send({"type": "http.response.start", "status": 200, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + + middleware = MAuthASGIMiddleware(mock_app) + + # Mock receive that returns body then disconnect + receive_calls = 0 + + async def mock_receive(): + nonlocal receive_calls + receive_calls += 1 + if receive_calls == 1: + return {"type": "http.request", "body": b"test", "more_body": False} + return {"type": "http.disconnect"} + + send_mock = AsyncMock() + scope = { + "type": "http", + "method": "POST", + "path": "/test", + "query_string": b"", + "headers": [] + } + + await middleware(scope, mock_receive, send_mock) + + # Verify events were received in correct order + self.assertEqual(len(call_order), 2) + self.assertEqual(call_order[0], ("body", "http.request")) + self.assertEqual(call_order[1], ("disconnect", "http.disconnect")) + self.assertEqual(receive_calls, 2) # Called once for auth, once from app From 09332f3e1dc21d613252461078ae3d9d82ee890c Mon Sep 17 00:00:00 2001 From: Dimitris Kontokostas Date: Fri, 24 Oct 2025 15:40:44 +0300 Subject: [PATCH 2/2] bump up version and add changelog entry --- CHANGELOG.md | 3 +++ pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9b7caa..10a1609 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +# 1.6.6 +- Support long-lived connections in ASGI middleware + # 1.6.5 - Resolved dependabot identified security issues - Removed build status icon from travis (not used for CI any longer) diff --git a/pyproject.toml b/pyproject.toml index 40e2c8e..fb879ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mauth-client" -version = "1.6.5" +version = "1.6.6" description = "MAuth Client for Python" repository = "https://github.com/mdsol/mauth-client-python" authors = ["Medidata Solutions "]