Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 148 additions & 149 deletions cmd/protoc-gen-connect-python/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,165 +165,164 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) {
},
}
}
}

p.P(`# Generated by the protoc-gen-connect-python. DO NOT EDIT!`)
p.P(`# source: `, f.GeneratedFilenamePrefix, `.proto`)
p.P(`# Protobuf Python Version: `, protocVersion(g.plugin))
if bi, ok := debug.ReadBuildInfo(); ok {
p.P(`# protoc-gen-connect-python version: `, bi.Main.Version)
}
p.P(`"""Generated connect code."""`)
p.P()
p.P(`from enum import Enum`)
p.P()
p.P(`from connect.client import Client`)
p.P(`from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse`)
p.P(`from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler`)
p.P(`from connect.options import ClientOptions, ConnectOptions`)
p.P(`from connect.session import AsyncClientSession`)
p.P(`from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor`)
p.P()

svcName := filepath.Base(f.GeneratedFilenamePrefix)
upperSvcName := camelCase(svcName)
svcNameService := upperSvcName + `Service`
svcNamePB := svcName + "_pb2"
p.P(`from `, `..`, ` import `, svcNamePB)
var sb strings.Builder
numSvc := len(p.services)
if numSvc > 0 {
fmt.Fprintf(&sb, "from ..%s import ", svcNamePB)
}
seem := make(map[string]bool)
i := 0
for _, meth := range sortedMap(p.services) {
svc := p.services[meth]
switch {
case seem[svc.input.method] && seem[svc.output.method]:
continue
case !seem[svc.input.method] && seem[svc.output.method]:
fmt.Fprintf(&sb, "%s", svc.input.method)
seem[svc.input.method] = true
case seem[svc.input.method] && !seem[svc.output.method]:
fmt.Fprintf(&sb, "%s", svc.output.method)
seem[svc.output.method] = true
default:
fmt.Fprintf(&sb, "%s, %s", svc.input.method, svc.output.method)
seem[svc.input.method] = true
seem[svc.output.method] = true
p.P(`# Generated by the protoc-gen-connect-python. DO NOT EDIT!`)
p.P(`# source: `, f.GeneratedFilenamePrefix, `.proto`)
p.P(`# Protobuf Python Version: `, protocVersion(g.plugin))
if bi, ok := debug.ReadBuildInfo(); ok {
p.P(`# protoc-gen-connect-python version: `, bi.Main.Version)
}
if i <= numSvc-2 {
fmt.Fprint(&sb, ", ")
p.P(`"""Generated connect code."""`)
p.P()
p.P(`from enum import Enum`)
p.P()
p.P(`from connect.client import Client`)
p.P(`from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse`)
p.P(`from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler`)
p.P(`from connect.options import ClientOptions, ConnectOptions`)
p.P(`from connect.session import AsyncClientSession`)
p.P(`from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor`)
p.P()

svcName := svc.GoName
upperSvcName := camelCase(svcName)
svcNamePB := filepath.Base(f.GeneratedFilenamePrefix) + "_pb2"
p.P(`from `, `..`, ` import `, svcNamePB)
var sb strings.Builder
numSvc := len(p.services)
if numSvc > 0 {
fmt.Fprintf(&sb, "from ..%s import ", svcNamePB)
}
i++
}
p.P(strings.TrimSuffix(sb.String(), ", "))
p.P()
p.P()
procedures := svcNameService + `Procedures`
p.P(`class `, procedures, `(Enum):`)
p.P(` """Procedures for the `, svcName, ` service."""`)
p.P()
for _, meth := range sortedMap(p.services) {
p.P(` `, meth.Method, ` = `, strconv.Quote(`/`+filepath.Join(meth.FullName, meth.Method)))
}
p.P()
p.P()
serviceDescriptor := svcNameService + `_service_descriptor`
p.P(serviceDescriptor, `: `, `ServiceDescriptor`, ` = `, svcNamePB+`.DESCRIPTOR.services_by_name[`, strconv.Quote(svcNameService), `]`)
p.P()
for _, meth := range sortedMap(p.services) {
methodDescriptor := svcNameService + meth.Method + `_method_descriptor`
p.P(methodDescriptor+`: `, `MethodDescriptor = `, serviceDescriptor+`.methods_by_name[`, strconv.Quote(meth.Method), `]`)
}
p.P()
p.P()
p.P(`class `, upperSvcName, `Client:`)
p.P(` def __init__(self, base_url: str, session: AsyncClientSession, options: ClientOptions | None = None) -> None:`)
p.P(` base_url = base_url.removesuffix("/")`)
p.P()
for _, meth := range sortedMap(p.services) {
svc := p.services[meth]
p.P(` `, `self.`, meth.Method, ` = `, `Client[`, svc.input.method, `, `, svc.output.method, `](`)
p.P(` `, `session, `, `base_url + `, procedures+`.`+meth.Method+`.value, `, svc.input.method+`, `, svc.output.method, `, options`)
switch meth.RPCType {
case Unary:
p.P(` `, `).call_unary`)
case ServerStreaming:
p.P(` `, `).call_server_stream`)
case ClientStreaming:
p.P(` `, `).call_client_stream`)
seem := make(map[string]bool)
i := 0
for _, meth := range sortedMap(p.services) {
svc := p.services[meth]
switch {
case seem[svc.input.method] && seem[svc.output.method]:
continue
case !seem[svc.input.method] && seem[svc.output.method]:
fmt.Fprintf(&sb, "%s", svc.input.method)
seem[svc.input.method] = true
case seem[svc.input.method] && !seem[svc.output.method]:
fmt.Fprintf(&sb, "%s", svc.output.method)
seem[svc.output.method] = true
default:
fmt.Fprintf(&sb, "%s, %s", svc.input.method, svc.output.method)
seem[svc.input.method] = true
seem[svc.output.method] = true
}
if i <= numSvc-2 {
fmt.Fprint(&sb, ", ")
}
i++
}
}
p.P()
p.P()
handler := svcNameService + `Handler`
p.P(`class `, handler, `:`)
p.P(` `, `"""Handler for the `, lowerCamelCase(upperSvcName), ` service."""`)
p.P()
j := 0
for _, meth := range sortedMap(p.services) {
svc := p.services[meth]
sb.Reset()
fmt.Fprintf(&sb, " async def %s(self, request: ", meth.Method)
var (
reqRPCType string
respRPCType string
)
// TODO(zchee): BidirectionalStreaming?
switch meth.RPCType {
case Unary:
reqRPCType = `UnaryRequest`
respRPCType = `UnaryResponse`
case ServerStreaming, ClientStreaming:
reqRPCType = `StreamRequest`
respRPCType = `StreamResponse`
p.P(strings.TrimSuffix(sb.String(), ", "))
p.P()
p.P()
procedures := upperSvcName + `Procedures`
p.P(`class `, procedures, `(Enum):`)
p.P(` """Procedures for the `, svcName, ` service."""`)
p.P()
for _, meth := range sortedMap(p.services) {
p.P(` `, meth.Method, ` = `, strconv.Quote(`/`+filepath.Join(meth.FullName, meth.Method)))
}
fmt.Fprintf(&sb, "%s[%s]) -> %s[%s]: ...", reqRPCType, svc.input.method, respRPCType, svc.output.method)
p.P(sb.String())
if j <= len(p.services)-2 {
p.P()
p.P()
p.P()
serviceDescriptor := upperSvcName + `_service_descriptor`
p.P(serviceDescriptor, `: `, `ServiceDescriptor`, ` = `, svcNamePB+`.DESCRIPTOR.services_by_name[`, strconv.Quote(upperSvcName), `]`)
p.P()
for _, meth := range sortedMap(p.services) {
methodDescriptor := upperSvcName + meth.Method + `_method_descriptor`
p.P(methodDescriptor+`: `, `MethodDescriptor = `, serviceDescriptor+`.methods_by_name[`, strconv.Quote(meth.Method), `]`)
}
j++
}
p.P()
p.P()
p.P(`def create_`, svcNameService, `_handlers`, `(`, `service: `, handler, `, options: ConnectOptions | None = None`, `) -> list[Handler]:`)
p.P(` handlers = [`)
for _, meth := range sortedMap(p.services) {
svc := p.services[meth]
var (
rpcHandler string
call string
)
// TODO(zchee): BidirectionalStreaming?
switch meth.RPCType {
case Unary:
rpcHandler = `UnaryHandler`
call = fmt.Sprintf(" unary=service.%s,", meth.Method)
case ServerStreaming:
rpcHandler = `ServerStreamHandler`
call = fmt.Sprintf(" stream=service.%s,", meth.Method)
case ClientStreaming:
rpcHandler = `ClientStreamHandler`
call = fmt.Sprintf(" stream=service.%s,", meth.Method)
p.P()
p.P()
p.P(`class `, upperSvcName, `Client:`)
p.P(` def __init__(self, base_url: str, session: AsyncClientSession, options: ClientOptions | None = None) -> None:`)
p.P(` base_url = base_url.removesuffix("/")`)
p.P()
for _, meth := range sortedMap(p.services) {
svc := p.services[meth]
p.P(` `, `self.`, meth.Method, ` = `, `Client[`, svc.input.method, `, `, svc.output.method, `](`)
p.P(` `, `session, `, `base_url + `, procedures+`.`+meth.Method+`.value, `, svc.input.method+`, `, svc.output.method, `, options`)
switch meth.RPCType {
case Unary:
p.P(` `, `).call_unary`)
case ServerStreaming:
p.P(` `, `).call_server_stream`)
case ClientStreaming:
p.P(` `, `).call_client_stream`)
}
}
p.P()
p.P()
handler := upperSvcName + `Handler`
p.P(`class `, handler, `:`)
p.P(` `, `"""Handler for the `, lowerCamelCase(upperSvcName), ` service."""`)
p.P()
j := 0
for _, meth := range sortedMap(p.services) {
svc := p.services[meth]
sb.Reset()
fmt.Fprintf(&sb, " async def %s(self, request: ", meth.Method)
var (
reqRPCType string
respRPCType string
)
// TODO(zchee): BidirectionalStreaming?
switch meth.RPCType {
case Unary:
reqRPCType = `UnaryRequest`
respRPCType = `UnaryResponse`
case ServerStreaming, ClientStreaming:
reqRPCType = `StreamRequest`
respRPCType = `StreamResponse`
}
fmt.Fprintf(&sb, "%s[%s]) -> %s[%s]: ...", reqRPCType, svc.input.method, respRPCType, svc.output.method)
p.P(sb.String())
if j <= len(p.services)-2 {
p.P()
}
j++
}
p.P()
p.P()
p.P(`def create_`, upperSvcName, `_handlers`, `(`, `service: `, handler, `, options: ConnectOptions | None = None`, `) -> list[Handler]:`)
p.P(` handlers = [`)
for _, meth := range sortedMap(p.services) {
svc := p.services[meth]
var (
rpcHandler string
call string
)
// TODO(zchee): BidirectionalStreaming?
switch meth.RPCType {
case Unary:
rpcHandler = `UnaryHandler`
call = fmt.Sprintf(" unary=service.%s,", meth.Method)
case ServerStreaming:
rpcHandler = `ServerStreamHandler`
call = fmt.Sprintf(" stream=service.%s,", meth.Method)
case ClientStreaming:
rpcHandler = `ClientStreamHandler`
call = fmt.Sprintf(" stream=service.%s,", meth.Method)
}

// TODO(zchee): BidirectionalStreaming?
switch meth.RPCType {
case Unary, ServerStreaming, ClientStreaming:
p.P(` `, rpcHandler, `(`)
p.P(` procedure=`, procedures+`.`+meth.Method+`.value,`)
p.P(call)
p.P(` input=`, svc.input.method, `,`)
p.P(` output=`, svc.output.method, `,`)
p.P(` options=options,`)
p.P(` ),`)
// TODO(zchee): BidirectionalStreaming?
switch meth.RPCType {
case Unary, ServerStreaming, ClientStreaming:
p.P(` `, rpcHandler, `(`)
p.P(` procedure=`, procedures+`.`+meth.Method+`.value,`)
p.P(call)
p.P(` input=`, svc.input.method, `,`)
p.P(` output=`, svc.output.method, `,`)
p.P(` options=options,`)
p.P(` ),`)
}
}
p.P(` ]`)
p.P(` return handlers`)
}
p.P(` ]`)
p.P(` return handlers`)
}

func lowerCamelCase(s string) string {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Generated by the protoc-gen-connect-python. DO NOT EDIT!
# source: examples/proto/connectrpc/eliza/v1/v1connect/eliza.proto
# Protobuf Python Version: v5.29.3
# protoc-gen-connect-python version: v0.0.0-20250225082729-4cfb74729a2c
# protoc-gen-connect-python version: v0.0.0-20250225130907-52aad1ea4ad5
"""Generated connect code."""

from enum import Enum
Expand All @@ -18,7 +18,7 @@


class ElizaServiceProcedures(Enum):
"""Procedures for the eliza service."""
"""Procedures for the ElizaService service."""

Say = "/connectrpc.eliza.v1.ElizaService/Say"
Converse = "/connectrpc.eliza.v1.ElizaService/Converse"
Expand All @@ -34,7 +34,7 @@ class ElizaServiceProcedures(Enum):
ElizaServiceIntroduceClient_method_descriptor: MethodDescriptor = ElizaService_service_descriptor.methods_by_name["IntroduceClient"]


class ElizaClient:
class ElizaServiceClient:
def __init__(self, base_url: str, session: AsyncClientSession, options: ClientOptions | None = None) -> None:
base_url = base_url.removesuffix("/")

Expand All @@ -53,7 +53,7 @@ def __init__(self, base_url: str, session: AsyncClientSession, options: ClientOp


class ElizaServiceHandler:
"""Handler for the eliza service."""
"""Handler for the elizaService service."""

async def Say(self, request: UnaryRequest[SayRequest]) -> UnaryResponse[SayResponse]: ...

Expand Down