Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions conformance/test/gen/connectrpc/conformance/v1/service_connect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT!
# source: connectrpc/conformance/v1/service.proto

from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping
from typing import Protocol

from connectrpc.client import ConnectClient, ConnectClientSync
Expand Down Expand Up @@ -72,16 +72,17 @@ async def idempotent_unary(
raise ConnectError(Code.UNIMPLEMENTED, "Not implemented")


class ConformanceServiceASGIApplication(ConnectASGIApplication):
class ConformanceServiceASGIApplication(ConnectASGIApplication[ConformanceService]):
def __init__(
self,
service: ConformanceService,
service: ConformanceService | AsyncGenerator[ConformanceService],
*,
interceptors: Iterable[Interceptor] = (),
read_max_bytes: int | None = None,
) -> None:
super().__init__(
endpoints={
service=service,
endpoints=lambda svc: {
"/connectrpc.conformance.v1.ConformanceService/Unary": Endpoint.unary(
method=MethodInfo(
name="Unary",
Expand All @@ -90,7 +91,7 @@ def __init__(
output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnaryResponse,
idempotency_level=IdempotencyLevel.UNKNOWN,
),
function=service.unary,
function=svc.unary,
),
"/connectrpc.conformance.v1.ConformanceService/ServerStream": Endpoint.server_stream(
method=MethodInfo(
Expand All @@ -100,7 +101,7 @@ def __init__(
output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.ServerStreamResponse,
idempotency_level=IdempotencyLevel.UNKNOWN,
),
function=service.server_stream,
function=svc.server_stream,
),
"/connectrpc.conformance.v1.ConformanceService/ClientStream": Endpoint.client_stream(
method=MethodInfo(
Expand All @@ -110,7 +111,7 @@ def __init__(
output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.ClientStreamResponse,
idempotency_level=IdempotencyLevel.UNKNOWN,
),
function=service.client_stream,
function=svc.client_stream,
),
"/connectrpc.conformance.v1.ConformanceService/BidiStream": Endpoint.bidi_stream(
method=MethodInfo(
Expand All @@ -120,7 +121,7 @@ def __init__(
output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.BidiStreamResponse,
idempotency_level=IdempotencyLevel.UNKNOWN,
),
function=service.bidi_stream,
function=svc.bidi_stream,
),
"/connectrpc.conformance.v1.ConformanceService/Unimplemented": Endpoint.unary(
method=MethodInfo(
Expand All @@ -130,7 +131,7 @@ def __init__(
output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnimplementedResponse,
idempotency_level=IdempotencyLevel.UNKNOWN,
),
function=service.unimplemented,
function=svc.unimplemented,
),
"/connectrpc.conformance.v1.ConformanceService/IdempotentUnary": Endpoint.unary(
method=MethodInfo(
Expand All @@ -140,7 +141,7 @@ def __init__(
output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.IdempotentUnaryResponse,
idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS,
),
function=service.idempotent_unary,
function=svc.idempotent_unary,
),
},
interceptors=interceptors,
Expand Down
15 changes: 8 additions & 7 deletions example/example/eliza_connect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT!
# source: example/eliza.proto

from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping
from typing import Protocol

from connectrpc.client import ConnectClient, ConnectClientSync
Expand Down Expand Up @@ -39,16 +39,17 @@ def introduce(
raise ConnectError(Code.UNIMPLEMENTED, "Not implemented")


class ElizaServiceASGIApplication(ConnectASGIApplication):
class ElizaServiceASGIApplication(ConnectASGIApplication[ElizaService]):
def __init__(
self,
service: ElizaService,
service: ElizaService | AsyncGenerator[ElizaService],
*,
interceptors: Iterable[Interceptor] = (),
read_max_bytes: int | None = None,
) -> None:
super().__init__(
endpoints={
service=service,
endpoints=lambda svc: {
"/connectrpc.eliza.v1.ElizaService/Say": Endpoint.unary(
method=MethodInfo(
name="Say",
Expand All @@ -57,7 +58,7 @@ def __init__(
output=example_dot_eliza__pb2.SayResponse,
idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS,
),
function=service.say,
function=svc.say,
),
"/connectrpc.eliza.v1.ElizaService/Converse": Endpoint.bidi_stream(
method=MethodInfo(
Expand All @@ -67,7 +68,7 @@ def __init__(
output=example_dot_eliza__pb2.ConverseResponse,
idempotency_level=IdempotencyLevel.UNKNOWN,
),
function=service.converse,
function=svc.converse,
),
"/connectrpc.eliza.v1.ElizaService/Introduce": Endpoint.server_stream(
method=MethodInfo(
Expand All @@ -77,7 +78,7 @@ def __init__(
output=example_dot_eliza__pb2.IntroduceResponse,
idempotency_level=IdempotencyLevel.UNKNOWN,
),
function=service.introduce,
function=svc.introduce,
),
},
interceptors=interceptors,
Expand Down
2 changes: 1 addition & 1 deletion protoc-gen-connect-python/generator/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestGenerateConnectFile(t *testing.T) {
}

content := got.GetContent()
if !strings.Contains(content, "from collections.abc import AsyncIterator, Iterable, Iterator, Mapping") {
if !strings.Contains(content, "from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping") {
t.Error("Generated code missing required imports")
}
if !strings.Contains(content, "class "+strings.Split(tt.input.GetService()[0].GetName(), ".")[0]) {
Expand Down
11 changes: 6 additions & 5 deletions protoc-gen-connect-python/generator/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ var ConnectTemplate = template.Must(template.New("ConnectTemplate").Parse(`# -*-
# Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT!
# source: {{.FileName}}
{{if .Services}}
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping
from typing import Protocol

from connectrpc.client import ConnectClient, ConnectClientSync
Expand All @@ -67,10 +67,11 @@ class {{.Name}}(Protocol):{{- range .Methods }}
raise ConnectError(Code.UNIMPLEMENTED, "Not implemented")
{{ end }}

class {{.Name}}ASGIApplication(ConnectASGIApplication):
def __init__(self, service: {{.Name}}, *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None:
class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]):
def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None:
super().__init__(
endpoints={ {{- range .Methods }}
service=service,
endpoints=lambda svc: { {{- range .Methods }}
"/{{.ServiceName}}/{{.Name}}": Endpoint.{{.EndpointType}}(
method=MethodInfo(
name="{{.Name}}",
Expand All @@ -79,7 +80,7 @@ class {{.Name}}ASGIApplication(ConnectASGIApplication):
output={{.OutputType}},
idempotency_level=IdempotencyLevel.{{.IdempotencyLevel}},
),
function=service.{{.PythonName}},
function=svc.{{.PythonName}},
),{{- end }}
},
interceptors=interceptors,
Expand Down
4 changes: 2 additions & 2 deletions protoc-gen-connect-python/generator/template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ func TestConnectTemplate(t *testing.T) {
},
},
contains: []string{
"from collections.abc import AsyncIterator, Iterable, Iterator, Mapping",
"from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping",
"class TestService(Protocol):",
"class TestServiceASGIApplication(ConnectASGIApplication):",
"class TestServiceASGIApplication(ConnectASGIApplication[TestService]):",
"def TestMethod",
},
},
Expand Down
99 changes: 85 additions & 14 deletions src/connectrpc/_server_async.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import base64
import functools
import inspect
from abc import ABC, abstractmethod
from asyncio import CancelledError, sleep
from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
from collections.abc import (
AsyncGenerator,
AsyncIterator,
Callable,
Iterable,
Mapping,
Sequence,
)
from dataclasses import replace
from http import HTTPStatus
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Generic, TypeVar, cast
from urllib.parse import parse_qs

from . import _compression, _server_shared
Expand Down Expand Up @@ -48,6 +56,7 @@
Scope = "asgiref.typing.Scope"


_SVC = TypeVar("_SVC")
_REQ = TypeVar("_REQ")
_RES = TypeVar("_RES")

Expand All @@ -64,45 +73,96 @@
)


class ConnectASGIApplication(ABC):
class ConnectASGIApplication(ABC, Generic[_SVC]):
"""An ASGI application for the Connect protocol."""

_resolved_endpoints: Mapping[str, Endpoint] | None

@property
@abstractmethod
def path(self) -> str: ...

def __init__(
self,
*,
endpoints: Mapping[str, Endpoint],
service: _SVC | AsyncGenerator[_SVC],
endpoints: Callable[[_SVC], Mapping[str, Endpoint]],
interceptors: Iterable[Interceptor] = (),
read_max_bytes: int | None = None,
) -> None:
"""Initialize the ASGI application."""
super().__init__()
if interceptors:
interceptors = resolve_interceptors(interceptors)
endpoints = {
path: _apply_interceptors(endpoint, interceptors)
for path, endpoint in endpoints.items()
}
self._service = service
self._endpoints = endpoints
self._interceptors = interceptors
self._resolved_endpoints = None
self._read_max_bytes = read_max_bytes

async def __call__(
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
assert scope["type"] == "http" # noqa: S101 - only for type narrowing, in practice always true
if scope["type"] == "websocket":
msg = "connect does not support websockets"
raise RuntimeError(msg)

if scope["type"] == "lifespan":
service_iter = None
while True:
msg = await receive()
match msg["type"]:
case "lifespan.startup":
# Need to cast since type checking doesn't seem to narrow well with isasyncgen
if inspect.isasyncgen(self._service):
service_iter = cast(
"AsyncGenerator[_SVC, None]", self._service
)
try:
service = await anext(service_iter)
except Exception as e:
await send(
{
"type": "lifespan.startup.failed",
"message": str(e),
}
)
return None
else:
service = cast("_SVC", self._service)
self._resolved_endpoints = self._resolve_endpoints(service)
await send({"type": "lifespan.startup.complete"})
case "lifespan.shutdown":
if service_iter is not None:
try:
await service_iter.aclose()
except Exception as e:
await send(
{
"type": "lifespan.shutdown.failed",
"message": str(e),
}
)
return None
await send({"type": "lifespan.shutdown.complete"})
return None

if not self._resolved_endpoints:
if inspect.isasyncgen(self._service):
msg = "ASGI server does not support lifespan but async generator passed for service. Enable lifespan support."
raise RuntimeError(msg)

self._resolved_endpoints = self._resolve_endpoints(
cast("_SVC", self._service)
)
endpoints = self._resolved_endpoints

ctx: RequestContext | None = None
try:
path = scope["path"]
endpoint = self._endpoints.get(path)
endpoint = endpoints.get(path)
if not endpoint and scope["root_path"]:
# The application was mounted at some root so try stripping the prefix.
path = path.removeprefix(scope["root_path"])
endpoint = self._endpoints.get(path)

endpoint = endpoints.get(path)
if not endpoint:
raise HTTPException(HTTPStatus.NOT_FOUND, [])

Expand Down Expand Up @@ -381,6 +441,17 @@ async def _handle_error(
)
await send({"type": "http.response.body", "body": body, "more_body": False})

def _resolve_endpoints(self, service: _SVC) -> Mapping[str, Endpoint]:
resolved_endpoints = self._endpoints(service)
if self._interceptors:
resolved_endpoints = {
path: _apply_interceptors(
endpoint, resolve_interceptors(self._interceptors)
)
for path, endpoint in resolved_endpoints.items()
}
return resolved_endpoints


async def _send_stream_response_headers(
send: ASGISendCallable, codec: Codec, compression_name: str, ctx: RequestContext
Expand Down
Loading