diff --git a/CHANGELOG.md b/CHANGELOG.md index e9b7caa..10a1609 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +# 1.6.6 +- Support long-lived connections in ASGI middleware + # 1.6.5 - Resolved dependabot identified security issues - Removed build status icon from travis (not used for CI any longer) diff --git a/mauth_client/middlewares/asgi.py b/mauth_client/middlewares/asgi.py index e0a1045..c571b25 100644 --- a/mauth_client/middlewares/asgi.py +++ b/mauth_client/middlewares/asgi.py @@ -62,7 +62,7 @@ async def __call__( scope_copy[ENV_APP_UUID] = signed.app_uuid scope_copy[ENV_AUTHENTIC] = True scope_copy[ENV_PROTOCOL_VERSION] = signed.protocol_version() - await self.app(scope_copy, self._fake_receive(events), send) + await self.app(scope_copy, self._fake_receive(events, receive), send) else: await self._send_response(send, status, message) @@ -100,12 +100,18 @@ async def _send_response(self, send: ASGISendCallable, status: int, msg: str) -> "body": json.dumps(body).encode("utf-8"), }) - def _fake_receive(self, events: List[ASGIReceiveEvent]) -> ASGIReceiveCallable: + def _fake_receive(self, events: List[ASGIReceiveEvent], + original_receive: ASGIReceiveCallable) -> ASGIReceiveCallable: """ - Create a fake, async receive function using an iterator of the events - we've already read. This will be passed to downstream middlewares/apps - instead of the usual receive fn, so that they can also "receive" the - body events. + Create a fake receive function that replays cached body events. + + After the middleware consumes request body events for authentication, + this allows downstream apps to also "receive" those events. Once all + cached events are exhausted, delegates to the original receive to + properly forward lifecycle events (like http.disconnect). + + This is essential for long-lived connections (SSE, streaming responses) + that need to detect client disconnects. """ events_iter = iter(events) @@ -113,5 +119,7 @@ async def _receive() -> ASGIReceiveEvent: try: return next(events_iter) except StopIteration: - pass + # After body events are consumed, delegate to original receive + # This allows proper handling of disconnects for SSE connections + return await original_receive() return _receive diff --git a/pyproject.toml b/pyproject.toml index 40e2c8e..fb879ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mauth-client" -version = "1.6.5" +version = "1.6.6" description = "MAuth Client for Python" repository = "https://github.com/mdsol/mauth-client-python" authors = ["Medidata Solutions "] diff --git a/tests/middlewares/asgi_test.py b/tests/middlewares/asgi_test.py index 20f704b..d5624e2 100644 --- a/tests/middlewares/asgi_test.py +++ b/tests/middlewares/asgi_test.py @@ -1,9 +1,9 @@ import unittest -from unittest.mock import patch - from fastapi import FastAPI, Request from fastapi.testclient import TestClient from fastapi.websockets import WebSocket +from unittest.mock import AsyncMock +from unittest.mock import patch from uuid import uuid4 from mauth_client.authenticator import LocalAuthenticator @@ -220,3 +220,61 @@ def is_authentic_effect(self): self.client.get("/sub_app/path") self.assertEqual(request_url, "/sub_app/path") + + +class TestMAuthASGIMiddlewareInLongLivedConnections(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.app = FastAPI() + Config.APP_UUID = str(uuid4()) + Config.MAUTH_URL = "https://mauth.com" + Config.MAUTH_API_VERSION = "v1" + Config.PRIVATE_KEY = "key" + + @patch.object(LocalAuthenticator, "is_authentic") + async def test_fake_receive_delegates_to_original_after_body_consumed(self, is_authentic_mock): + """Test that after body events are consumed, _fake_receive delegates to original receive""" + is_authentic_mock.return_value = (True, 200, "") + + # Track that original receive was called after body events exhausted + call_order = [] + + async def mock_app(scope, receive, send): + # First receive should get body event + event1 = await receive() + call_order.append(("body", event1["type"])) + + # Second receive should delegate to original receive + event2 = await receive() + call_order.append(("disconnect", event2["type"])) + + await send({"type": "http.response.start", "status": 200, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + + middleware = MAuthASGIMiddleware(mock_app) + + # Mock receive that returns body then disconnect + receive_calls = 0 + + async def mock_receive(): + nonlocal receive_calls + receive_calls += 1 + if receive_calls == 1: + return {"type": "http.request", "body": b"test", "more_body": False} + return {"type": "http.disconnect"} + + send_mock = AsyncMock() + scope = { + "type": "http", + "method": "POST", + "path": "/test", + "query_string": b"", + "headers": [] + } + + await middleware(scope, mock_receive, send_mock) + + # Verify events were received in correct order + self.assertEqual(len(call_order), 2) + self.assertEqual(call_order[0], ("body", "http.request")) + self.assertEqual(call_order[1], ("disconnect", "http.disconnect")) + self.assertEqual(receive_calls, 2) # Called once for auth, once from app