diff --git a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py index e87b28e..e7400cd 100644 --- a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py +++ b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py @@ -1,7 +1,7 @@ # Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! # source: connectrpc/conformance/v1/service.proto -from collections.abc import AsyncIterator, Iterable, Iterator, Mapping +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping from typing import Protocol from connectrpc.client import ConnectClient, ConnectClientSync @@ -72,16 +72,17 @@ async def idempotent_unary( raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") -class ConformanceServiceASGIApplication(ConnectASGIApplication): +class ConformanceServiceASGIApplication(ConnectASGIApplication[ConformanceService]): def __init__( self, - service: ConformanceService, + service: ConformanceService | AsyncGenerator[ConformanceService], *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, ) -> None: super().__init__( - endpoints={ + service=service, + endpoints=lambda svc: { "/connectrpc.conformance.v1.ConformanceService/Unary": Endpoint.unary( method=MethodInfo( name="Unary", @@ -90,7 +91,7 @@ def __init__( output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnaryResponse, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.unary, + function=svc.unary, ), "/connectrpc.conformance.v1.ConformanceService/ServerStream": Endpoint.server_stream( method=MethodInfo( @@ -100,7 +101,7 @@ def __init__( output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.ServerStreamResponse, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.server_stream, + function=svc.server_stream, ), "/connectrpc.conformance.v1.ConformanceService/ClientStream": Endpoint.client_stream( method=MethodInfo( @@ -110,7 +111,7 @@ def __init__( output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.ClientStreamResponse, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.client_stream, + function=svc.client_stream, ), "/connectrpc.conformance.v1.ConformanceService/BidiStream": Endpoint.bidi_stream( method=MethodInfo( @@ -120,7 +121,7 @@ def __init__( output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.BidiStreamResponse, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.bidi_stream, + function=svc.bidi_stream, ), "/connectrpc.conformance.v1.ConformanceService/Unimplemented": Endpoint.unary( method=MethodInfo( @@ -130,7 +131,7 @@ def __init__( output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnimplementedResponse, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.unimplemented, + function=svc.unimplemented, ), "/connectrpc.conformance.v1.ConformanceService/IdempotentUnary": Endpoint.unary( method=MethodInfo( @@ -140,7 +141,7 @@ def __init__( output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.IdempotentUnaryResponse, idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, ), - function=service.idempotent_unary, + function=svc.idempotent_unary, ), }, interceptors=interceptors, diff --git a/example/example/eliza_connect.py b/example/example/eliza_connect.py index 13f1f71..90eb71b 100644 --- a/example/example/eliza_connect.py +++ b/example/example/eliza_connect.py @@ -1,7 +1,7 @@ # Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! # source: example/eliza.proto -from collections.abc import AsyncIterator, Iterable, Iterator, Mapping +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping from typing import Protocol from connectrpc.client import ConnectClient, ConnectClientSync @@ -39,16 +39,17 @@ def introduce( raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") -class ElizaServiceASGIApplication(ConnectASGIApplication): +class ElizaServiceASGIApplication(ConnectASGIApplication[ElizaService]): def __init__( self, - service: ElizaService, + service: ElizaService | AsyncGenerator[ElizaService], *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, ) -> None: super().__init__( - endpoints={ + service=service, + endpoints=lambda svc: { "/connectrpc.eliza.v1.ElizaService/Say": Endpoint.unary( method=MethodInfo( name="Say", @@ -57,7 +58,7 @@ def __init__( output=example_dot_eliza__pb2.SayResponse, idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, ), - function=service.say, + function=svc.say, ), "/connectrpc.eliza.v1.ElizaService/Converse": Endpoint.bidi_stream( method=MethodInfo( @@ -67,7 +68,7 @@ def __init__( output=example_dot_eliza__pb2.ConverseResponse, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.converse, + function=svc.converse, ), "/connectrpc.eliza.v1.ElizaService/Introduce": Endpoint.server_stream( method=MethodInfo( @@ -77,7 +78,7 @@ def __init__( output=example_dot_eliza__pb2.IntroduceResponse, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.introduce, + function=svc.introduce, ), }, interceptors=interceptors, diff --git a/protoc-gen-connect-python/generator/generator_test.go b/protoc-gen-connect-python/generator/generator_test.go index a8cedb9..e15ed39 100644 --- a/protoc-gen-connect-python/generator/generator_test.go +++ b/protoc-gen-connect-python/generator/generator_test.go @@ -106,7 +106,7 @@ func TestGenerateConnectFile(t *testing.T) { } content := got.GetContent() - if !strings.Contains(content, "from collections.abc import AsyncIterator, Iterable, Iterator, Mapping") { + if !strings.Contains(content, "from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping") { t.Error("Generated code missing required imports") } if !strings.Contains(content, "class "+strings.Split(tt.input.GetService()[0].GetName(), ".")[0]) { diff --git a/protoc-gen-connect-python/generator/template.go b/protoc-gen-connect-python/generator/template.go index 75e54db..12a6b38 100644 --- a/protoc-gen-connect-python/generator/template.go +++ b/protoc-gen-connect-python/generator/template.go @@ -44,7 +44,7 @@ var ConnectTemplate = template.Must(template.New("ConnectTemplate").Parse(`# -*- # Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! # source: {{.FileName}} {{if .Services}} -from collections.abc import AsyncIterator, Iterable, Iterator, Mapping +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping from typing import Protocol from connectrpc.client import ConnectClient, ConnectClientSync @@ -67,10 +67,11 @@ class {{.Name}}(Protocol):{{- range .Methods }} raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") {{ end }} -class {{.Name}}ASGIApplication(ConnectASGIApplication): - def __init__(self, service: {{.Name}}, *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None: +class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]): + def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None: super().__init__( - endpoints={ {{- range .Methods }} + service=service, + endpoints=lambda svc: { {{- range .Methods }} "/{{.ServiceName}}/{{.Name}}": Endpoint.{{.EndpointType}}( method=MethodInfo( name="{{.Name}}", @@ -79,7 +80,7 @@ class {{.Name}}ASGIApplication(ConnectASGIApplication): output={{.OutputType}}, idempotency_level=IdempotencyLevel.{{.IdempotencyLevel}}, ), - function=service.{{.PythonName}}, + function=svc.{{.PythonName}}, ),{{- end }} }, interceptors=interceptors, diff --git a/protoc-gen-connect-python/generator/template_test.go b/protoc-gen-connect-python/generator/template_test.go index 9afdffc..ff614c4 100644 --- a/protoc-gen-connect-python/generator/template_test.go +++ b/protoc-gen-connect-python/generator/template_test.go @@ -36,9 +36,9 @@ func TestConnectTemplate(t *testing.T) { }, }, contains: []string{ - "from collections.abc import AsyncIterator, Iterable, Iterator, Mapping", + "from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping", "class TestService(Protocol):", - "class TestServiceASGIApplication(ConnectASGIApplication):", + "class TestServiceASGIApplication(ConnectASGIApplication[TestService]):", "def TestMethod", }, }, diff --git a/src/connectrpc/_server_async.py b/src/connectrpc/_server_async.py index 9608940..176841d 100644 --- a/src/connectrpc/_server_async.py +++ b/src/connectrpc/_server_async.py @@ -1,11 +1,19 @@ import base64 import functools +import inspect from abc import ABC, abstractmethod from asyncio import CancelledError, sleep -from collections.abc import AsyncIterator, Iterable, Mapping, Sequence +from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Callable, + Iterable, + Mapping, + Sequence, +) from dataclasses import replace from http import HTTPStatus -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar, cast from urllib.parse import parse_qs from . import _compression, _server_shared @@ -48,6 +56,7 @@ Scope = "asgiref.typing.Scope" +_SVC = TypeVar("_SVC") _REQ = TypeVar("_REQ") _RES = TypeVar("_RES") @@ -64,9 +73,11 @@ ) -class ConnectASGIApplication(ABC): +class ConnectASGIApplication(ABC, Generic[_SVC]): """An ASGI application for the Connect protocol.""" + _resolved_endpoints: Mapping[str, Endpoint] | None + @property @abstractmethod def path(self) -> str: ... @@ -74,35 +85,84 @@ def path(self) -> str: ... def __init__( self, *, - endpoints: Mapping[str, Endpoint], + service: _SVC | AsyncGenerator[_SVC], + endpoints: Callable[[_SVC], Mapping[str, Endpoint]], interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, ) -> None: """Initialize the ASGI application.""" super().__init__() - if interceptors: - interceptors = resolve_interceptors(interceptors) - endpoints = { - path: _apply_interceptors(endpoint, interceptors) - for path, endpoint in endpoints.items() - } + self._service = service self._endpoints = endpoints + self._interceptors = interceptors + self._resolved_endpoints = None self._read_max_bytes = read_max_bytes async def __call__( self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable ) -> None: - assert scope["type"] == "http" # noqa: S101 - only for type narrowing, in practice always true + if scope["type"] == "websocket": + msg = "connect does not support websockets" + raise RuntimeError(msg) + + if scope["type"] == "lifespan": + service_iter = None + while True: + msg = await receive() + match msg["type"]: + case "lifespan.startup": + # Need to cast since type checking doesn't seem to narrow well with isasyncgen + if inspect.isasyncgen(self._service): + service_iter = cast( + "AsyncGenerator[_SVC, None]", self._service + ) + try: + service = await anext(service_iter) + except Exception as e: + await send( + { + "type": "lifespan.startup.failed", + "message": str(e), + } + ) + return None + else: + service = cast("_SVC", self._service) + self._resolved_endpoints = self._resolve_endpoints(service) + await send({"type": "lifespan.startup.complete"}) + case "lifespan.shutdown": + if service_iter is not None: + try: + await service_iter.aclose() + except Exception as e: + await send( + { + "type": "lifespan.shutdown.failed", + "message": str(e), + } + ) + return None + await send({"type": "lifespan.shutdown.complete"}) + return None + + if not self._resolved_endpoints: + if inspect.isasyncgen(self._service): + msg = "ASGI server does not support lifespan but async generator passed for service. Enable lifespan support." + raise RuntimeError(msg) + + self._resolved_endpoints = self._resolve_endpoints( + cast("_SVC", self._service) + ) + endpoints = self._resolved_endpoints ctx: RequestContext | None = None try: path = scope["path"] - endpoint = self._endpoints.get(path) + endpoint = endpoints.get(path) if not endpoint and scope["root_path"]: # The application was mounted at some root so try stripping the prefix. path = path.removeprefix(scope["root_path"]) - endpoint = self._endpoints.get(path) - + endpoint = endpoints.get(path) if not endpoint: raise HTTPException(HTTPStatus.NOT_FOUND, []) @@ -381,6 +441,17 @@ async def _handle_error( ) await send({"type": "http.response.body", "body": body, "more_body": False}) + def _resolve_endpoints(self, service: _SVC) -> Mapping[str, Endpoint]: + resolved_endpoints = self._endpoints(service) + if self._interceptors: + resolved_endpoints = { + path: _apply_interceptors( + endpoint, resolve_interceptors(self._interceptors) + ) + for path, endpoint in resolved_endpoints.items() + } + return resolved_endpoints + async def _send_stream_response_headers( send: ASGISendCallable, codec: Codec, compression_name: str, ctx: RequestContext diff --git a/test/haberdasher_connect.py b/test/haberdasher_connect.py index df3c2a6..34a8537 100644 --- a/test/haberdasher_connect.py +++ b/test/haberdasher_connect.py @@ -1,7 +1,7 @@ # Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! # source: haberdasher.proto -from collections.abc import AsyncIterator, Iterable, Iterator, Mapping +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping from typing import Protocol import google.protobuf.empty_pb2 as google_dot_protobuf_dot_empty__pb2 @@ -54,16 +54,17 @@ async def do_nothing( raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") -class HaberdasherASGIApplication(ConnectASGIApplication): +class HaberdasherASGIApplication(ConnectASGIApplication[Haberdasher]): def __init__( self, - service: Haberdasher, + service: Haberdasher | AsyncGenerator[Haberdasher], *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, ) -> None: super().__init__( - endpoints={ + service=service, + endpoints=lambda svc: { "/connectrpc.example.Haberdasher/MakeHat": Endpoint.unary( method=MethodInfo( name="MakeHat", @@ -72,7 +73,7 @@ def __init__( output=haberdasher__pb2.Hat, idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, ), - function=service.make_hat, + function=svc.make_hat, ), "/connectrpc.example.Haberdasher/MakeFlexibleHat": Endpoint.client_stream( method=MethodInfo( @@ -82,7 +83,7 @@ def __init__( output=haberdasher__pb2.Hat, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.make_flexible_hat, + function=svc.make_flexible_hat, ), "/connectrpc.example.Haberdasher/MakeSimilarHats": Endpoint.server_stream( method=MethodInfo( @@ -92,7 +93,7 @@ def __init__( output=haberdasher__pb2.Hat, idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, ), - function=service.make_similar_hats, + function=svc.make_similar_hats, ), "/connectrpc.example.Haberdasher/MakeVariousHats": Endpoint.bidi_stream( method=MethodInfo( @@ -102,7 +103,7 @@ def __init__( output=haberdasher__pb2.Hat, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.make_various_hats, + function=svc.make_various_hats, ), "/connectrpc.example.Haberdasher/ListParts": Endpoint.server_stream( method=MethodInfo( @@ -112,7 +113,7 @@ def __init__( output=haberdasher__pb2.Hat.Part, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.list_parts, + function=svc.list_parts, ), "/connectrpc.example.Haberdasher/DoNothing": Endpoint.unary( method=MethodInfo( @@ -122,7 +123,7 @@ def __init__( output=google_dot_protobuf_dot_empty__pb2.Empty, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.do_nothing, + function=svc.do_nothing, ), }, interceptors=interceptors, diff --git a/test/test_lifespan.py b/test/test_lifespan.py new file mode 100644 index 0000000..b45dd45 --- /dev/null +++ b/test/test_lifespan.py @@ -0,0 +1,176 @@ +import asyncio +from collections import Counter +from io import StringIO + +import pytest +import uvicorn +from httpx import ASGITransport, AsyncClient +from uvicorn.config import LOGGING_CONFIG + +from connectrpc.errors import ConnectError + +from .haberdasher_connect import ( + Haberdasher, + HaberdasherASGIApplication, + HaberdasherClient, +) +from .haberdasher_pb2 import Hat, Size + + +class CountingHaberdasher(Haberdasher): + def __init__(self, counter: Counter) -> None: + self._counter = counter + + async def make_hat(self, request, ctx): + self._counter["requests"] += 1 + return Hat(size=request.inches, color="blue") + + +@pytest.mark.asyncio +async def test_lifespan() -> None: + final_count = None + + async def counting_haberdasher(): + counter = Counter() + try: + haberdasher = CountingHaberdasher(counter) + yield haberdasher + finally: + nonlocal final_count + final_count = counter["requests"] + + app = HaberdasherASGIApplication(counting_haberdasher()) + # Use uvicorn since it supports lifespan + config = uvicorn.Config( + app, port=0, log_level="critical", timeout_graceful_shutdown=0 + ) + server = uvicorn.Server(config) + uvicorn_task = asyncio.create_task(server.serve()) + + for _ in range(50): + if server.started: + break + await asyncio.sleep(0.1) + else: + msg = "Server did not start" + raise RuntimeError(msg) + + port = server.servers[0].sockets[0].getsockname()[1] + + async with HaberdasherClient(f"http://localhost:{port}") as client: + for _ in range(5): + hat = await client.make_hat(Size(inches=10)) + assert hat.size == 10 + assert hat.color == "blue" + + server.should_exit = True + await uvicorn_task + assert final_count == 5 + + +@pytest.mark.asyncio +async def test_lifespan_startup_error() -> None: + final_count = None + + async def counting_haberdasher(): + counter = Counter() + if True: + msg = "Haberdasher failed to start" + raise RuntimeError(msg) + # Unreachable code below but keep it to make this an async generator + try: + haberdasher = CountingHaberdasher(counter) + yield haberdasher + finally: + nonlocal final_count + final_count = counter["requests"] + + app = HaberdasherASGIApplication(counting_haberdasher()) + # Use uvicorn since it supports lifespan + logs = StringIO() + log_config = LOGGING_CONFIG.copy() + log_config["handlers"]["default"]["stream"] = logs + config = uvicorn.Config( + app, + port=0, + log_level="error", + timeout_graceful_shutdown=0, + log_config=log_config, + ) + server = uvicorn.Server(config) + await server.serve() + + assert "Haberdasher failed to start" in logs.getvalue() + + +@pytest.mark.asyncio +async def test_lifespan_shutdown_error() -> None: + async def counting_haberdasher(): + counter = Counter() + try: + haberdasher = CountingHaberdasher(counter) + yield haberdasher + finally: + msg = "Haberdasher failed to shut down" + raise RuntimeError(msg) + + app = HaberdasherASGIApplication(counting_haberdasher()) + # Use uvicorn since it supports lifespan + logs = StringIO() + log_config = LOGGING_CONFIG.copy() + log_config["handlers"]["default"]["stream"] = logs + config = uvicorn.Config( + app, + port=0, + log_level="error", + timeout_graceful_shutdown=0, + log_config=log_config, + ) + server = uvicorn.Server(config) + uvicorn_task = asyncio.create_task(server.serve()) + + for _ in range(50): + if server.started: + break + await asyncio.sleep(0.1) + else: + msg = "Server did not start" + raise RuntimeError(msg) + + port = server.servers[0].sockets[0].getsockname()[1] + + async with HaberdasherClient(f"http://localhost:{port}") as client: + for _ in range(5): + hat = await client.make_hat(Size(inches=10)) + assert hat.size == 10 + assert hat.color == "blue" + + server.should_exit = True + await uvicorn_task + assert "Haberdasher failed to shut down" in logs.getvalue() + + +@pytest.mark.asyncio +async def test_lifespan_not_supported() -> None: + final_count = None + + async def counting_haberdasher(): + counter = Counter() + try: + haberdasher = CountingHaberdasher(counter) + yield haberdasher + finally: + nonlocal final_count + final_count = counter["requests"] + + app = HaberdasherASGIApplication(counting_haberdasher()) + transport = ASGITransport(app) # pyright:ignore[reportArgumentType] - httpx type is not complete + async with HaberdasherClient( + "http://localhost", session=AsyncClient(transport=transport) + ) as client: + with pytest.raises(ConnectError) as e: + await client.make_hat(Size(inches=10)) + assert ( + "ASGI server does not support lifespan but async generator passed for service." + in str(e.value) + )