Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 50 additions & 23 deletions cmd/protoc-gen-connect-python/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import (
"google.golang.org/protobuf/types/pluginpb"
)

var version = "devel"

type Config struct{}

type Generator struct {
Expand Down Expand Up @@ -103,6 +101,7 @@ type Method struct {
Method string
FullName string
RPCType RPCType
Options protoreflect.ProtoMessage
}

type message struct {
Expand Down Expand Up @@ -140,16 +139,17 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) {
idx: i,
Method: meth.GoName,
FullName: fullname[:idx],
Options: meth.Desc.Options(),
}

// parse RPC type
switch {
case meth.Desc.IsStreamingServer() && meth.Desc.IsStreamingClient():
method.RPCType = BidirectionalStreaming
case meth.Desc.IsStreamingServer():
method.RPCType = ServerStreaming
case meth.Desc.IsStreamingClient():
method.RPCType = ClientStreaming
case meth.Desc.IsStreamingServer() && meth.Desc.IsStreamingClient():
method.RPCType = BidirectionalStreaming
default:
method.RPCType = Unary
}
Expand All @@ -174,13 +174,23 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) {
}
p.P(`"""Generated connect code."""`)
p.P()
p.P(`import abc`)
p.P(`from enum import Enum`)
p.P()
p.P(`from connect.client import Client`)
p.P(`from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse`)
p.P(`from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler`)
p.P(`from connect.options import ClientOptions, ConnectOptions`)
p.P(`from connect import (`)
p.P(` Client,`)
p.P(` ClientOptions,`)
p.P(` ConnectOptions,`)
p.P(` Handler,`)
p.P(` HandlerContext,`)
p.P(` IdempotencyLevel,`)
p.P(` StreamRequest,`)
p.P(` StreamResponse,`)
p.P(` UnaryRequest,`)
p.P(` UnaryResponse,`)
p.P(`)`)
p.P(`from connect.connection_pool import AsyncConnectionPool`)
p.P(`from connect.handler import BidiStreamHandler, ClientStreamHandler, ServerStreamHandler, UnaryHandler`)
p.P(`from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor`)
p.P()

Expand All @@ -191,7 +201,7 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) {
var sb strings.Builder
numSvc := len(p.services)
if numSvc > 0 {
fmt.Fprintf(&sb, "from ..%s import ", svcNamePB)
fmt.Fprintf(&sb, "from ..%s import (\n", svcNamePB)
}
seem := make(map[string]bool)
i := 0
Expand All @@ -201,22 +211,25 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) {
case seem[svc.input.method] && seem[svc.output.method]:
continue
case !seem[svc.input.method] && seem[svc.output.method]:
fmt.Fprintf(&sb, "%s", svc.input.method)
fmt.Fprintf(&sb, " %s\n", svc.input.method)
seem[svc.input.method] = true
case seem[svc.input.method] && !seem[svc.output.method]:
fmt.Fprintf(&sb, "%s", svc.output.method)
fmt.Fprintf(&sb, " %s\n", svc.output.method)
seem[svc.output.method] = true
default:
fmt.Fprintf(&sb, "%s, %s", svc.input.method, svc.output.method)
fmt.Fprintf(&sb, " %s,\n %s", svc.input.method, svc.output.method)
seem[svc.input.method] = true
seem[svc.output.method] = true
}
if i <= numSvc-2 {
fmt.Fprint(&sb, ", ")
fmt.Fprint(&sb, ",\n")
} else {
fmt.Fprint(&sb, ",")
}
i++
}
p.P(strings.TrimSuffix(sb.String(), ", "))
p.P(sb.String())
p.P(`)`)
p.P()
p.P()
procedures := upperSvcName + `Procedures`
Expand Down Expand Up @@ -244,20 +257,28 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) {
for _, meth := range sortedMap(p.services) {
svc := p.services[meth]
p.P(` `, `self.`, meth.Method, ` = `, `Client[`, svc.input.method, `, `, svc.output.method, `](`)
p.P(` `, `pool, `, `base_url + `, procedures+`.`+meth.Method+`.value, `, svc.input.method+`, `, svc.output.method, `, options`)
if options := meth.Options; options != nil {
if desc, ok := options.(*descriptorpb.MethodOptions); ok && desc.GetIdempotencyLevel() != descriptorpb.MethodOptions_IDEMPOTENCY_UNKNOWN {
p.P(` `, `pool, `, `base_url + `, procedures+`.`+meth.Method+`.value, `, svc.input.method+`, `, svc.output.method, `, ClientOptions(idempotency_level=IdempotencyLevel.`, desc.GetIdempotencyLevel().String(), `, enable_get=True).merge(options)`)
} else {
p.P(` `, `pool, `, `base_url + `, procedures+`.`+meth.Method+`.value, `, svc.input.method+`, `, svc.output.method, `, options`)
}
}
switch meth.RPCType {
case Unary:
p.P(` `, `).call_unary`)
case ServerStreaming:
p.P(` `, `).call_server_stream`)
case ClientStreaming:
p.P(` `, `).call_client_stream`)
case BidirectionalStreaming:
p.P(` `, `).call_bidi_stream`)
}
}
p.P()
p.P()
handler := upperSvcName + `Handler`
p.P(`class `, handler, `:`)
p.P(`class `, handler, `(metaclass=abc.ABCMeta):`)
p.P(` `, `"""Handler for the `, lowerCamelCase(upperSvcName), ` service."""`)
p.P()
j := 0
Expand All @@ -269,16 +290,15 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) {
reqRPCType string
respRPCType string
)
// TODO(zchee): BidirectionalStreaming?
switch meth.RPCType {
case Unary:
reqRPCType = `UnaryRequest`
respRPCType = `UnaryResponse`
case ServerStreaming, ClientStreaming:
case ServerStreaming, ClientStreaming, BidirectionalStreaming:
reqRPCType = `StreamRequest`
respRPCType = `StreamResponse`
}
fmt.Fprintf(&sb, "%s[%s]) -> %s[%s]: ...", reqRPCType, svc.input.method, respRPCType, svc.output.method)
fmt.Fprintf(&sb, "%s[%s], context: HandlerContext) -> %s[%s]:\n raise NotImplementedError()", reqRPCType, svc.input.method, respRPCType, svc.output.method)
p.P(sb.String())
if j <= len(p.services)-2 {
p.P()
Expand All @@ -295,7 +315,6 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) {
rpcHandler string
call string
)
// TODO(zchee): BidirectionalStreaming?
switch meth.RPCType {
case Unary:
rpcHandler = `UnaryHandler`
Expand All @@ -306,17 +325,25 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) {
case ClientStreaming:
rpcHandler = `ClientStreamHandler`
call = fmt.Sprintf(" stream=service.%s,", meth.Method)
case BidirectionalStreaming:
rpcHandler = `BidiStreamHandler`
call = fmt.Sprintf(" stream=service.%s,", meth.Method)
}

// TODO(zchee): BidirectionalStreaming?
switch meth.RPCType {
case Unary, ServerStreaming, ClientStreaming:
case Unary, ServerStreaming, ClientStreaming, BidirectionalStreaming:
p.P(` `, rpcHandler, `(`)
p.P(` procedure=`, procedures+`.`+meth.Method+`.value,`)
p.P(call)
p.P(` input=`, svc.input.method, `,`)
p.P(` output=`, svc.output.method, `,`)
p.P(` options=options,`)
if options := meth.Options; options != nil {
if desc, ok := options.(*descriptorpb.MethodOptions); ok && desc.GetIdempotencyLevel() != descriptorpb.MethodOptions_IDEMPOTENCY_UNKNOWN {
p.P(` options=ConnectOptions(idempotency_level=IdempotencyLevel.`, desc.GetIdempotencyLevel().String(), `).merge(options),`)
} else {
p.P(` options=options,`)
}
}
p.P(` ),`)
}
}
Expand Down
53 changes: 38 additions & 15 deletions examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
# Generated by the protoc-gen-connect-python. DO NOT EDIT!
# source: examples/proto/connectrpc/eliza/v1/v1connect/eliza.proto
# Protobuf Python Version: v5.29.3
# protoc-gen-connect-python version: v0.0.0-20250517015031-b19a36b52499+dirty
# protoc-gen-connect-python version: v0.0.0-20250708090951-d93686e5039f
"""Generated connect code."""

import abc
from enum import Enum

from connect.client import Client
from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse
from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler
from connect.options import ClientOptions, ConnectOptions
from connect import (
Client,
ClientOptions,
ConnectOptions,
Handler,
HandlerContext,
IdempotencyLevel,
StreamRequest,
StreamResponse,
UnaryRequest,
UnaryResponse,
)
from connect.connection_pool import AsyncConnectionPool
from connect.handler import BidiStreamHandler, ClientStreamHandler, ServerStreamHandler, UnaryHandler
from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor

from .. import eliza_pb2
from ..eliza_pb2 import SayRequest, SayResponse, ConverseRequest, ConverseResponse, IntroduceRequest, IntroduceResponse, ReflectRequest, ReflectResponse
from ..eliza_pb2 import (
SayRequest,
SayResponse,
ConverseRequest,
ConverseResponse,
IntroduceRequest,
IntroduceResponse,
ReflectRequest,
ReflectResponse,
)


class ElizaServiceProcedures(Enum):
Expand All @@ -39,11 +58,11 @@ def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOpti
base_url = base_url.removesuffix("/")

self.Say = Client[SayRequest, SayResponse](
pool, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, options
pool, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, ClientOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, enable_get=True).merge(options)
).call_unary
self.Converse = Client[ConverseRequest, ConverseResponse](
pool, base_url + ElizaServiceProcedures.Converse.value, ConverseRequest, ConverseResponse, options
).call_server_stream
).call_bidi_stream
self.Introduce = Client[IntroduceRequest, IntroduceResponse](
pool, base_url + ElizaServiceProcedures.Introduce.value, IntroduceRequest, IntroduceResponse, options
).call_server_stream
Expand All @@ -52,16 +71,20 @@ def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOpti
).call_client_stream


class ElizaServiceHandler:
class ElizaServiceHandler(metaclass=abc.ABCMeta):
"""Handler for the elizaService service."""

async def Say(self, request: UnaryRequest[SayRequest]) -> UnaryResponse[SayResponse]: ...
async def Say(self, request: UnaryRequest[SayRequest], context: HandlerContext) -> UnaryResponse[SayResponse]:
raise NotImplementedError()

async def Converse(self, request: StreamRequest[ConverseRequest]) -> StreamResponse[ConverseResponse]: ...
async def Converse(self, request: StreamRequest[ConverseRequest], context: HandlerContext) -> StreamResponse[ConverseResponse]:
raise NotImplementedError()

async def Introduce(self, request: StreamRequest[IntroduceRequest]) -> StreamResponse[IntroduceResponse]: ...
async def Introduce(self, request: StreamRequest[IntroduceRequest], context: HandlerContext) -> StreamResponse[IntroduceResponse]:
raise NotImplementedError()

async def Reflect(self, request: StreamRequest[ReflectRequest]) -> StreamResponse[ReflectResponse]: ...
async def Reflect(self, request: StreamRequest[ReflectRequest], context: HandlerContext) -> StreamResponse[ReflectResponse]:
raise NotImplementedError()


def create_ElizaService_handlers(service: ElizaServiceHandler, options: ConnectOptions | None = None) -> list[Handler]:
Expand All @@ -71,9 +94,9 @@ def create_ElizaService_handlers(service: ElizaServiceHandler, options: ConnectO
unary=service.Say,
input=SayRequest,
output=SayResponse,
options=options,
options=ConnectOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS).merge(options),
),
ServerStreamHandler(
BidiStreamHandler(
procedure=ElizaServiceProcedures.Converse.value,
stream=service.Converse,
input=ConverseRequest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOpti
base_url = base_url.removesuffix("/")

self.Say = Client[SayRequest, SayResponse](
pool, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, options
pool, base_url + ElizaServiceProcedures.Say.value, SayRequest, SayResponse, ClientOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, enable_get=True).merge(options)
).call_unary
self.Converse = Client[ConverseRequest, ConverseResponse](
pool, base_url + ElizaServiceProcedures.Converse.value, ConverseRequest, ConverseResponse, options
Expand Down
Loading