Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# 1.6.6
- Support long-lived connections in ASGI middleware

# 1.6.5
- Resolved dependabot identified security issues
- Removed build status icon from travis (not used for CI any longer)
Expand Down
22 changes: 15 additions & 7 deletions mauth_client/middlewares/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def __call__(
scope_copy[ENV_APP_UUID] = signed.app_uuid
scope_copy[ENV_AUTHENTIC] = True
scope_copy[ENV_PROTOCOL_VERSION] = signed.protocol_version()
await self.app(scope_copy, self._fake_receive(events), send)
await self.app(scope_copy, self._fake_receive(events, receive), send)
else:
await self._send_response(send, status, message)

Expand Down Expand Up @@ -100,18 +100,26 @@ async def _send_response(self, send: ASGISendCallable, status: int, msg: str) ->
"body": json.dumps(body).encode("utf-8"),
})

def _fake_receive(self, events: List[ASGIReceiveEvent]) -> ASGIReceiveCallable:
def _fake_receive(self, events: List[ASGIReceiveEvent],
original_receive: ASGIReceiveCallable) -> ASGIReceiveCallable:
"""
Create a fake, async receive function using an iterator of the events
we've already read. This will be passed to downstream middlewares/apps
instead of the usual receive fn, so that they can also "receive" the
body events.
Create a fake receive function that replays cached body events.

After the middleware consumes request body events for authentication,
this allows downstream apps to also "receive" those events. Once all
cached events are exhausted, delegates to the original receive to
properly forward lifecycle events (like http.disconnect).

This is essential for long-lived connections (SSE, streaming responses)
that need to detect client disconnects.
"""
events_iter = iter(events)

async def _receive() -> ASGIReceiveEvent:
try:
return next(events_iter)
except StopIteration:
pass
# After body events are consumed, delegate to original receive
# This allows proper handling of disconnects for SSE connections
return await original_receive()
return _receive
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mauth-client"
version = "1.6.5"
version = "1.6.6"
description = "MAuth Client for Python"
repository = "https://github.com/mdsol/mauth-client-python"
authors = ["Medidata Solutions <support@mdsol.com>"]
Expand Down
62 changes: 60 additions & 2 deletions tests/middlewares/asgi_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import unittest
from unittest.mock import patch

from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from fastapi.websockets import WebSocket
from unittest.mock import AsyncMock
from unittest.mock import patch
from uuid import uuid4

from mauth_client.authenticator import LocalAuthenticator
Expand Down Expand Up @@ -220,3 +220,61 @@ def is_authentic_effect(self):
self.client.get("/sub_app/path")

self.assertEqual(request_url, "/sub_app/path")


class TestMAuthASGIMiddlewareInLongLivedConnections(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.app = FastAPI()
Config.APP_UUID = str(uuid4())
Config.MAUTH_URL = "https://mauth.com"
Config.MAUTH_API_VERSION = "v1"
Config.PRIVATE_KEY = "key"

@patch.object(LocalAuthenticator, "is_authentic")
async def test_fake_receive_delegates_to_original_after_body_consumed(self, is_authentic_mock):
"""Test that after body events are consumed, _fake_receive delegates to original receive"""
is_authentic_mock.return_value = (True, 200, "")

# Track that original receive was called after body events exhausted
call_order = []

async def mock_app(scope, receive, send):
# First receive should get body event
event1 = await receive()
call_order.append(("body", event1["type"]))

# Second receive should delegate to original receive
event2 = await receive()
call_order.append(("disconnect", event2["type"]))

await send({"type": "http.response.start", "status": 200, "headers": []})
await send({"type": "http.response.body", "body": b""})

middleware = MAuthASGIMiddleware(mock_app)

# Mock receive that returns body then disconnect
receive_calls = 0

async def mock_receive():
nonlocal receive_calls
receive_calls += 1
if receive_calls == 1:
return {"type": "http.request", "body": b"test", "more_body": False}
return {"type": "http.disconnect"}

send_mock = AsyncMock()
scope = {
"type": "http",
"method": "POST",
"path": "/test",
"query_string": b"",
"headers": []
}

await middleware(scope, mock_receive, send_mock)

# Verify events were received in correct order
self.assertEqual(len(call_order), 2)
self.assertEqual(call_order[0], ("body", "http.request"))
self.assertEqual(call_order[1], ("disconnect", "http.disconnect"))
self.assertEqual(receive_calls, 2) # Called once for auth, once from app
Loading