/
base.py
67 lines (54 loc) 路 2.26 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import asyncio
import typing
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
class BaseHTTPMiddleware:
def __init__(self, app: ASGIApp, dispatch: DispatchFunction = None) -> None:
self.app = app
self.dispatch_func = self.dispatch if dispatch is None else dispatch
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, self.call_next)
await response(scope, receive, send)
async def call_next(self, request: Request) -> Response:
loop = asyncio.get_event_loop()
queue: "asyncio.Queue[typing.Optional[Message]]" = asyncio.Queue()
scope = request.scope
receive = request.receive
send = queue.put
async def coro() -> None:
try:
await self.app(scope, receive, send)
finally:
await queue.put(None)
task = loop.create_task(coro())
message = await queue.get()
if message is None:
task.result()
raise RuntimeError("No response returned.")
assert message["type"] == "http.response.start"
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
while True:
message = await queue.get()
if message is None:
break
assert message["type"] == "http.response.body"
yield message.get("body", b"")
task.result()
response = StreamingResponse(
status_code=message["status"], content=body_stream()
)
response.raw_headers = message["headers"]
return response
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
raise NotImplementedError() # pragma: no cover