diff --git a/cmd/protoc-gen-connect-python/generator/generator.go b/cmd/protoc-gen-connect-python/generator/generator.go index 6d29b83..74377c0 100644 --- a/cmd/protoc-gen-connect-python/generator/generator.go +++ b/cmd/protoc-gen-connect-python/generator/generator.go @@ -180,7 +180,7 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) { p.P(`from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse`) p.P(`from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler`) p.P(`from connect.options import ClientOptions, ConnectOptions`) - p.P(`from connect.session import AsyncClientSession`) + p.P(`from connect.connection_pool import AsyncConnectionPool`) p.P(`from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor`) p.P() @@ -238,13 +238,13 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) { p.P() p.P() p.P(`class `, upperSvcName, `Client:`) - p.P(` def __init__(self, base_url: str, session: AsyncClientSession, options: ClientOptions | None = None) -> None:`) + p.P(` def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOptions | None = None) -> None:`) p.P(` base_url = base_url.removesuffix("/")`) p.P() for _, meth := range sortedMap(p.services) { svc := p.services[meth] p.P(` `, `self.`, meth.Method, ` = `, `Client[`, svc.input.method, `, `, svc.output.method, `](`) - p.P(` `, `session, `, `base_url + `, procedures+`.`+meth.Method+`.value, `, svc.input.method+`, `, svc.output.method, `, options`) + p.P(` `, `pool, `, `base_url + `, procedures+`.`+meth.Method+`.value, `, svc.input.method+`, `, svc.output.method, `, options`) switch meth.RPCType { case Unary: p.P(` `, `).call_unary`) diff --git a/conformance/client_runner.py b/conformance/client_runner.py index dcb3587..0e37df4 100755 --- a/conformance/client_runner.py +++ b/conformance/client_runner.py @@ -11,10 +11,10 @@ from typing import Any from connect.connect import StreamRequest, UnaryRequest +from connect.connection_pool import AsyncConnectionPool from connect.error import ConnectError from connect.headers import Headers from connect.options import ClientOptions -from connect.session import AsyncClientSession from google.protobuf import any_pb2 from google.protobuf.internal.containers import RepeatedCompositeFieldContainer @@ -182,7 +182,7 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c - Captures and returns errors in the response if exceptions occur. Note: - - This function uses an asynchronous HTTP client session (`AsyncClientSession`) + - This function uses an asynchronous HTTP client connection pool (`AsyncConnectionPool`) for making requests. - Compression (e.g., gzip) is applied if specified in the request. - Headers and trailers are converted to protobuf-compatible formats. @@ -215,7 +215,7 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c url = f"{proto}://{msg.host}:{msg.port}" - async with AsyncClientSession(http1=http1, http2=http2, ssl_context=ssl_context) as session: + async with AsyncConnectionPool(http1=http1, http2=http2, ssl_context=ssl_context) as pool: payloads = [] try: options = ClientOptions() @@ -231,7 +231,7 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c if msg.codec == config_pb2.CODEC_JSON: options.use_binary_format = False - client = service_connect.ConformanceServiceClient(base_url=url, session=session, options=options) + client = service_connect.ConformanceServiceClient(base_url=url, pool=pool, options=options) if msg.stream_type == config_pb2.STREAM_TYPE_UNARY: if msg.request_delay_ms > 0: await asyncio.sleep(msg.request_delay_ms / 1000) diff --git a/conformance/gen/connectrpc/conformance/v1/conformancev1connect/service_connect.py b/conformance/gen/connectrpc/conformance/v1/conformancev1connect/service_connect.py index 057d7e9..ee2a8ec 100644 --- a/conformance/gen/connectrpc/conformance/v1/conformancev1connect/service_connect.py +++ b/conformance/gen/connectrpc/conformance/v1/conformancev1connect/service_connect.py @@ -10,7 +10,7 @@ import connect.connect from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler, BidiStreamHandler from connect.options import ClientOptions, ConnectOptions -from connect.session import AsyncClientSession +from connect.connection_pool import AsyncConnectionPool from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor from connect.idempotency_level import IdempotencyLevel @@ -40,26 +40,26 @@ class ConformanceServiceProcedures(Enum): class ConformanceServiceClient: - def __init__(self, base_url: str, session: AsyncClientSession, options: ClientOptions | None = None) -> None: + def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOptions | None = None) -> None: base_url = base_url.removesuffix("/") self.Unary = Client[UnaryRequest, UnaryResponse]( - session, base_url + ConformanceServiceProcedures.Unary.value, UnaryRequest, UnaryResponse, options + pool, base_url + ConformanceServiceProcedures.Unary.value, UnaryRequest, UnaryResponse, options ).call_unary self.ServerStream = Client[ServerStreamRequest, ServerStreamResponse]( - session, base_url + ConformanceServiceProcedures.ServerStream.value, ServerStreamRequest, ServerStreamResponse, options + pool, base_url + ConformanceServiceProcedures.ServerStream.value, ServerStreamRequest, ServerStreamResponse, options ).call_server_stream self.ClientStream = Client[ClientStreamRequest, ClientStreamResponse]( - session, base_url + ConformanceServiceProcedures.ClientStream.value, ClientStreamRequest, ClientStreamResponse, options + pool, base_url + ConformanceServiceProcedures.ClientStream.value, ClientStreamRequest, ClientStreamResponse, options ).call_client_stream self.BidiStream = Client[BidiStreamRequest, BidiStreamResponse]( - session, base_url + ConformanceServiceProcedures.BidiStream.value, BidiStreamRequest, BidiStreamResponse, options + pool, base_url + ConformanceServiceProcedures.BidiStream.value, BidiStreamRequest, BidiStreamResponse, options ).call_bidi_stream self.Unimplemented = Client[UnimplementedRequest, UnimplementedResponse]( - session, base_url + ConformanceServiceProcedures.Unimplemented.value, UnimplementedRequest, UnimplementedResponse, options + pool, base_url + ConformanceServiceProcedures.Unimplemented.value, UnimplementedRequest, UnimplementedResponse, options ).call_unary self.IdempotentUnary = Client[IdempotentUnaryRequest, IdempotentUnaryResponse]( - session, base_url + ConformanceServiceProcedures.IdempotentUnary.value, IdempotentUnaryRequest, IdempotentUnaryResponse, ClientOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, enable_get=True).merge(options), + pool, base_url + ConformanceServiceProcedures.IdempotentUnary.value, IdempotentUnaryRequest, IdempotentUnaryResponse, ClientOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, enable_get=True).merge(options), ).call_unary diff --git a/examples/client.py b/examples/client.py index 69d4c50..e5f3d35 100644 --- a/examples/client.py +++ b/examples/client.py @@ -4,7 +4,7 @@ import logging from connect.connect import UnaryRequest -from connect.session import AsyncClientSession +from connect.connection_pool import AsyncConnectionPool from proto.connectrpc.eliza.v1.eliza_pb2 import SayRequest from proto.connectrpc.eliza.v1.v1connect.eliza_connect_pb2 import ElizaServiceClient @@ -15,9 +15,9 @@ async def main() -> None: """Interact with the ElizaServiceClient asynchronously.""" - async with AsyncClientSession() as session: + async with AsyncConnectionPool() as pool: client = ElizaServiceClient( - session=session, + pool=pool, base_url="http://localhost:8080/", ) response = await client.Say(UnaryRequest(SayRequest(sentence="I feel happy."))) diff --git a/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect.py b/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect.py index c52e011..9b545e8 100644 --- a/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect.py +++ b/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect.py @@ -8,13 +8,13 @@ from connect.client import Client from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse +from connect.connection_pool import AsyncConnectionPool from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler from connect.options import ClientOptions, ConnectOptions -from connect.session import AsyncClientSession from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor from .. import eliza_pb2 -from ..eliza_pb2 import SayRequest, SayResponse, ConverseRequest, ConverseResponse, IntroduceRequest, IntroduceResponse +from ..eliza_pb2 import ConverseRequest, ConverseResponse, IntroduceRequest, IntroduceResponse, SayRequest, SayResponse class ElizaServiceProcedures(Enum): @@ -35,20 +35,20 @@ class ElizaServiceProcedures(Enum): class ElizaServiceClient: - def __init__(self, base_url: str, session: AsyncClientSession, options: ClientOptions | None = None) -> None: + def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOptions | None = None) -> None: base_url = base_url.removesuffix("/") self.Say = Client[SayRequest, SayResponse]( - session, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, options + pool, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, options ).call_unary self.Converse = Client[ConverseRequest, ConverseResponse]( - session, base_url + ElizaServiceProcedures.Converse.value, ConverseRequest, ConverseResponse, options + pool, base_url + ElizaServiceProcedures.Converse.value, ConverseRequest, ConverseResponse, options ).call_server_stream self.IntroduceServer = Client[IntroduceRequest, IntroduceResponse]( - session, base_url + ElizaServiceProcedures.IntroduceServer.value, IntroduceRequest, IntroduceResponse, options + pool, base_url + ElizaServiceProcedures.IntroduceServer.value, IntroduceRequest, IntroduceResponse, options ).call_server_stream self.IntroduceClient = Client[IntroduceRequest, IntroduceResponse]( - session, base_url + ElizaServiceProcedures.IntroduceClient.value, IntroduceRequest, IntroduceResponse, options + pool, base_url + ElizaServiceProcedures.IntroduceClient.value, IntroduceRequest, IntroduceResponse, options ).call_client_stream diff --git a/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py b/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py index 6690b5e..919973a 100644 --- a/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py +++ b/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py @@ -9,9 +9,9 @@ from connect.client import Client from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse +from connect.connection_pool import AsyncConnectionPool from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler from connect.options import ClientOptions, ConnectOptions -from connect.session import AsyncClientSession from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor from .. import eliza_pb2 @@ -38,21 +38,21 @@ class ElizaServiceProcedures(Enum): class ElizaServiceClient: - def __init__(self, base_url: str, session: AsyncClientSession, options: ClientOptions | None = None) -> None: + def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOptions | None = None) -> None: base_url = base_url.removesuffix("/") self.Say = Client[SayRequest, SayResponse]( - session, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, options + pool, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, options ).call_unary self.IntroduceServer = Client[IntroduceRequest, IntroduceResponse]( - session, + pool, base_url + ElizaServiceProcedures.IntroduceServer.value, IntroduceRequest, IntroduceResponse, options, ).call_server_stream self.IntroduceClient = Client[IntroduceRequest, IntroduceResponse]( - session, + pool, base_url + ElizaServiceProcedures.IntroduceClient.value, IntroduceRequest, IntroduceResponse, diff --git a/src/connect/client.py b/src/connect/client.py index e74fbb2..8c0329b 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -23,6 +23,7 @@ recieve_stream_response, recieve_unary_response, ) +from connect.connection_pool import AsyncConnectionPool from connect.error import ConnectError from connect.idempotency_level import IdempotencyLevel from connect.interceptor import apply_interceptors @@ -30,7 +31,6 @@ from connect.protocol import Protocol, ProtocolClient, ProtocolClientParams from connect.protocol_connect.connect_protocol import ProtocolConnect from connect.protocol_grpc.grpc_protocol import ProtocolGRPC -from connect.session import AsyncClientSession from connect.utils import aiterate @@ -186,7 +186,7 @@ class Client[T_Request, T_Response]: def __init__( self, - session: AsyncClientSession, + pool: AsyncConnectionPool, url: str, input: type[T_Request], output: type[T_Response], @@ -195,7 +195,7 @@ def __init__( """Initialize the client with the given URL, request and response types, and optional client options. Args: - session (AsyncClientSession): The client session to use for the connection. + pool (AsyncConnectionPool): The connection pool to use for making requests. url (str): The URL of the server to connect to. input (type[T_Request]): The type of the request object. output (type[T_Response]): The type of the response object. @@ -212,7 +212,7 @@ def __init__( protocol_client = config.protocol.client( ProtocolClientParams( - session=session, + pool=pool, codec=config.codec, url=config.url, compression_name=config.request_compression_name, diff --git a/src/connect/connection_pool.py b/src/connect/connection_pool.py new file mode 100644 index 0000000..030e742 --- /dev/null +++ b/src/connect/connection_pool.py @@ -0,0 +1,3 @@ +"""Provides connection pool functionality using httpcore's AsyncConnectionPool.""" + +from httpcore import AsyncConnectionPool as AsyncConnectionPool diff --git a/src/connect/protocol.py b/src/connect/protocol.py index 348129d..30a4f3e 100644 --- a/src/connect/protocol.py +++ b/src/connect/protocol.py @@ -16,12 +16,12 @@ StreamingHandlerConn, StreamType, ) +from connect.connection_pool import AsyncConnectionPool from connect.error import ConnectError from connect.headers import Headers from connect.idempotency_level import IdempotencyLevel from connect.request import Request from connect.response_writer import ServerResponseWriter -from connect.session import AsyncClientSession PROTOCOL_CONNECT = "connect" PROTOCOL_GRPC = "grpc" @@ -70,7 +70,7 @@ class ProtocolClientParams(BaseModel): arbitrary_types_allowed=True, ) - session: AsyncClientSession + pool: AsyncConnectionPool codec: Codec url: URL compression_name: str | None = Field(default=None) diff --git a/src/connect/protocol_connect/connect_client.py b/src/connect/protocol_connect/connect_client.py index e8cb072..8fca3dc 100644 --- a/src/connect/protocol_connect/connect_client.py +++ b/src/connect/protocol_connect/connect_client.py @@ -26,6 +26,7 @@ StreamType, ensure_single, ) +from connect.connection_pool import AsyncConnectionPool from connect.content_stream import BoundAsyncStream from connect.error import ConnectError from connect.headers import Headers, include_request_headers @@ -59,7 +60,6 @@ from connect.protocol_connect.error_json import error_from_json from connect.protocol_connect.marshaler import ConnectStreamingMarshaler, ConnectUnaryRequestMarshaler from connect.protocol_connect.unmarshaler import ConnectStreamingUnmarshaler, ConnectUnaryUnmarshaler -from connect.session import AsyncClientSession from connect.utils import ( map_httpcore_exceptions, ) @@ -147,7 +147,7 @@ def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: conn: StreamingClientConn if spec.stream_type == StreamType.Unary: conn = ConnectUnaryClientConn( - session=self.params.session, + pool=self.params.pool, spec=spec, peer=self.peer, url=self.params.url, @@ -172,7 +172,7 @@ def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: conn.marshaler.stable_codec = self.params.codec else: conn = ConnectStreamingClientConn( - session=self.params.session, + pool=self.params.pool, spec=spec, peer=self.peer, url=self.params.url, @@ -212,7 +212,7 @@ class ConnectUnaryClientConn(StreamingClientConn): """ - session: AsyncClientSession + pool: AsyncConnectionPool _spec: Spec _peer: Peer url: URL @@ -226,7 +226,7 @@ class ConnectUnaryClientConn(StreamingClientConn): def __init__( self, - session: AsyncClientSession, + pool: AsyncConnectionPool, spec: Spec, peer: Peer, url: URL, @@ -239,7 +239,7 @@ def __init__( """Initialize the ConnectProtocol instance. Args: - session (AsyncClientSession): The session for the connection. + pool (AsyncConnectionPool): The connection pool for the client. spec (Spec): The specification for the connection. peer (Peer): The peer information. url (URL): The URL for the connection. @@ -255,7 +255,7 @@ def __init__( """ event_hooks = {} if event_hooks is None else event_hooks - self.session = session + self.pool = pool self._spec = spec self._peer = peer self.url = url @@ -412,9 +412,9 @@ async def send( with map_httpcore_exceptions(): if not abort_event: - response = await self.session.pool.handle_async_request(request=request) + response = await self.pool.handle_async_request(request=request) else: - request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) + request_task = asyncio.create_task(self.pool.handle_async_request(request=request)) abort_task = asyncio.create_task(abort_event.wait()) done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) @@ -563,7 +563,7 @@ class ConnectStreamingClientConn(StreamingClientConn): def __init__( self, - session: AsyncClientSession, + pool: AsyncConnectionPool, spec: Spec, peer: Peer, url: URL, @@ -577,7 +577,7 @@ def __init__( """Initialize a new instance of the class. Args: - session (AsyncClientSession): The session object for the connection. + pool (AsyncConnectionPool): The connection pool for the client. spec (Spec): The specification object. peer (Peer): The peer object. url (URL): The URL for the connection. @@ -594,7 +594,7 @@ def __init__( """ event_hooks = {} if event_hooks is None else event_hooks - self.session = session + self.pool = pool self._spec = spec self._peer = peer self.url = url @@ -771,9 +771,9 @@ async def send( with map_httpcore_exceptions(): if not abort_event: - response = await self.session.pool.handle_async_request(request) + response = await self.pool.handle_async_request(request) else: - request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) + request_task = asyncio.create_task(self.pool.handle_async_request(request=request)) abort_task = asyncio.create_task(abort_event.wait()) done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) diff --git a/src/connect/protocol_grpc/grpc_client.py b/src/connect/protocol_grpc/grpc_client.py index 51e69b9..87908e8 100644 --- a/src/connect/protocol_grpc/grpc_client.py +++ b/src/connect/protocol_grpc/grpc_client.py @@ -19,6 +19,7 @@ StreamingClientConn, StreamType, ) +from connect.connection_pool import AsyncConnectionPool from connect.content_stream import BoundAsyncStream from connect.error import ConnectError from connect.headers import Headers, include_request_headers @@ -41,7 +42,6 @@ from connect.protocol_grpc.error_trailer import grpc_error_from_trailer from connect.protocol_grpc.marshaler import GRPCMarshaler from connect.protocol_grpc.unmarshaler import GRPCUnmarshaler -from connect.session import AsyncClientSession from connect.utils import map_httpcore_exceptions EventHook = Callable[..., Any] @@ -51,7 +51,7 @@ class GRPCClient(ProtocolClient): """GRPCClient is a protocol client implementation for gRPC communication, supporting both standard and web environments. Attributes: - params (ProtocolClientParams): Configuration parameters for the protocol client, including codec, compression, session, and URL. + params (ProtocolClientParams): Configuration parameters for the protocol client, including codec, compression, pool, and URL. _peer (Peer): The peer instance associated with this client, representing the remote endpoint. web (bool): Indicates whether the client is running in a web environment, affecting header and content-type handling. @@ -133,14 +133,14 @@ def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: StreamingClientConn: An initialized gRPC streaming client connection. Details: - - Configures the connection with parameters such as session, peer, URL, codec, and compression settings. + - Configures the connection with parameters such as pool, peer, URL, codec, and compression settings. - Initializes GRPCMarshaler and GRPCUnmarshaler with appropriate codecs and limits. - Compression is determined using the provided compression name and available compressions. """ return GRPCClientConn( web=self.web, - session=self.params.session, + pool=self.params.pool, spec=spec, peer=self.peer, url=self.params.url, @@ -164,10 +164,10 @@ def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: class GRPCClientConn(StreamingClientConn): """GRPCClientConn is a gRPC client connection implementation supporting asynchronous streaming requests and responses over HTTP/2. - This class manages the lifecycle of a gRPC client connection, including marshaling and unmarshaling messages, handling request and response headers/trailers, managing compression, and supporting event hooks for request/response events. It integrates with an asynchronous HTTP client session and supports cancellation via asyncio events. + This class manages the lifecycle of a gRPC client connection, including marshaling and unmarshaling messages, handling request and response headers/trailers, managing compression, and supporting event hooks for request/response events. It integrates with an asynchronous HTTP client connection pool and supports cancellation via asyncio events. Attributes: - session (AsyncClientSession): The asynchronous client session used for HTTP requests. + pool (AsyncConnectionPool): The asynchronous connection pool for managing HTTP/2 connections. _spec (Spec): The protocol or API specification. _peer (Peer): Information about the remote peer. url (URL): The endpoint URL for the connection. @@ -183,7 +183,7 @@ class GRPCClientConn(StreamingClientConn): """ web: bool - session: AsyncClientSession + pool: AsyncConnectionPool _spec: Spec _peer: Peer url: URL @@ -199,7 +199,7 @@ class GRPCClientConn(StreamingClientConn): def __init__( self, web: bool, - session: AsyncClientSession, + pool: AsyncConnectionPool, spec: Spec, peer: Peer, url: URL, @@ -214,7 +214,7 @@ def __init__( Args: web (bool): Indicates if the connection is for a web environment. - session (AsyncClientSession): The asynchronous client session to use for requests. + pool (AsyncConnectionPool): The connection pool for managing HTTP/2 connections. spec (Spec): The specification object describing the protocol or API. peer (Peer): The peer information for the connection. url (URL): The URL endpoint for the connection. @@ -229,7 +229,7 @@ def __init__( event_hooks = {} if event_hooks is None else event_hooks self.web = web - self.session = session + self.pool = pool self._spec = spec self._peer = peer self.url = url @@ -363,9 +363,9 @@ async def send( with map_httpcore_exceptions(): if not abort_event: - response = await self.session.pool.handle_async_request(request) + response = await self.pool.handle_async_request(request) else: - request_task = asyncio.create_task(self.session.pool.handle_async_request(request=request)) + request_task = asyncio.create_task(self.pool.handle_async_request(request=request)) abort_task = asyncio.create_task(abort_event.wait()) done, _ = await asyncio.wait({request_task, abort_task}, return_when=asyncio.FIRST_COMPLETED) diff --git a/src/connect/session.py b/src/connect/session.py deleted file mode 100644 index 3344d25..0000000 --- a/src/connect/session.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Provides an asynchronous client session for managing HTTP connections.""" - -import abc -import ssl -import types -import typing -from types import TracebackType -from typing import Self - -import httpcore - -from connect.utils import map_httpcore_exceptions - - -class AbstractAsyncContextManager(abc.ABC): - """Abstract base class for an asynchronous context manager.""" - - async def __aenter__(self) -> Self: - """Enter the context manager and return the instance.""" - return self - - @abc.abstractmethod - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Exit the context manager and handle any exceptions that occur.""" - return None - - -class AsyncClientSession(AbstractAsyncContextManager): - """An asynchronous client session for managing HTTP connections. - - This class provides an asynchronous context manager for managing a pool of HTTP connections. - It supports both HTTP/1 and HTTP/2 protocols and allows configuration of various connection - parameters such as SSL context, proxy, maximum connections, keep-alive settings, retries, - and network backend. - - Attributes: - pool (httpcore.AsyncConnectionPool): The connection pool used for managing HTTP connections. - - Args: - ssl_context (ssl.SSLContext | None): The SSL context for secure connections. Defaults to None. - proxy (httpcore.Proxy | None): The proxy configuration. Defaults to None. - max_connections (int | None): The maximum number of connections. Defaults to 10. - max_keepalive_connections (int | None): The maximum number of keep-alive connections. Defaults to None. - keepalive_expiry (float | None): The keep-alive expiry time in seconds. Defaults to None. - http1 (bool): Whether to support HTTP/1 protocol. Defaults to True. - http2 (bool): Whether to support HTTP/2 protocol. Defaults to True. - retries (int): The number of retries for failed requests. Defaults to 0. - local_address (str | None): The local address to bind to. Defaults to None. - uds (str | None): The Unix domain socket to bind to. Defaults to None. - network_backend (httpcore.AsyncNetworkBackend | None): The network backend to use. Defaults to None. - socket_options (typing.Iterable[httpcore.SOCKET_OPTION] | None): The socket options to set. Defaults to None. - - """ - - def __init__( - self, - ssl_context: ssl.SSLContext | None = None, - proxy: httpcore.Proxy | None = None, - max_connections: int | None = 10, - max_keepalive_connections: int | None = None, - keepalive_expiry: float | None = None, - http1: bool = False, - http2: bool = True, # because bidi-streams are not supported in HTTP/1 - retries: int = 0, - local_address: str | None = None, - uds: str | None = None, - network_backend: httpcore.AsyncNetworkBackend | None = None, - socket_options: typing.Iterable[httpcore.SOCKET_OPTION] | None = None, - ) -> None: - """Initialize the connection pool with the given parameters.""" - self.pool = httpcore.AsyncConnectionPool( - ssl_context=ssl_context, - proxy=proxy, - max_connections=max_connections, - max_keepalive_connections=max_keepalive_connections, - keepalive_expiry=keepalive_expiry, - http1=http1, - http2=http2, - retries=retries, - local_address=local_address, - uds=uds, - network_backend=network_backend, - socket_options=socket_options, - ) - - async def aclose(self) -> None: - """Close the connection pool.""" - await self.pool.aclose() - - async def __aenter__(self) -> "AsyncClientSession": - """Enter the context manager and return the instance.""" - await self.pool.__aenter__() - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None = None, - exc: BaseException | None = None, - tb: types.TracebackType | None = None, - ) -> None: - """Exit the context manager and handle any exceptions that occur.""" - with map_httpcore_exceptions(): - await self.pool.__aexit__(exc_type, exc, tb) diff --git a/tests/test_streaming_connect_client.py b/tests/test_streaming_connect_client.py index dfd8c7e..999491a 100644 --- a/tests/test_streaming_connect_client.py +++ b/tests/test_streaming_connect_client.py @@ -11,11 +11,11 @@ from connect.client import Client from connect.code import Code from connect.connect import StreamRequest, StreamResponse, StreamType +from connect.connection_pool import AsyncConnectionPool from connect.envelope import Envelope, EnvelopeFlags from connect.error import ConnectError from connect.interceptor import Interceptor, StreamFunc from connect.options import ClientOptions -from connect.session import AsyncClientSession from tests.conftest import ASGIRequest, Receive, Scope, Send, ServerConfig from tests.testdata.ping.v1.ping_pb2 import PingRequest, PingResponse from tests.testdata.ping.v1.v1connect.ping_connect import PingServiceProcedures @@ -65,8 +65,8 @@ async def server_streaming(scope: Scope, receive: Receive, send: Send) -> None: async def test_server_streaming(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: - client = Client(session=session, url=url, input=PingRequest, output=PingResponse) + async with AsyncConnectionPool() as pool: + client = Client(pool=pool, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(content=PingRequest(name="Bob")) async with client.call_server_stream(ping_request) as response: @@ -116,8 +116,8 @@ async def server_streaming_end_stream_error(scope: Scope, receive: Receive, send async def test_server_streaming_end_stream_error(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: - client = Client(session=session, url=url, input=PingRequest, output=PingResponse) + async with AsyncConnectionPool() as pool: + client = Client(pool=pool, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(content=PingRequest(name="Bob")) async with client.call_server_stream(ping_request) as response: @@ -179,8 +179,8 @@ async def server_streaming_received_message_after_end_stream(scope: Scope, recei async def test_server_streaming_received_message_after_end_stream(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: - client = Client(session=session, url=url, input=PingRequest, output=PingResponse) + async with AsyncConnectionPool() as pool: + client = Client(pool=pool, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(content=PingRequest(name="Bob")) async with client.call_server_stream(ping_request) as response: @@ -244,8 +244,8 @@ async def server_streaming_received_extra_end_stream(scope: Scope, receive: Rece async def test_server_streaming_received_extra_end_stream(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: - client = Client(session=session, url=url, input=PingRequest, output=PingResponse) + async with AsyncConnectionPool() as pool: + client = Client(pool=pool, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(content=PingRequest(name="Bob")) async with client.call_server_stream(ping_request) as response: @@ -295,8 +295,8 @@ async def server_streaming_not_received_end_stream(scope: Scope, receive: Receiv async def test_server_streaming_not_received_end_stream(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: - client = Client(session=session, url=url, input=PingRequest, output=PingResponse) + async with AsyncConnectionPool() as pool: + client = Client(pool=pool, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(content=PingRequest(name="Bob")) async with client.call_server_stream(ping_request) as response: @@ -352,8 +352,8 @@ async def server_streaming_response_envelope_message_compression(scope: Scope, r async def test_server_streaming_response_envelope_message_compression(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: - client = Client(session=session, url=url, input=PingRequest, output=PingResponse) + async with AsyncConnectionPool() as pool: + client = Client(pool=pool, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(content=PingRequest(name="Bob")) async with client.call_server_stream(ping_request) as response: @@ -411,9 +411,9 @@ async def server_streaming_request_envelope_message_compression(scope: Scope, re async def test_server_streaming_request_envelope_message_compression(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: + async with AsyncConnectionPool() as pool: client = Client( - session=session, + pool=pool, url=url, input=PingRequest, output=PingResponse, @@ -470,9 +470,9 @@ async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: return _wrapped - async with AsyncClientSession() as session: + async with AsyncConnectionPool() as pool: client = Client( - session=session, + pool=pool, url=url, input=PingRequest, output=PingResponse, @@ -514,8 +514,8 @@ async def server_streaming_not_httpstatus_200(scope: Scope, receive: Receive, se async def test_server_streaming_not_httpstatus_200(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: - client = Client(session=session, url=url, input=PingRequest, output=PingResponse) + async with AsyncConnectionPool() as pool: + client = Client(pool=pool, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(content=PingRequest(name="Bob")) with pytest.raises(ConnectError) as excinfo: @@ -576,8 +576,8 @@ async def iterator() -> AsyncIterator[PingRequest]: for message in messages: yield PingRequest(name=message) - async with AsyncClientSession() as session: - client = Client(session=session, url=url, input=PingRequest, output=PingResponse) + async with AsyncConnectionPool() as pool: + client = Client(pool=pool, url=url, input=PingRequest, output=PingResponse) ping_request = StreamRequest(content=iterator()) async with client.call_client_stream(ping_request) as response: @@ -632,9 +632,9 @@ async def _wrapped(request: StreamRequest[Any]) -> StreamResponse[Any]: async def iterator() -> AsyncIterator[PingRequest]: yield PingRequest(name="test") - async with AsyncClientSession() as session: + async with AsyncConnectionPool() as pool: client = Client( - session=session, + pool=pool, url=url, input=PingRequest, output=PingResponse, diff --git a/tests/test_unary_connect_client.py b/tests/test_unary_connect_client.py index 48831d1..04c4ce8 100644 --- a/tests/test_unary_connect_client.py +++ b/tests/test_unary_connect_client.py @@ -10,11 +10,11 @@ from connect.client import Client from connect.code import Code from connect.connect import StreamType, UnaryRequest, UnaryResponse +from connect.connection_pool import AsyncConnectionPool from connect.error import ConnectError from connect.idempotency_level import IdempotencyLevel from connect.interceptor import Interceptor, UnaryFunc from connect.options import ClientOptions -from connect.session import AsyncClientSession from tests.conftest import ASGIRequest, Receive, Scope, Send, ServerConfig from tests.testdata.ping.v1.ping_pb2 import PingRequest, PingResponse from tests.testdata.ping.v1.v1connect.ping_connect import PingServiceProcedures @@ -25,8 +25,8 @@ async def test_post_application_proto(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: - client = Client(session=session, url=url, input=PingRequest, output=PingResponse) + async with AsyncConnectionPool() as pool: + client = Client(pool=pool, url=url, input=PingRequest, output=PingResponse) ping_request = UnaryRequest(content=PingRequest(name="test")) response = await client.call_unary(ping_request) @@ -56,8 +56,8 @@ async def post_response_gzip(scope: Scope, receive: Receive, send: Send) -> None async def test_post_response_gzip(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: - client = Client(session=session, url=url, input=PingRequest, output=PingResponse) + async with AsyncConnectionPool() as pool: + client = Client(pool=pool, url=url, input=PingRequest, output=PingResponse) ping_request = UnaryRequest(content=PingRequest(name="test")) await client.call_unary(ping_request) @@ -90,9 +90,9 @@ async def post_request_gzip(scope: Scope, receive: Receive, send: Send) -> None: async def test_post_request_gzip(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: + async with AsyncConnectionPool() as pool: client = Client( - session=session, + pool=pool, url=url, input=PingRequest, output=PingResponse, @@ -154,9 +154,9 @@ async def get_application_proto(scope: Scope, receive: Receive, send: Send) -> N async def test_get_application_proto(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: + async with AsyncConnectionPool() as pool: client = Client( - session=session, + pool=pool, url=url, input=PingRequest, output=PingResponse, @@ -186,8 +186,8 @@ async def post_not_found(scope: Scope, receive: Receive, send: Send) -> None: async def test_post_not_found(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: - client = Client(session=session, url=url, input=PingRequest, output=PingResponse) + async with AsyncConnectionPool() as pool: + client = Client(pool=pool, url=url, input=PingRequest, output=PingResponse) ping_request = UnaryRequest(content=PingRequest(name="test")) with pytest.raises(ConnectError) as excinfo: @@ -218,8 +218,8 @@ async def post_invalid_content_type_prefix(scope: Scope, receive: Receive, send: async def test_post_invalid_content_type_prefix(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: - client = Client(session=session, url=url, input=PingRequest, output=PingResponse) + async with AsyncConnectionPool() as pool: + client = Client(pool=pool, url=url, input=PingRequest, output=PingResponse) ping_request = UnaryRequest(content=PingRequest(name="test")) with pytest.raises(ConnectError) as excinfo: @@ -268,8 +268,8 @@ async def test_post_error_details(hypercorn_server: ServerConfig) -> None: url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: - client = Client(session=session, url=url, input=PingRequest, output=PingResponse) + async with AsyncConnectionPool() as pool: + client = Client(pool=pool, url=url, input=PingRequest, output=PingResponse) ping_request = UnaryRequest(content=PingRequest(name="test")) with pytest.raises(ConnectError) as excinfo: @@ -328,8 +328,8 @@ async def test_post_compressed_error_details(hypercorn_server: ServerConfig) -> url = hypercorn_server.base_url + PingServiceProcedures.Ping.value + "/proto" - async with AsyncClientSession() as session: - client = Client(session=session, url=url, input=PingRequest, output=PingResponse) + async with AsyncConnectionPool() as pool: + client = Client(pool=pool, url=url, input=PingRequest, output=PingResponse) ping_request = UnaryRequest(content=PingRequest(name="test")) with pytest.raises(ConnectError) as excinfo: @@ -393,9 +393,9 @@ async def _wrapped(request: UnaryRequest[Any]) -> UnaryResponse[Any]: return _wrapped - async with AsyncClientSession() as session: + async with AsyncConnectionPool() as pool: client = Client( - session=session, + pool=pool, url=url, input=PingRequest, output=PingResponse,