Skip to content

Commit

Permalink
gapic: avoid generating duplicate iterators (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
noahdietz committed Jul 23, 2019
1 parent 8a10cc2 commit b489dd0
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 24 deletions.
41 changes: 28 additions & 13 deletions internal/gengapic/gengapic.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,19 @@ type generator struct {
serviceConfig *serviceConfig

grpcConf *conf.ServiceConfig

// Auxiliary types to be generated in the package
aux *auxTypes
}

func (g *generator) init(files []*descriptor.FileDescriptorProto) {
g.descInfo = pbinfo.Of(files)

g.comments = map[proto.Message]string{}
g.imports = map[pbinfo.ImportSpec]bool{}
g.aux = &auxTypes{
iters: map[string]*iterType{},
}

for _, f := range files {
for _, loc := range f.GetSourceCodeInfo().GetLocation() {
Expand Down Expand Up @@ -299,27 +305,36 @@ func (g *generator) gen(serv *descriptor.ServiceDescriptorProto, pkgName string)
return err
}

aux := auxTypes{
iters: map[string]iterType{},
}
// clear LRO types between services
g.aux.lros = []*descriptor.MethodDescriptorProto{}

for _, m := range serv.Method {
g.methodDoc(m)
if err := g.genMethod(servName, serv, m, &aux); err != nil {
if err := g.genMethod(servName, serv, m); err != nil {
return errors.E(err, "method: %s", m.GetName())
}
}

sort.Slice(aux.lros, func(i, j int) bool {
return aux.lros[i].GetName() < aux.lros[j].GetName()
sort.Slice(g.aux.lros, func(i, j int) bool {
return g.aux.lros[i].GetName() < g.aux.lros[j].GetName()
})
for _, m := range aux.lros {
for _, m := range g.aux.lros {
if err := g.lroType(servName, serv, m); err != nil {
return err
}
}

var iters []iterType
for _, iter := range aux.iters {
var iters []*iterType
for _, iter := range g.aux.iters {
// skip iterators that have already been generated in this package
//
// TODO(ndietz): investigate generating auxiliary types in a
// separate file in the same package to avoid keeping this state
if iter.generated {
continue
}

iter.generated = true
iters = append(iters, iter)
}
sort.Slice(iters, func(i, j int) bool {
Expand All @@ -340,14 +355,14 @@ type auxTypes struct {
// "List" of iterator types. We use these to generate FooIterator returned by paging methods.
// Since multiple methods can page over the same type, we dedupe by the name of the iterator,
// which is in turn determined by the element type name.
iters map[string]iterType
iters map[string]*iterType
}

// genMethod generates a single method from a client. m must be a method declared in serv.
// If the generated method requires an auxillary type, it is added to aux.
func (g *generator) genMethod(servName string, serv *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto, aux *auxTypes) error {
func (g *generator) genMethod(servName string, serv *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) error {
if m.GetOutputType() == lroType {
aux.lros = append(aux.lros, m)
g.aux.lros = append(g.aux.lros, m)
return g.lroCall(servName, m)
}

Expand All @@ -362,7 +377,7 @@ func (g *generator) genMethod(servName string, serv *descriptor.ServiceDescripto
if err != nil {
return err
}
aux.iters[iter.iterTypeName] = iter

return g.pagingCall(servName, m, pf, iter)
}

Expand Down
10 changes: 5 additions & 5 deletions internal/gengapic/gengapic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,22 +253,22 @@ methods:
proto.SetExtension(m.Options, longrunning.E_OperationInfo, lroType)
}

aux := auxTypes{
iters: map[string]iterType{},
g.aux = &auxTypes{
iters: map[string]*iterType{},
}
if err := g.genMethod("Foo", serv, m, &aux); err != nil {
if err := g.genMethod("Foo", serv, m); err != nil {
t.Error(err)
continue
}

for _, m := range aux.lros {
for _, m := range g.aux.lros {
if err := g.lroType("MyService", serv, m); err != nil {
t.Error(err)
continue methods
}
}

for _, iter := range aux.iters {
for _, iter := range g.aux.iters {
g.pagingIter(iter)
}

Expand Down
18 changes: 13 additions & 5 deletions internal/gengapic/paging.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ type iterType struct {
// If the elem type is a message, elemImports contains pbinfo.ImportSpec for the type.
// Otherwise, len(elemImports)==0.
elemImports []pbinfo.ImportSpec

generated bool
}

// iterTypeOf deduces iterType from a field to be iterated over.
// elemField should be the "resource" of a paginating RPC.
func (g *generator) iterTypeOf(elemField *descriptor.FieldDescriptorProto) (iterType, error) {
func (g *generator) iterTypeOf(elemField *descriptor.FieldDescriptorProto) (*iterType, error) {
var pt iterType

switch t := *elemField.Type; {
Expand All @@ -43,7 +45,7 @@ func (g *generator) iterTypeOf(elemField *descriptor.FieldDescriptorProto) (iter

imp, err := g.descInfo.ImportSpec(eType)
if err != nil {
return iterType{}, err
return &iterType{}, err
}

pt.elemTypeName = fmt.Sprintf("*%s.%s", imp.Name, eType.GetName())
Expand All @@ -66,7 +68,13 @@ func (g *generator) iterTypeOf(elemField *descriptor.FieldDescriptorProto) (iter
pt.elemTypeName = pType
pt.iterTypeName = upperFirst(pt.elemTypeName) + "Iterator"
}
return pt, nil

if iter, ok := g.aux.iters[pt.iterTypeName]; ok {
return iter, nil
}
g.aux.iters[pt.iterTypeName] = &pt

return &pt, nil
}

// TODO(pongad): this will probably need to read from annotations later.
Expand Down Expand Up @@ -126,7 +134,7 @@ func (g *generator) pagingField(m *descriptor.MethodDescriptorProto) (*descripto
return elemFields[0], nil
}

func (g *generator) pagingCall(servName string, m *descriptor.MethodDescriptorProto, elemField *descriptor.FieldDescriptorProto, pt iterType) error {
func (g *generator) pagingCall(servName string, m *descriptor.MethodDescriptorProto, elemField *descriptor.FieldDescriptorProto, pt *iterType) error {
inType := g.descInfo.Type[*m.InputType]
outType := g.descInfo.Type[*m.OutputType]

Expand Down Expand Up @@ -202,7 +210,7 @@ func (g *generator) pagingCall(servName string, m *descriptor.MethodDescriptorPr
return nil
}

func (g *generator) pagingIter(pt iterType) {
func (g *generator) pagingIter(pt *iterType) {
p := g.printf

p("// %s manages a stream of %s.", pt.iterTypeName, pt.elemTypeName)
Expand Down
5 changes: 4 additions & 1 deletion internal/gengapic/paging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ func TestIterTypeOf(t *testing.T) {
Name: proto.String("Foo"),
}
g := &generator{
aux: &auxTypes{
iters: map[string]*iterType{},
},
descInfo: pbinfo.Info{
Type: map[string]pbinfo.ProtoType{
msgType.GetName(): msgType,
Expand Down Expand Up @@ -204,7 +207,7 @@ func TestIterTypeOf(t *testing.T) {
got, err := g.iterTypeOf(tst.field)
if err != nil {
t.Error(err)
} else if diff := cmp.Diff(tst.want, got, cmp.AllowUnexported(got)); diff != "" {
} else if diff := cmp.Diff(tst.want, *got, cmp.AllowUnexported(*got)); diff != "" {
t.Errorf("%d: (got=-, want=+):\n%s", i, diff)
}
}
Expand Down

0 comments on commit b489dd0

Please sign in to comment.