From 0fb269cafae0a4403e0af0fa45816055a0af3734 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 11 Nov 2025 15:38:21 +0900 Subject: [PATCH 1/8] Add support for ASGI lifespan Signed-off-by: Anuraag Agrawal --- .../conformance/v1/service_connect.py | 21 +-- example/example/eliza_connect.py | 15 +- .../generator/template.go | 11 +- src/connectrpc/_server_async.py | 104 +++++++++-- test/haberdasher_connect.py | 21 +-- test/test_lifespan.py | 162 ++++++++++++++++++ 6 files changed, 287 insertions(+), 47 deletions(-) create mode 100644 test/test_lifespan.py diff --git a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py index e87b28e..6993ce1 100644 --- a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py +++ b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py @@ -1,7 +1,7 @@ # Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! # source: connectrpc/conformance/v1/service.proto -from collections.abc import AsyncIterator, Iterable, Iterator, Mapping +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping from typing import Protocol from connectrpc.client import ConnectClient, ConnectClientSync @@ -72,16 +72,17 @@ async def idempotent_unary( raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") -class ConformanceServiceASGIApplication(ConnectASGIApplication): +class ConformanceServiceASGIApplication(ConnectASGIApplication[ConformanceService]): def __init__( self, - service: ConformanceService, + service: ConformanceService | AsyncGenerator[ConformanceService, None], *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, ) -> None: super().__init__( - endpoints={ + service=service, + endpoints=lambda svc: { "/connectrpc.conformance.v1.ConformanceService/Unary": Endpoint.unary( method=MethodInfo( name="Unary", @@ -90,7 +91,7 @@ def __init__( output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnaryResponse, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.unary, + function=svc.unary, ), "/connectrpc.conformance.v1.ConformanceService/ServerStream": Endpoint.server_stream( method=MethodInfo( @@ -100,7 +101,7 @@ def __init__( output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.ServerStreamResponse, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.server_stream, + function=svc.server_stream, ), "/connectrpc.conformance.v1.ConformanceService/ClientStream": Endpoint.client_stream( method=MethodInfo( @@ -110,7 +111,7 @@ def __init__( output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.ClientStreamResponse, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.client_stream, + function=svc.client_stream, ), "/connectrpc.conformance.v1.ConformanceService/BidiStream": Endpoint.bidi_stream( method=MethodInfo( @@ -120,7 +121,7 @@ def __init__( output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.BidiStreamResponse, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.bidi_stream, + function=svc.bidi_stream, ), "/connectrpc.conformance.v1.ConformanceService/Unimplemented": Endpoint.unary( method=MethodInfo( @@ -130,7 +131,7 @@ def __init__( output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnimplementedResponse, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.unimplemented, + function=svc.unimplemented, ), "/connectrpc.conformance.v1.ConformanceService/IdempotentUnary": Endpoint.unary( method=MethodInfo( @@ -140,7 +141,7 @@ def __init__( output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.IdempotentUnaryResponse, idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, ), - function=service.idempotent_unary, + function=svc.idempotent_unary, ), }, interceptors=interceptors, diff --git a/example/example/eliza_connect.py b/example/example/eliza_connect.py index 13f1f71..4600502 100644 --- a/example/example/eliza_connect.py +++ b/example/example/eliza_connect.py @@ -1,7 +1,7 @@ # Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! # source: example/eliza.proto -from collections.abc import AsyncIterator, Iterable, Iterator, Mapping +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping from typing import Protocol from connectrpc.client import ConnectClient, ConnectClientSync @@ -39,16 +39,17 @@ def introduce( raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") -class ElizaServiceASGIApplication(ConnectASGIApplication): +class ElizaServiceASGIApplication(ConnectASGIApplication[ElizaService]): def __init__( self, - service: ElizaService, + service: ElizaService | AsyncGenerator[ElizaService, None], *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, ) -> None: super().__init__( - endpoints={ + service=service, + endpoints=lambda svc: { "/connectrpc.eliza.v1.ElizaService/Say": Endpoint.unary( method=MethodInfo( name="Say", @@ -57,7 +58,7 @@ def __init__( output=example_dot_eliza__pb2.SayResponse, idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, ), - function=service.say, + function=svc.say, ), "/connectrpc.eliza.v1.ElizaService/Converse": Endpoint.bidi_stream( method=MethodInfo( @@ -67,7 +68,7 @@ def __init__( output=example_dot_eliza__pb2.ConverseResponse, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.converse, + function=svc.converse, ), "/connectrpc.eliza.v1.ElizaService/Introduce": Endpoint.server_stream( method=MethodInfo( @@ -77,7 +78,7 @@ def __init__( output=example_dot_eliza__pb2.IntroduceResponse, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.introduce, + function=svc.introduce, ), }, interceptors=interceptors, diff --git a/protoc-gen-connect-python/generator/template.go b/protoc-gen-connect-python/generator/template.go index 75e54db..4bf7d9d 100644 --- a/protoc-gen-connect-python/generator/template.go +++ b/protoc-gen-connect-python/generator/template.go @@ -44,7 +44,7 @@ var ConnectTemplate = template.Must(template.New("ConnectTemplate").Parse(`# -*- # Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! # source: {{.FileName}} {{if .Services}} -from collections.abc import AsyncIterator, Iterable, Iterator, Mapping +from collections.abc import AsyncIterator, AsyncGenerator, Iterable, Iterator, Mapping from typing import Protocol from connectrpc.client import ConnectClient, ConnectClientSync @@ -67,10 +67,11 @@ class {{.Name}}(Protocol):{{- range .Methods }} raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") {{ end }} -class {{.Name}}ASGIApplication(ConnectASGIApplication): - def __init__(self, service: {{.Name}}, *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None: +class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]): + def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}, None], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None: super().__init__( - endpoints={ {{- range .Methods }} + service=service, + endpoints=lambda svc: { {{- range .Methods }} "/{{.ServiceName}}/{{.Name}}": Endpoint.{{.EndpointType}}( method=MethodInfo( name="{{.Name}}", @@ -79,7 +80,7 @@ class {{.Name}}ASGIApplication(ConnectASGIApplication): output={{.OutputType}}, idempotency_level=IdempotencyLevel.{{.IdempotencyLevel}}, ), - function=service.{{.PythonName}}, + function=svc.{{.PythonName}}, ),{{- end }} }, interceptors=interceptors, diff --git a/src/connectrpc/_server_async.py b/src/connectrpc/_server_async.py index 9608940..191270f 100644 --- a/src/connectrpc/_server_async.py +++ b/src/connectrpc/_server_async.py @@ -1,11 +1,19 @@ import base64 import functools -from abc import ABC, abstractmethod +import inspect +from abc import abstractmethod from asyncio import CancelledError, sleep -from collections.abc import AsyncIterator, Iterable, Mapping, Sequence +from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Callable, + Iterable, + Mapping, + Sequence, +) from dataclasses import replace from http import HTTPStatus -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar, cast from urllib.parse import parse_qs from . import _compression, _server_shared @@ -48,6 +56,7 @@ Scope = "asgiref.typing.Scope" +_SVC = TypeVar("_SVC") _REQ = TypeVar("_REQ") _RES = TypeVar("_RES") @@ -64,9 +73,11 @@ ) -class ConnectASGIApplication(ABC): +class ConnectASGIApplication(Generic[_SVC]): """An ASGI application for the Connect protocol.""" + _resolved_endpoints: Mapping[str, Endpoint] | None + @property @abstractmethod def path(self) -> str: ... @@ -74,35 +85,87 @@ def path(self) -> str: ... def __init__( self, *, - endpoints: Mapping[str, Endpoint], + service: _SVC | AsyncGenerator[_SVC, None], + endpoints: Callable[[_SVC], Mapping[str, Endpoint]], interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, ) -> None: """Initialize the ASGI application.""" super().__init__() - if interceptors: - interceptors = resolve_interceptors(interceptors) - endpoints = { - path: _apply_interceptors(endpoint, interceptors) - for path, endpoint in endpoints.items() - } + self._service = service self._endpoints = endpoints + self._interceptors = interceptors + self._resolved_endpoints = None self._read_max_bytes = read_max_bytes async def __call__( self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable ) -> None: - assert scope["type"] == "http" # noqa: S101 - only for type narrowing, in practice always true + if scope["type"] == "websocket": + msg = "connect does not support websockets" + raise RuntimeError(msg) + + if scope["type"] == "lifespan": + service_iter = None + while True: + msg = await receive() + match msg["type"]: + case "lifespan.startup": + # Need to cast since type checking doesn't seem to narrow well with isasyncgen + if inspect.isasyncgen(self._service): + service_iter = cast( + "AsyncGenerator[_SVC, None]", self._service + ) + try: + service = await anext(service_iter) + except Exception as e: + await send( + { + "type": "lifespan.startup.failed", + "message": str(e), + } + ) + return None + else: + service = cast("_SVC", self._service) + if (state := scope.get("state")) is not None: + state["endpoints"] = self._resolve_endpoints(service) + await send({"type": "lifespan.startup.complete"}) + case "lifespan.shutdown": + if service_iter is not None: + try: + await service_iter.aclose() + except Exception as e: + await send( + { + "type": "lifespan.shutdown.failed", + "message": str(e), + } + ) + await send({"type": "lifespan.shutdown.complete"}) + return None + + if state := scope.get("state"): + endpoints: Mapping[str, Endpoint] = state["endpoints"] + else: + if not self._resolved_endpoints: + if inspect.isasyncgen(self._service): + msg = "ASGI server does not support lifespan but async generator passed for service. Enable lifespan support." + raise RuntimeError(msg) + + self._resolved_endpoints = self._resolve_endpoints( + cast("_SVC", self._service) + ) + endpoints = self._resolved_endpoints ctx: RequestContext | None = None try: path = scope["path"] - endpoint = self._endpoints.get(path) + endpoint = endpoints.get(path) if not endpoint and scope["root_path"]: # The application was mounted at some root so try stripping the prefix. path = path.removeprefix(scope["root_path"]) - endpoint = self._endpoints.get(path) - + endpoint = endpoints.get(path) if not endpoint: raise HTTPException(HTTPStatus.NOT_FOUND, []) @@ -381,6 +444,17 @@ async def _handle_error( ) await send({"type": "http.response.body", "body": body, "more_body": False}) + def _resolve_endpoints(self, service: _SVC) -> Mapping[str, Endpoint]: + resolved_endpoints = self._endpoints(service) + if self._interceptors: + resolved_endpoints = { + path: _apply_interceptors( + endpoint, resolve_interceptors(self._interceptors) + ) + for path, endpoint in resolved_endpoints.items() + } + return resolved_endpoints + async def _send_stream_response_headers( send: ASGISendCallable, codec: Codec, compression_name: str, ctx: RequestContext diff --git a/test/haberdasher_connect.py b/test/haberdasher_connect.py index df3c2a6..79343a3 100644 --- a/test/haberdasher_connect.py +++ b/test/haberdasher_connect.py @@ -1,7 +1,7 @@ # Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! # source: haberdasher.proto -from collections.abc import AsyncIterator, Iterable, Iterator, Mapping +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping from typing import Protocol import google.protobuf.empty_pb2 as google_dot_protobuf_dot_empty__pb2 @@ -54,16 +54,17 @@ async def do_nothing( raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") -class HaberdasherASGIApplication(ConnectASGIApplication): +class HaberdasherASGIApplication(ConnectASGIApplication[Haberdasher]): def __init__( self, - service: Haberdasher, + service: Haberdasher | AsyncGenerator[Haberdasher, None], *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, ) -> None: super().__init__( - endpoints={ + service=service, + endpoints=lambda svc: { "/connectrpc.example.Haberdasher/MakeHat": Endpoint.unary( method=MethodInfo( name="MakeHat", @@ -72,7 +73,7 @@ def __init__( output=haberdasher__pb2.Hat, idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, ), - function=service.make_hat, + function=svc.make_hat, ), "/connectrpc.example.Haberdasher/MakeFlexibleHat": Endpoint.client_stream( method=MethodInfo( @@ -82,7 +83,7 @@ def __init__( output=haberdasher__pb2.Hat, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.make_flexible_hat, + function=svc.make_flexible_hat, ), "/connectrpc.example.Haberdasher/MakeSimilarHats": Endpoint.server_stream( method=MethodInfo( @@ -92,7 +93,7 @@ def __init__( output=haberdasher__pb2.Hat, idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, ), - function=service.make_similar_hats, + function=svc.make_similar_hats, ), "/connectrpc.example.Haberdasher/MakeVariousHats": Endpoint.bidi_stream( method=MethodInfo( @@ -102,7 +103,7 @@ def __init__( output=haberdasher__pb2.Hat, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.make_various_hats, + function=svc.make_various_hats, ), "/connectrpc.example.Haberdasher/ListParts": Endpoint.server_stream( method=MethodInfo( @@ -112,7 +113,7 @@ def __init__( output=haberdasher__pb2.Hat.Part, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.list_parts, + function=svc.list_parts, ), "/connectrpc.example.Haberdasher/DoNothing": Endpoint.unary( method=MethodInfo( @@ -122,7 +123,7 @@ def __init__( output=google_dot_protobuf_dot_empty__pb2.Empty, idempotency_level=IdempotencyLevel.UNKNOWN, ), - function=service.do_nothing, + function=svc.do_nothing, ), }, interceptors=interceptors, diff --git a/test/test_lifespan.py b/test/test_lifespan.py new file mode 100644 index 0000000..fcd5747 --- /dev/null +++ b/test/test_lifespan.py @@ -0,0 +1,162 @@ +import asyncio +from collections import Counter +from io import StringIO + +import pytest +import uvicorn +from uvicorn.config import LOGGING_CONFIG + +from .haberdasher_connect import ( + Haberdasher, + HaberdasherASGIApplication, + HaberdasherClient, +) +from .haberdasher_pb2 import Hat, Size + + +@pytest.mark.asyncio +async def test_lifespan() -> None: + class CountingHaberdasher(Haberdasher): + def __init__(self, counter: Counter) -> None: + self._counter = counter + + async def make_hat(self, request, ctx): + self._counter["requests"] += 1 + return Hat(size=request.inches, color="blue") + + final_count = None + + async def counting_haberdasher(): + counter = Counter() + try: + haberdasher = CountingHaberdasher(counter) + yield haberdasher + finally: + nonlocal final_count + final_count = counter["requests"] + + app = HaberdasherASGIApplication(counting_haberdasher()) + # Use uvicorn since it supports lifespan + config = uvicorn.Config( + app, port=0, log_level="critical", timeout_graceful_shutdown=0 + ) + server = uvicorn.Server(config) + uvicorn_task = asyncio.create_task(server.serve()) + + for _ in range(50): + if server.started: + break + await asyncio.sleep(0.1) + else: + msg = "Server did not start" + raise RuntimeError(msg) + + port = server.servers[0].sockets[0].getsockname()[1] + + async with HaberdasherClient(f"http://localhost:{port}") as client: + for _ in range(5): + hat = await client.make_hat(Size(inches=10)) + assert hat.size == 10 + assert hat.color == "blue" + + server.should_exit = True + await uvicorn_task + assert final_count == 5 + + +@pytest.mark.asyncio +async def test_lifespan_startup_error() -> None: + class CountingHaberdasher(Haberdasher): + def __init__(self, counter: Counter) -> None: + self._counter = counter + + async def make_hat(self, request, ctx): + self._counter["requests"] += 1 + return Hat(size=request.inches, color="blue") + + final_count = None + + async def counting_haberdasher(): + counter = Counter() + if True: + msg = "Haberdasher failed to start" + raise RuntimeError(msg) + # Unreachable code below but keep it to make this an async generator + try: + haberdasher = CountingHaberdasher(counter) + yield haberdasher + finally: + nonlocal final_count + final_count = counter["requests"] + + app = HaberdasherASGIApplication(counting_haberdasher()) + # Use uvicorn since it supports lifespan + logs = StringIO() + log_config = LOGGING_CONFIG.copy() + log_config["handlers"]["default"]["stream"] = logs + config = uvicorn.Config( + app, + port=0, + log_level="error", + timeout_graceful_shutdown=0, + log_config=log_config, + ) + server = uvicorn.Server(config) + await server.serve() + + assert "Haberdasher failed to start" in logs.getvalue() + + +@pytest.mark.asyncio +async def test_lifespan_shutdown_error() -> None: + class CountingHaberdasher(Haberdasher): + def __init__(self, counter: Counter) -> None: + self._counter = counter + + async def make_hat(self, request, ctx): + self._counter["requests"] += 1 + return Hat(size=request.inches, color="blue") + + async def counting_haberdasher(): + counter = Counter() + try: + haberdasher = CountingHaberdasher(counter) + yield haberdasher + finally: + msg = "Haberdasher failed to shut down" + raise RuntimeError(msg) + + app = HaberdasherASGIApplication(counting_haberdasher()) + # Use uvicorn since it supports lifespan + logs = StringIO() + log_config = LOGGING_CONFIG.copy() + log_config["handlers"]["default"]["stream"] = logs + config = uvicorn.Config( + app, + port=0, + log_level="error", + timeout_graceful_shutdown=0, + log_config=log_config, + ) + server = uvicorn.Server(config) + uvicorn_task = asyncio.create_task(server.serve()) + + for _ in range(50): + if server.started: + break + await asyncio.sleep(0.1) + else: + msg = "Server did not start" + raise RuntimeError(msg) + + port = server.servers[0].sockets[0].getsockname()[1] + + async with HaberdasherClient(f"http://localhost:{port}") as client: + for _ in range(5): + hat = await client.make_hat(Size(inches=10)) + assert hat.size == 10 + assert hat.color == "blue" + + server.should_exit = True + await uvicorn_task + assert "Haberdasher failed to shut down" in logs.getvalue() From 71751d4f204a5e8c44eecaa431f2da125c676124 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 11 Nov 2025 15:42:06 +0900 Subject: [PATCH 2/8] Simpler Signed-off-by: Anuraag Agrawal --- .../test/gen/connectrpc/conformance/v1/service_connect.py | 2 +- example/example/eliza_connect.py | 2 +- protoc-gen-connect-python/generator/template.go | 2 +- src/connectrpc/_server_async.py | 2 +- test/haberdasher_connect.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py index 6993ce1..e7400cd 100644 --- a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py +++ b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py @@ -75,7 +75,7 @@ async def idempotent_unary( class ConformanceServiceASGIApplication(ConnectASGIApplication[ConformanceService]): def __init__( self, - service: ConformanceService | AsyncGenerator[ConformanceService, None], + service: ConformanceService | AsyncGenerator[ConformanceService], *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, diff --git a/example/example/eliza_connect.py b/example/example/eliza_connect.py index 4600502..90eb71b 100644 --- a/example/example/eliza_connect.py +++ b/example/example/eliza_connect.py @@ -42,7 +42,7 @@ def introduce( class ElizaServiceASGIApplication(ConnectASGIApplication[ElizaService]): def __init__( self, - service: ElizaService | AsyncGenerator[ElizaService, None], + service: ElizaService | AsyncGenerator[ElizaService], *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, diff --git a/protoc-gen-connect-python/generator/template.go b/protoc-gen-connect-python/generator/template.go index 4bf7d9d..4e03176 100644 --- a/protoc-gen-connect-python/generator/template.go +++ b/protoc-gen-connect-python/generator/template.go @@ -68,7 +68,7 @@ class {{.Name}}(Protocol):{{- range .Methods }} {{ end }} class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]): - def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}, None], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None: + def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None: super().__init__( service=service, endpoints=lambda svc: { {{- range .Methods }} diff --git a/src/connectrpc/_server_async.py b/src/connectrpc/_server_async.py index 191270f..0a319dc 100644 --- a/src/connectrpc/_server_async.py +++ b/src/connectrpc/_server_async.py @@ -85,7 +85,7 @@ def path(self) -> str: ... def __init__( self, *, - service: _SVC | AsyncGenerator[_SVC, None], + service: _SVC | AsyncGenerator[_SVC], endpoints: Callable[[_SVC], Mapping[str, Endpoint]], interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, diff --git a/test/haberdasher_connect.py b/test/haberdasher_connect.py index 79343a3..34a8537 100644 --- a/test/haberdasher_connect.py +++ b/test/haberdasher_connect.py @@ -57,7 +57,7 @@ async def do_nothing( class HaberdasherASGIApplication(ConnectASGIApplication[Haberdasher]): def __init__( self, - service: Haberdasher | AsyncGenerator[Haberdasher, None], + service: Haberdasher | AsyncGenerator[Haberdasher], *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, From d9fd56e532924436f2e2e42c5dc9ee4272967e8d Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 11 Nov 2025 15:47:43 +0900 Subject: [PATCH 3/8] Update protoc tests Signed-off-by: Anuraag Agrawal --- protoc-gen-connect-python/generator/generator_test.go | 2 +- protoc-gen-connect-python/generator/template.go | 2 +- protoc-gen-connect-python/generator/template_test.go | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/protoc-gen-connect-python/generator/generator_test.go b/protoc-gen-connect-python/generator/generator_test.go index a8cedb9..e15ed39 100644 --- a/protoc-gen-connect-python/generator/generator_test.go +++ b/protoc-gen-connect-python/generator/generator_test.go @@ -106,7 +106,7 @@ func TestGenerateConnectFile(t *testing.T) { } content := got.GetContent() - if !strings.Contains(content, "from collections.abc import AsyncIterator, Iterable, Iterator, Mapping") { + if !strings.Contains(content, "from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping") { t.Error("Generated code missing required imports") } if !strings.Contains(content, "class "+strings.Split(tt.input.GetService()[0].GetName(), ".")[0]) { diff --git a/protoc-gen-connect-python/generator/template.go b/protoc-gen-connect-python/generator/template.go index 4e03176..12a6b38 100644 --- a/protoc-gen-connect-python/generator/template.go +++ b/protoc-gen-connect-python/generator/template.go @@ -44,7 +44,7 @@ var ConnectTemplate = template.Must(template.New("ConnectTemplate").Parse(`# -*- # Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! # source: {{.FileName}} {{if .Services}} -from collections.abc import AsyncIterator, AsyncGenerator, Iterable, Iterator, Mapping +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping from typing import Protocol from connectrpc.client import ConnectClient, ConnectClientSync diff --git a/protoc-gen-connect-python/generator/template_test.go b/protoc-gen-connect-python/generator/template_test.go index 9afdffc..ff614c4 100644 --- a/protoc-gen-connect-python/generator/template_test.go +++ b/protoc-gen-connect-python/generator/template_test.go @@ -36,9 +36,9 @@ func TestConnectTemplate(t *testing.T) { }, }, contains: []string{ - "from collections.abc import AsyncIterator, Iterable, Iterator, Mapping", + "from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping", "class TestService(Protocol):", - "class TestServiceASGIApplication(ConnectASGIApplication):", + "class TestServiceASGIApplication(ConnectASGIApplication[TestService]):", "def TestMethod", }, }, From 5377386fd30f4cec2319d48ed7e63822c488c79f Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 11 Nov 2025 16:16:49 +0900 Subject: [PATCH 4/8] No need to use state Signed-off-by: Anuraag Agrawal --- src/connectrpc/_server_async.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/connectrpc/_server_async.py b/src/connectrpc/_server_async.py index 0a319dc..0f00d99 100644 --- a/src/connectrpc/_server_async.py +++ b/src/connectrpc/_server_async.py @@ -128,8 +128,7 @@ async def __call__( return None else: service = cast("_SVC", self._service) - if (state := scope.get("state")) is not None: - state["endpoints"] = self._resolve_endpoints(service) + self._resolved_endpoints = self._resolve_endpoints(service) await send({"type": "lifespan.startup.complete"}) case "lifespan.shutdown": if service_iter is not None: @@ -145,18 +144,15 @@ async def __call__( await send({"type": "lifespan.shutdown.complete"}) return None - if state := scope.get("state"): - endpoints: Mapping[str, Endpoint] = state["endpoints"] - else: - if not self._resolved_endpoints: - if inspect.isasyncgen(self._service): - msg = "ASGI server does not support lifespan but async generator passed for service. Enable lifespan support." - raise RuntimeError(msg) + if not self._resolved_endpoints: + if inspect.isasyncgen(self._service): + msg = "ASGI server does not support lifespan but async generator passed for service. Enable lifespan support." + raise RuntimeError(msg) - self._resolved_endpoints = self._resolve_endpoints( - cast("_SVC", self._service) - ) - endpoints = self._resolved_endpoints + self._resolved_endpoints = self._resolve_endpoints( + cast("_SVC", self._service) + ) + endpoints = self._resolved_endpoints ctx: RequestContext | None = None try: From 634cf820b08d165658573631d47ee97927a4bad2 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 11 Nov 2025 16:21:04 +0900 Subject: [PATCH 5/8] Add test for lifespan required but not available Signed-off-by: Anuraag Agrawal --- test/test_lifespan.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/test_lifespan.py b/test/test_lifespan.py index fcd5747..80f89a1 100644 --- a/test/test_lifespan.py +++ b/test/test_lifespan.py @@ -4,8 +4,11 @@ import pytest import uvicorn +from httpx import ASGITransport, AsyncClient from uvicorn.config import LOGGING_CONFIG +from connectrpc.errors import ConnectError + from .haberdasher_connect import ( Haberdasher, HaberdasherASGIApplication, @@ -160,3 +163,37 @@ async def counting_haberdasher(): server.should_exit = True await uvicorn_task assert "Haberdasher failed to shut down" in logs.getvalue() + + +@pytest.mark.asyncio +async def test_lifespan_not_supported() -> None: + class CountingHaberdasher(Haberdasher): + def __init__(self, counter: Counter) -> None: + self._counter = counter + + async def make_hat(self, request, ctx): + self._counter["requests"] += 1 + return Hat(size=request.inches, color="blue") + + final_count = None + + async def counting_haberdasher(): + counter = Counter() + try: + haberdasher = CountingHaberdasher(counter) + yield haberdasher + finally: + nonlocal final_count + final_count = counter["requests"] + + app = HaberdasherASGIApplication(counting_haberdasher()) + transport = ASGITransport(app) # pyright:ignore[reportArgumentType] - httpx type is not complete + async with HaberdasherClient( + "http://localhost", session=AsyncClient(transport=transport) + ) as client: + with pytest.raises(ConnectError) as e: + await client.make_hat(Size(inches=10)) + assert ( + "ASGI server does not support lifespan but async generator passed for service." + in str(e.value) + ) From 18566c093afb211730276f0fa631bd7dd1e42b1f Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 11 Nov 2025 18:00:02 +0900 Subject: [PATCH 6/8] Dedupe Signed-off-by: Anuraag Agrawal --- test/test_lifespan.py | 41 +++++++++-------------------------------- 1 file changed, 9 insertions(+), 32 deletions(-) diff --git a/test/test_lifespan.py b/test/test_lifespan.py index 80f89a1..b45dd45 100644 --- a/test/test_lifespan.py +++ b/test/test_lifespan.py @@ -17,16 +17,17 @@ from .haberdasher_pb2 import Hat, Size -@pytest.mark.asyncio -async def test_lifespan() -> None: - class CountingHaberdasher(Haberdasher): - def __init__(self, counter: Counter) -> None: - self._counter = counter +class CountingHaberdasher(Haberdasher): + def __init__(self, counter: Counter) -> None: + self._counter = counter + + async def make_hat(self, request, ctx): + self._counter["requests"] += 1 + return Hat(size=request.inches, color="blue") - async def make_hat(self, request, ctx): - self._counter["requests"] += 1 - return Hat(size=request.inches, color="blue") +@pytest.mark.asyncio +async def test_lifespan() -> None: final_count = None async def counting_haberdasher(): @@ -69,14 +70,6 @@ async def counting_haberdasher(): @pytest.mark.asyncio async def test_lifespan_startup_error() -> None: - class CountingHaberdasher(Haberdasher): - def __init__(self, counter: Counter) -> None: - self._counter = counter - - async def make_hat(self, request, ctx): - self._counter["requests"] += 1 - return Hat(size=request.inches, color="blue") - final_count = None async def counting_haberdasher(): @@ -112,14 +105,6 @@ async def counting_haberdasher(): @pytest.mark.asyncio async def test_lifespan_shutdown_error() -> None: - class CountingHaberdasher(Haberdasher): - def __init__(self, counter: Counter) -> None: - self._counter = counter - - async def make_hat(self, request, ctx): - self._counter["requests"] += 1 - return Hat(size=request.inches, color="blue") - async def counting_haberdasher(): counter = Counter() try: @@ -167,14 +152,6 @@ async def counting_haberdasher(): @pytest.mark.asyncio async def test_lifespan_not_supported() -> None: - class CountingHaberdasher(Haberdasher): - def __init__(self, counter: Counter) -> None: - self._counter = counter - - async def make_hat(self, request, ctx): - self._counter["requests"] += 1 - return Hat(size=request.inches, color="blue") - final_count = None async def counting_haberdasher(): From e6309aa9c9fde2f433efd7e8cb58e61eae46abb0 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 11 Nov 2025 18:01:51 +0900 Subject: [PATCH 7/8] Fix doublesend Signed-off-by: Anuraag Agrawal --- src/connectrpc/_server_async.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/connectrpc/_server_async.py b/src/connectrpc/_server_async.py index 0f00d99..bf49e56 100644 --- a/src/connectrpc/_server_async.py +++ b/src/connectrpc/_server_async.py @@ -141,6 +141,7 @@ async def __call__( "message": str(e), } ) + return None await send({"type": "lifespan.shutdown.complete"}) return None From 8620f9ddf85cbaa4d0bb8acdc445a104df1d16ad Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 11 Nov 2025 21:02:25 +0900 Subject: [PATCH 8/8] ABC Signed-off-by: Anuraag Agrawal --- src/connectrpc/_server_async.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/connectrpc/_server_async.py b/src/connectrpc/_server_async.py index bf49e56..176841d 100644 --- a/src/connectrpc/_server_async.py +++ b/src/connectrpc/_server_async.py @@ -1,7 +1,7 @@ import base64 import functools import inspect -from abc import abstractmethod +from abc import ABC, abstractmethod from asyncio import CancelledError, sleep from collections.abc import ( AsyncGenerator, @@ -73,7 +73,7 @@ ) -class ConnectASGIApplication(Generic[_SVC]): +class ConnectASGIApplication(ABC, Generic[_SVC]): """An ASGI application for the Connect protocol.""" _resolved_endpoints: Mapping[str, Endpoint] | None