Skip to content

Commit

Permalink
feat!: add HybridResolver (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronc committed May 3, 2023
1 parent a0d0171 commit 8051872
Show file tree
Hide file tree
Showing 18 changed files with 488 additions and 2,004 deletions.
5 changes: 5 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@ require (
golang.org/x/text v0.9.0 // indirect
google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect
)

retract (
// API changed in an incompatible way
v1.4.8
)
91 changes: 0 additions & 91 deletions proto/all_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1646,97 +1646,6 @@ func TestExtensionMapFieldMarshalDeterministic(t *testing.T) {
}
}

// Many extensions, because small maps might not iterate differently on each iteration.
var exts = []*ExtensionDesc{
E_X201,
E_X202,
E_X203,
E_X204,
E_X205,
E_X206,
E_X207,
E_X208,
E_X209,
E_X210,
E_X211,
E_X212,
E_X213,
E_X214,
E_X215,
E_X216,
E_X217,
E_X218,
E_X219,
E_X220,
E_X221,
E_X222,
E_X223,
E_X224,
E_X225,
E_X226,
E_X227,
E_X228,
E_X229,
E_X230,
E_X231,
E_X232,
E_X233,
E_X234,
E_X235,
E_X236,
E_X237,
E_X238,
E_X239,
E_X240,
E_X241,
E_X242,
E_X243,
E_X244,
E_X245,
E_X246,
E_X247,
E_X248,
E_X249,
E_X250,
}

func TestMessageSetMarshalOrder(t *testing.T) {
m := &MyMessageSet{}
for _, x := range exts {
if err := SetExtension(m, x, &Empty{}); err != nil {
t.Fatalf("SetExtension: %v", err)
}
}

buf, err := Marshal(m)
if err != nil {
t.Fatalf("Marshal: %v", err)
}

// Serialize m several times, and check we get the same bytes each time.
for i := 0; i < 10; i++ {
b1, err := Marshal(m)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
if !bytes.Equal(b1, buf) {
t.Errorf("Bytes differ on re-Marshal #%d", i)
}

m2 := &MyMessageSet{}
if err = Unmarshal(buf, m2); err != nil {
t.Errorf("Unmarshal: %v", err)
}
b2, err := Marshal(m2)
if err != nil {
t.Errorf("re-Marshal: %v", err)
}
if !bytes.Equal(b2, buf) {
t.Errorf("Bytes differ on round-trip #%d", i)
}
}
}

func TestUnmarshalMergesMessages(t *testing.T) {
// If a nested message occurs twice in the input,
// the fields should be merged when decoding.
Expand Down
6 changes: 0 additions & 6 deletions proto/extensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,6 @@ func TestExtensionsRoundTrip(t *testing.T) {
if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
t.Errorf("got %v, expected ErrMissingExtension", e)
}
if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
t.Error("expected bad extension error, got nil")
}
if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
t.Error("expected extension err")
}
if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
t.Error("expected some sort of type mismatch error, got nil")
}
Expand Down
130 changes: 77 additions & 53 deletions proto/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package proto

import (
"bytes"
"compress/gzip"
"errors"
"fmt"
"runtime"
Expand All @@ -24,8 +23,8 @@ import (
//
// In contrast to MergedFileDescriptorsWithValidation,
// MergedFileDescriptors does not validate import paths
func MergedFileDescriptors(globalFiles *protoregistry.Files, appFiles map[string][]byte) (*descriptorpb.FileDescriptorSet, error) {
return mergedFileDescriptors(globalFiles, appFiles, false)
func MergedFileDescriptors(globalFiles *protoregistry.Files, gogoFiles *protoregistry.Files) (*descriptorpb.FileDescriptorSet, error) {
return mergedFileDescriptors(globalFiles, gogoFiles, false)
}

// MergedFileDescriptorsWithValidation returns a single FileDescriptorSet containing all the
Expand All @@ -34,22 +33,22 @@ func MergedFileDescriptors(globalFiles *protoregistry.Files, appFiles map[string
// If there are any incorrect import paths that do not match
// the fully qualified package name, or if there is a common file descriptor
// that differs accross globalFiles and appFiles, an error is returned.
func MergedFileDescriptorsWithValidation(globalFiles *protoregistry.Files, appFiles map[string][]byte) (*descriptorpb.FileDescriptorSet, error) {
return mergedFileDescriptors(globalFiles, appFiles, true)
func MergedFileDescriptorsWithValidation(globalFiles *protoregistry.Files, gogoFiles *protoregistry.Files) (*descriptorpb.FileDescriptorSet, error) {
return mergedFileDescriptors(globalFiles, gogoFiles, true)
}

// MergedGlobalFileDescriptors calls MergedFileDescriptors
// with [protoregistry.GlobalFiles] and all files
// registered through [RegisterFile].
func MergedGlobalFileDescriptors() (*descriptorpb.FileDescriptorSet, error) {
return MergedFileDescriptors(protoregistry.GlobalFiles, protoFiles)
return MergedFileDescriptors(protoregistry.GlobalFiles, gogoProtoRegistry)
}

// MergedGlobalFileDescriptorsWithValidation calls MergedFileDescriptorsWithValidation
// with [protoregistry.GlobalFiles] and all files
// registered through [RegisterFile].
func MergedGlobalFileDescriptorsWithValidation() (*descriptorpb.FileDescriptorSet, error) {
return MergedFileDescriptorsWithValidation(protoregistry.GlobalFiles, protoFiles)
return MergedFileDescriptorsWithValidation(protoregistry.GlobalFiles, gogoProtoRegistry)
}

// MergedRegistry returns a *protoregistry.Files that acts as a single registry
Expand Down Expand Up @@ -177,7 +176,7 @@ LOOP:
type descriptorProcessor struct {
processWG sync.WaitGroup
globalFileCh chan protoreflect.FileDescriptor
appFileCh chan []byte
appFileCh chan protoreflect.FileDescriptor

fdWG sync.WaitGroup
fdCh chan *descriptorpb.FileDescriptorProto
Expand All @@ -186,7 +185,7 @@ type descriptorProcessor struct {

// process reads from p.globalFileCh and p.appFileCh, processing each file descriptor as appropriate,
// and sends the processed file descriptors through p.fdCh for eventual return from mergedFileDescriptors.
// Any errors during processing are sent to ec.ProcessErrCh,
// Any errors during processing are sent to ec.ProcessErrCh,
// which collects the errors also for possible return from mergedFileDescriptors.
//
// If validate is true, extra work is performed to validate import paths
Expand All @@ -213,45 +212,19 @@ func (p *descriptorProcessor) process(globalFiles *protoregistry.Files, ec *desc
}

// Now handle all the app files.

// Reuse a single gzip reader throughout the loop,
// so we don't have to repeatedly allocate new readers.
gzr := new(gzip.Reader)

// Also reuse a single byte buffer for each gzip read.
buf := new(bytes.Buffer)

for compressedBz := range p.appFileCh {
if err := gzr.Reset(bytes.NewReader(compressedBz)); err != nil {
// This should only fail if there is an invalid gzip header in compressedBz.
ec.ProcessErrCh <- fmt.Errorf("failed to reset gzip reader: %w", err)
continue
}

buf.Reset()
if _, err := buf.ReadFrom(gzr); err != nil {
// This should only fail if there was invalidly gzipped content in compressedBz.
ec.ProcessErrCh <- fmt.Errorf("failed to read from gzip reader: %w", err)
continue
}

fd := &descriptorpb.FileDescriptorProto{}
if err := protov2.Unmarshal(buf.Bytes(), fd); err != nil {
// This should only fail if the gzipped data contained invalid bytes for a FileDescriptorProto.
ec.ProcessErrCh <- err
continue
}

for gogoFd := range p.appFileCh {
// If the app FD is not in protoregistry, we need to track it.
gogoFdp := protodesc.ToFileDescriptorProto(gogoFd)
if validate {
// Ensure the import path on the app file is good.
if err := CheckImportPath(fd.GetName(), fd.GetPackage()); err != nil {
if err := CheckImportPath(gogoFdp.GetName(), gogoFdp.GetPackage()); err != nil {
// Track the import error but don't stop processing.
// It is more helpful to present all the import errors,
// rather than just stopping on the first one.
ec.ImportErrCh <- err
// Don't break the loop here, continue to check for a file descriptor diff.
}
}

// If the app FD is not in protoregistry, we need to track it.
protoregFd, err := globalFiles.FindFileByPath(*fd.Name)
protoregFd, err := globalFiles.FindFileByPath(*gogoFdp.Name)
if err != nil {
if !errors.Is(err, protoregistry.NotFound) {
// Non-nil error, and it wasn't a not found error.
Expand All @@ -260,15 +233,16 @@ func (p *descriptorProcessor) process(globalFiles *protoregistry.Files, ec *desc
}
// Otherwise it was a not found error, so add it.
// At this point we can't validate.
p.fdCh <- fd
p.fdCh <- gogoFdp
continue
}

if validate {
fdp := protodesc.ToFileDescriptorProto(protoregFd)
if !protov2.Equal(fdp, fd) {
diff := cmp.Diff(fdp, fd, protocmp.Transform())
ec.DiffCh <- fmt.Sprintf("Mismatch in %s:\n%s", *fd.Name, diff)

if !protov2.Equal(fdp, gogoFdp) {
diff := cmp.Diff(fdp, gogoFdp, protocmp.Transform())
ec.DiffCh <- fmt.Sprintf("Mismatch in %s:\n%s", *gogoFdp.Name, diff)
}
}
}
Expand All @@ -295,7 +269,7 @@ func (p *descriptorProcessor) collectFDs() {
// If validate is true, do extra work to validate that import paths are properly formed
// and that "duplicated" file descriptors across globalFiles and appFiles
// are indeed identical, returning an error if either of those conditions are invalidated.
func mergedFileDescriptors(globalFiles *protoregistry.Files, appFiles map[string][]byte, validate bool) (*descriptorpb.FileDescriptorSet, error) {
func mergedFileDescriptors(globalFiles *protoregistry.Files, gogoFiles *protoregistry.Files, validate bool) (*descriptorpb.FileDescriptorSet, error) {
// GOMAXPROCS is the number of CPU cores available, by default.
// Respect that setting as the number of CPU-bound goroutines,
// and for channel sizes.
Expand All @@ -305,7 +279,7 @@ func mergedFileDescriptors(globalFiles *protoregistry.Files, appFiles map[string

p := &descriptorProcessor{
globalFileCh: make(chan protoreflect.FileDescriptor, nProcs),
appFileCh: make(chan []byte, nProcs),
appFileCh: make(chan protoreflect.FileDescriptor, nProcs),

fdCh: make(chan *descriptorpb.FileDescriptorProto, nProcs),
fds: make([]*descriptorpb.FileDescriptorProto, 0, globalFiles.NumFiles()),
Expand All @@ -330,10 +304,11 @@ func mergedFileDescriptors(globalFiles *protoregistry.Files, appFiles map[string
// Signal that no more global files will be sent.
close(p.globalFileCh)

// Same for appFiles: send everything then signal app files are finished.
for _, bz := range appFiles {
p.appFileCh <- bz
}
// Same for gogoFiles: send everything then signal app files are finished.
gogoFiles.RangeFiles(func(fileDescriptor protoreflect.FileDescriptor) bool {
p.appFileCh <- fileDescriptor
return true
})
close(p.appFileCh)

// Since we are done sending file descriptors and we have closed those channels,
Expand All @@ -360,3 +335,52 @@ func mergedFileDescriptors(globalFiles *protoregistry.Files, appFiles map[string
File: p.fds,
}, nil
}

// HybridResolver is a protodesc.Resolver that uses both protoregistry.GlobalFiles
// and the gogo proto global registry, checking protoregistry.GlobalFiles first and
// then gogo proto global registry.
var HybridResolver Resolver = &hybridResolver{}

// Resolver is a protodesc.Resolver that can range over all the files in the resolver.
type Resolver interface {
protodesc.Resolver

// RangeFiles calls f for each file descriptor in the resolver while f returns true.
RangeFiles(f func(fileDescriptor protoreflect.FileDescriptor) bool)
}

type hybridResolver struct{}

var _ protodesc.Resolver = &hybridResolver{}

func (r *hybridResolver) FindFileByPath(path string) (protoreflect.FileDescriptor, error) {
if fd, err := protoregistry.GlobalFiles.FindFileByPath(path); err == nil {
return fd, nil
}

return gogoProtoRegistry.FindFileByPath(path)
}

func (r *hybridResolver) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) {
if desc, err := protoregistry.GlobalFiles.FindDescriptorByName(name); err == nil {
return desc, nil
}

return gogoProtoRegistry.FindDescriptorByName(name)
}

func (r *hybridResolver) RangeFiles(f func(fileDescriptor protoreflect.FileDescriptor) bool) {
seen := make(map[protoreflect.FullName]bool, protoregistry.GlobalFiles.NumFiles())

protoregistry.GlobalFiles.RangeFiles(func(fd protoreflect.FileDescriptor) bool {
seen[fd.FullName()] = true
return f(fd)
})

gogoProtoRegistry.RangeFiles(func(fd protoreflect.FileDescriptor) bool {
if seen[fd.FullName()] {
return true
}
return f(fd)
})
}
Loading

0 comments on commit 8051872

Please sign in to comment.