diff --git a/cmd/protoc-gen-connect-python/generator/generator.go b/cmd/protoc-gen-connect-python/generator/generator.go index 74377c0..ed968b7 100644 --- a/cmd/protoc-gen-connect-python/generator/generator.go +++ b/cmd/protoc-gen-connect-python/generator/generator.go @@ -20,8 +20,6 @@ import ( "google.golang.org/protobuf/types/pluginpb" ) -var version = "devel" - type Config struct{} type Generator struct { @@ -103,6 +101,7 @@ type Method struct { Method string FullName string RPCType RPCType + Options protoreflect.ProtoMessage } type message struct { @@ -140,16 +139,17 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) { idx: i, Method: meth.GoName, FullName: fullname[:idx], + Options: meth.Desc.Options(), } // parse RPC type switch { + case meth.Desc.IsStreamingServer() && meth.Desc.IsStreamingClient(): + method.RPCType = BidirectionalStreaming case meth.Desc.IsStreamingServer(): method.RPCType = ServerStreaming case meth.Desc.IsStreamingClient(): method.RPCType = ClientStreaming - case meth.Desc.IsStreamingServer() && meth.Desc.IsStreamingClient(): - method.RPCType = BidirectionalStreaming default: method.RPCType = Unary } @@ -174,13 +174,23 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) { } p.P(`"""Generated connect code."""`) p.P() + p.P(`import abc`) p.P(`from enum import Enum`) p.P() - p.P(`from connect.client import Client`) - p.P(`from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse`) - p.P(`from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler`) - p.P(`from connect.options import ClientOptions, ConnectOptions`) + p.P(`from connect import (`) + p.P(` Client,`) + p.P(` ClientOptions,`) + p.P(` ConnectOptions,`) + p.P(` Handler,`) + p.P(` HandlerContext,`) + p.P(` IdempotencyLevel,`) + p.P(` StreamRequest,`) + p.P(` StreamResponse,`) + p.P(` UnaryRequest,`) + p.P(` UnaryResponse,`) + p.P(`)`) p.P(`from connect.connection_pool import AsyncConnectionPool`) + p.P(`from connect.handler import BidiStreamHandler, ClientStreamHandler, ServerStreamHandler, UnaryHandler`) p.P(`from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor`) p.P() @@ -191,7 +201,7 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) { var sb strings.Builder numSvc := len(p.services) if numSvc > 0 { - fmt.Fprintf(&sb, "from ..%s import ", svcNamePB) + fmt.Fprintf(&sb, "from ..%s import (\n", svcNamePB) } seem := make(map[string]bool) i := 0 @@ -201,22 +211,25 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) { case seem[svc.input.method] && seem[svc.output.method]: continue case !seem[svc.input.method] && seem[svc.output.method]: - fmt.Fprintf(&sb, "%s", svc.input.method) + fmt.Fprintf(&sb, " %s\n", svc.input.method) seem[svc.input.method] = true case seem[svc.input.method] && !seem[svc.output.method]: - fmt.Fprintf(&sb, "%s", svc.output.method) + fmt.Fprintf(&sb, " %s\n", svc.output.method) seem[svc.output.method] = true default: - fmt.Fprintf(&sb, "%s, %s", svc.input.method, svc.output.method) + fmt.Fprintf(&sb, " %s,\n %s", svc.input.method, svc.output.method) seem[svc.input.method] = true seem[svc.output.method] = true } if i <= numSvc-2 { - fmt.Fprint(&sb, ", ") + fmt.Fprint(&sb, ",\n") + } else { + fmt.Fprint(&sb, ",") } i++ } - p.P(strings.TrimSuffix(sb.String(), ", ")) + p.P(sb.String()) + p.P(`)`) p.P() p.P() procedures := upperSvcName + `Procedures` @@ -244,7 +257,13 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) { for _, meth := range sortedMap(p.services) { svc := p.services[meth] p.P(` `, `self.`, meth.Method, ` = `, `Client[`, svc.input.method, `, `, svc.output.method, `](`) - p.P(` `, `pool, `, `base_url + `, procedures+`.`+meth.Method+`.value, `, svc.input.method+`, `, svc.output.method, `, options`) + if options := meth.Options; options != nil { + if desc, ok := options.(*descriptorpb.MethodOptions); ok && desc.GetIdempotencyLevel() != descriptorpb.MethodOptions_IDEMPOTENCY_UNKNOWN { + p.P(` `, `pool, `, `base_url + `, procedures+`.`+meth.Method+`.value, `, svc.input.method+`, `, svc.output.method, `, ClientOptions(idempotency_level=IdempotencyLevel.`, desc.GetIdempotencyLevel().String(), `, enable_get=True).merge(options)`) + } else { + p.P(` `, `pool, `, `base_url + `, procedures+`.`+meth.Method+`.value, `, svc.input.method+`, `, svc.output.method, `, options`) + } + } switch meth.RPCType { case Unary: p.P(` `, `).call_unary`) @@ -252,12 +271,14 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) { p.P(` `, `).call_server_stream`) case ClientStreaming: p.P(` `, `).call_client_stream`) + case BidirectionalStreaming: + p.P(` `, `).call_bidi_stream`) } } p.P() p.P() handler := upperSvcName + `Handler` - p.P(`class `, handler, `:`) + p.P(`class `, handler, `(metaclass=abc.ABCMeta):`) p.P(` `, `"""Handler for the `, lowerCamelCase(upperSvcName), ` service."""`) p.P() j := 0 @@ -269,16 +290,15 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) { reqRPCType string respRPCType string ) - // TODO(zchee): BidirectionalStreaming? switch meth.RPCType { case Unary: reqRPCType = `UnaryRequest` respRPCType = `UnaryResponse` - case ServerStreaming, ClientStreaming: + case ServerStreaming, ClientStreaming, BidirectionalStreaming: reqRPCType = `StreamRequest` respRPCType = `StreamResponse` } - fmt.Fprintf(&sb, "%s[%s]) -> %s[%s]: ...", reqRPCType, svc.input.method, respRPCType, svc.output.method) + fmt.Fprintf(&sb, "%s[%s], context: HandlerContext) -> %s[%s]:\n raise NotImplementedError()", reqRPCType, svc.input.method, respRPCType, svc.output.method) p.P(sb.String()) if j <= len(p.services)-2 { p.P() @@ -295,7 +315,6 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) { rpcHandler string call string ) - // TODO(zchee): BidirectionalStreaming? switch meth.RPCType { case Unary: rpcHandler = `UnaryHandler` @@ -306,17 +325,25 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) { case ClientStreaming: rpcHandler = `ClientStreamHandler` call = fmt.Sprintf(" stream=service.%s,", meth.Method) + case BidirectionalStreaming: + rpcHandler = `BidiStreamHandler` + call = fmt.Sprintf(" stream=service.%s,", meth.Method) } - // TODO(zchee): BidirectionalStreaming? switch meth.RPCType { - case Unary, ServerStreaming, ClientStreaming: + case Unary, ServerStreaming, ClientStreaming, BidirectionalStreaming: p.P(` `, rpcHandler, `(`) p.P(` procedure=`, procedures+`.`+meth.Method+`.value,`) p.P(call) p.P(` input=`, svc.input.method, `,`) p.P(` output=`, svc.output.method, `,`) - p.P(` options=options,`) + if options := meth.Options; options != nil { + if desc, ok := options.(*descriptorpb.MethodOptions); ok && desc.GetIdempotencyLevel() != descriptorpb.MethodOptions_IDEMPOTENCY_UNKNOWN { + p.P(` options=ConnectOptions(idempotency_level=IdempotencyLevel.`, desc.GetIdempotencyLevel().String(), `).merge(options),`) + } else { + p.P(` options=options,`) + } + } p.P(` ),`) } } diff --git a/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect.py b/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect.py index bc6f9b9..cb31310 100644 --- a/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect.py +++ b/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect.py @@ -1,20 +1,39 @@ # Generated by the protoc-gen-connect-python. DO NOT EDIT! # source: examples/proto/connectrpc/eliza/v1/v1connect/eliza.proto # Protobuf Python Version: v5.29.3 -# protoc-gen-connect-python version: v0.0.0-20250517015031-b19a36b52499+dirty +# protoc-gen-connect-python version: v0.0.0-20250708090951-d93686e5039f """Generated connect code.""" +import abc from enum import Enum -from connect.client import Client -from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse -from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler -from connect.options import ClientOptions, ConnectOptions +from connect import ( + Client, + ClientOptions, + ConnectOptions, + Handler, + HandlerContext, + IdempotencyLevel, + StreamRequest, + StreamResponse, + UnaryRequest, + UnaryResponse, +) from connect.connection_pool import AsyncConnectionPool +from connect.handler import BidiStreamHandler, ClientStreamHandler, ServerStreamHandler, UnaryHandler from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor from .. import eliza_pb2 -from ..eliza_pb2 import SayRequest, SayResponse, ConverseRequest, ConverseResponse, IntroduceRequest, IntroduceResponse, ReflectRequest, ReflectResponse +from ..eliza_pb2 import ( + SayRequest, + SayResponse, + ConverseRequest, + ConverseResponse, + IntroduceRequest, + IntroduceResponse, + ReflectRequest, + ReflectResponse, +) class ElizaServiceProcedures(Enum): @@ -39,11 +58,11 @@ def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOpti base_url = base_url.removesuffix("/") self.Say = Client[SayRequest, SayResponse]( - pool, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, options + pool, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, ClientOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, enable_get=True).merge(options) ).call_unary self.Converse = Client[ConverseRequest, ConverseResponse]( pool, base_url + ElizaServiceProcedures.Converse.value, ConverseRequest, ConverseResponse, options - ).call_server_stream + ).call_bidi_stream self.Introduce = Client[IntroduceRequest, IntroduceResponse]( pool, base_url + ElizaServiceProcedures.Introduce.value, IntroduceRequest, IntroduceResponse, options ).call_server_stream @@ -52,16 +71,20 @@ def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOpti ).call_client_stream -class ElizaServiceHandler: +class ElizaServiceHandler(metaclass=abc.ABCMeta): """Handler for the elizaService service.""" - async def Say(self, request: UnaryRequest[SayRequest]) -> UnaryResponse[SayResponse]: ... + async def Say(self, request: UnaryRequest[SayRequest], context: HandlerContext) -> UnaryResponse[SayResponse]: + raise NotImplementedError() - async def Converse(self, request: StreamRequest[ConverseRequest]) -> StreamResponse[ConverseResponse]: ... + async def Converse(self, request: StreamRequest[ConverseRequest], context: HandlerContext) -> StreamResponse[ConverseResponse]: + raise NotImplementedError() - async def Introduce(self, request: StreamRequest[IntroduceRequest]) -> StreamResponse[IntroduceResponse]: ... + async def Introduce(self, request: StreamRequest[IntroduceRequest], context: HandlerContext) -> StreamResponse[IntroduceResponse]: + raise NotImplementedError() - async def Reflect(self, request: StreamRequest[ReflectRequest]) -> StreamResponse[ReflectResponse]: ... + async def Reflect(self, request: StreamRequest[ReflectRequest], context: HandlerContext) -> StreamResponse[ReflectResponse]: + raise NotImplementedError() def create_ElizaService_handlers(service: ElizaServiceHandler, options: ConnectOptions | None = None) -> list[Handler]: @@ -71,9 +94,9 @@ def create_ElizaService_handlers(service: ElizaServiceHandler, options: ConnectO unary=service.Say, input=SayRequest, output=SayResponse, - options=options, + options=ConnectOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS).merge(options), ), - ServerStreamHandler( + BidiStreamHandler( procedure=ElizaServiceProcedures.Converse.value, stream=service.Converse, input=ConverseRequest, diff --git a/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py b/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py index f567890..0c888b8 100644 --- a/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py +++ b/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py @@ -58,7 +58,7 @@ def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOpti base_url = base_url.removesuffix("/") self.Say = Client[SayRequest, SayResponse]( - pool, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, options + pool, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, ClientOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, enable_get=True).merge(options) ).call_unary self.Converse = Client[ConverseRequest, ConverseResponse]( pool, base_url + ElizaServiceProcedures.Converse.value, ConverseRequest, ConverseResponse, options