Skip to content

Commit

Permalink
cmd/protoc-gen-go-grpc: allow hooks to modify client structs and serv…
Browse files Browse the repository at this point in the history
…ice handlers (#5240)
  • Loading branch information
ZhouyihaiDing committed Apr 6, 2022
1 parent 337b815 commit 18fdf54
Showing 1 changed file with 68 additions and 54 deletions.
122 changes: 68 additions & 54 deletions cmd/protoc-gen-go-grpc/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ const (

type serviceGenerateHelperInterface interface {
formatFullMethodName(service *protogen.Service, method *protogen.Method) string
generateClientStruct(g *protogen.GeneratedFile, clientName string)
generateNewClientDefinitions(g *protogen.GeneratedFile, service *protogen.Service, clientName string)
generateUnimplementedServerType(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service)
generateServerFunctions(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service, serverType string, serviceDescVar string)
formatHandlerFuncName(service *protogen.Service, hname string) string
}

type serviceGenerateHelper struct{}
Expand All @@ -47,7 +49,15 @@ func (serviceGenerateHelper) formatFullMethodName(service *protogen.Service, met
return fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())
}

func (serviceGenerateHelper) generateClientStruct(g *protogen.GeneratedFile, clientName string) {
g.P("type ", unexport(clientName), " struct {")
g.P("cc ", grpcPackage.Ident("ClientConnInterface"))
g.P("}")
g.P()
}

func (serviceGenerateHelper) generateNewClientDefinitions(g *protogen.GeneratedFile, service *protogen.Service, clientName string) {
g.P("return &", unexport(clientName), "{cc}")
}

func (serviceGenerateHelper) generateUnimplementedServerType(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service) {
Expand Down Expand Up @@ -77,6 +87,19 @@ func (serviceGenerateHelper) generateUnimplementedServerType(gen *protogen.Plugi
}

func (serviceGenerateHelper) generateServerFunctions(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service, serverType string, serviceDescVar string) {
// Server handler implementations.
handlerNames := make([]string, 0, len(service.Methods))
for _, method := range service.Methods {
hname := genServerMethod(gen, file, g, method, func(hname string) string {
return hname
})
handlerNames = append(handlerNames, hname)
}
genServiceDesc(file, g, serviceDescVar, serverType, service, handlerNames)
}

func (serviceGenerateHelper) formatHandlerFuncName(service *protogen.Service, hname string) string {
return hname
}

var helper serviceGenerateHelperInterface = serviceGenerateHelper{}
Expand Down Expand Up @@ -158,18 +181,14 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated
g.P()

// Client structure.
g.P("type ", unexport(clientName), " struct {")
g.P("cc ", grpcPackage.Ident("ClientConnInterface"))
g.P("}")
g.P()
helper.generateClientStruct(g, clientName)

// NewClient factory.
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P("func New", clientName, " (cc ", grpcPackage.Ident("ClientConnInterface"), ") ", clientName, " {")
helper.generateNewClientDefinitions(g, service, clientName)
g.P("return &", unexport(clientName), "{cc}")
g.P("}")
g.P()

Expand Down Expand Up @@ -239,52 +258,6 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated
g.P()

helper.generateServerFunctions(gen, file, g, service, serverType, serviceDescVar)

// Server handler implementations.
handlerNames := make([]string, 0, len(service.Methods))
for _, method := range service.Methods {
hname := genServerMethod(gen, file, g, method)
handlerNames = append(handlerNames, hname)
}

// Service descriptor.
g.P("// ", serviceDescVar, " is the ", grpcPackage.Ident("ServiceDesc"), " for ", service.GoName, " service.")
g.P("// It's only intended for direct use with ", grpcPackage.Ident("RegisterService"), ",")
g.P("// and not to be introspected or modified (even as a copy)")
g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {")
g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",")
g.P("HandlerType: (*", serverType, ")(nil),")
g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{")
for i, method := range service.Methods {
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
continue
}
g.P("{")
g.P("MethodName: ", strconv.Quote(string(method.Desc.Name())), ",")
g.P("Handler: ", handlerNames[i], ",")
g.P("},")
}
g.P("},")
g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{")
for i, method := range service.Methods {
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
continue
}
g.P("{")
g.P("StreamName: ", strconv.Quote(string(method.Desc.Name())), ",")
g.P("Handler: ", handlerNames[i], ",")
if method.Desc.IsStreamingServer() {
g.P("ServerStreams: true,")
}
if method.Desc.IsStreamingClient() {
g.P("ClientStreams: true,")
}
g.P("},")
}
g.P("},")
g.P("Metadata: \"", file.Desc.Path(), "\",")
g.P("}")
g.P()
}

func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
Expand Down Expand Up @@ -397,12 +370,53 @@ func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string
return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
}

func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method) string {
func genServiceDesc(file *protogen.File, g *protogen.GeneratedFile, serviceDescVar string, serverType string, service *protogen.Service, handlerNames []string) {
// Service descriptor.
g.P("// ", serviceDescVar, " is the ", grpcPackage.Ident("ServiceDesc"), " for ", service.GoName, " service.")
g.P("// It's only intended for direct use with ", grpcPackage.Ident("RegisterService"), ",")
g.P("// and not to be introspected or modified (even as a copy)")
g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {")
g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",")
g.P("HandlerType: (*", serverType, ")(nil),")
g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{")
for i, method := range service.Methods {
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
continue
}
g.P("{")
g.P("MethodName: ", strconv.Quote(string(method.Desc.Name())), ",")
g.P("Handler: ", handlerNames[i], ",")
g.P("},")
}
g.P("},")
g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{")
for i, method := range service.Methods {
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
continue
}
g.P("{")
g.P("StreamName: ", strconv.Quote(string(method.Desc.Name())), ",")
g.P("Handler: ", handlerNames[i], ",")
if method.Desc.IsStreamingServer() {
g.P("ServerStreams: true,")
}
if method.Desc.IsStreamingClient() {
g.P("ClientStreams: true,")
}
g.P("},")
}
g.P("},")
g.P("Metadata: \"", file.Desc.Path(), "\",")
g.P("}")
g.P()
}

func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, hnameFuncNameFormatter func(string) string) string {
service := method.Parent
hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName)

if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
g.P("func ", hname, "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {")
g.P("func ", hnameFuncNameFormatter(hname), "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {")
g.P("in := new(", method.Input.GoIdent, ")")
g.P("if err := dec(in); err != nil { return nil, err }")
g.P("if interceptor == nil { return srv.(", service.GoName, "Server).", method.GoName, "(ctx, in) }")
Expand All @@ -420,7 +434,7 @@ func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
return hname
}
streamType := unexport(service.GoName) + method.GoName + "Server"
g.P("func ", hname, "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {")
g.P("func ", hnameFuncNameFormatter(hname), "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {")
if !method.Desc.IsStreamingClient() {
g.P("m := new(", method.Input.GoIdent, ")")
g.P("if err := stream.RecvMsg(m); err != nil { return err }")
Expand Down

0 comments on commit 18fdf54

Please sign in to comment.