Skip to content

Commit

Permalink
Fix panic when inferring imports (#575)
Browse files Browse the repository at this point in the history
Also adds new tests and fixes other (less severe) issues when using the InferImportPaths.
  • Loading branch information
jhump committed Oct 2, 2023
1 parent a276f9d commit 35c5957
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 34 deletions.
161 changes: 128 additions & 33 deletions desc/protoparse/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"google.golang.org/protobuf/types/descriptorpb"

"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/desc/internal"
"github.com/jhump/protoreflect/desc/protoparse/ast"
)

Expand Down Expand Up @@ -155,8 +156,15 @@ func (p Parser) ParseFiles(filenames ...string) ([]*desc.FileDescriptor, error)
return nil, err
}
// then we can infer import paths
// TODO: if this re-writes one of the names in filenames, lookups below will break
results = fixupFilenames(results)
var rewritten map[string]string
results, rewritten = fixupFilenames(results)
if len(rewritten) > 0 {
for i := range filenames {
if replace, ok := rewritten[filenames[i]]; ok {
filenames[i] = replace
}
}
}
resolverFromResults := protocompile.ResolverFunc(func(path string) (protocompile.SearchResult, error) {
res, ok := results[path]
if !ok {
Expand Down Expand Up @@ -244,7 +252,15 @@ func (p Parser) ParseFilesButDoNotLink(filenames ...string) ([]*descriptorpb.Fil
for _, res := range results {
resultsMap[res.FileDescriptorProto().GetName()] = res
}
resultsMap = fixupFilenames(resultsMap)
var rewritten map[string]string
resultsMap, rewritten = fixupFilenames(resultsMap)
if len(rewritten) > 0 {
for i := range filenames {
if replace, ok := rewritten[filenames[i]]; ok {
filenames[i] = replace
}
}
}
for i := range filenames {
results[i] = resultsMap[filenames[i]]
}
Expand Down Expand Up @@ -362,48 +378,100 @@ func parseToProtos(res protocompile.Resolver, filenames []string, rep *reporter.
func parseToProtosRecursive(res protocompile.Resolver, filenames []string, rep *reporter.Handler, srcPosAddr *SourcePos) (map[string]parser.Result, error) {
results := make(map[string]parser.Result, len(filenames))
for _, filename := range filenames {
parseToProtoRecursive(res, filename, rep, srcPosAddr, results)
if err := parseToProtoRecursive(res, filename, rep, srcPosAddr, results); err != nil {
return results, err
}
}
return results, rep.Error()
}

func parseToProtoRecursive(res protocompile.Resolver, filename string, rep *reporter.Handler, srcPosAddr *SourcePos, results map[string]parser.Result) {
func parseToProtoRecursive(res protocompile.Resolver, filename string, rep *reporter.Handler, srcPosAddr *SourcePos, results map[string]parser.Result) error {
if _, ok := results[filename]; ok {
// already processed this one
return
return nil
}
results[filename] = nil // placeholder entry

astRoot, parseResult, _ := parseToAST(res, filename, rep)
if rep.ReporterError() != nil {
return
astRoot, parseResult, err := parseToAST(res, filename, rep)
if err != nil {
return err
}
if parseResult == nil {
parseResult, _ = parser.ResultFromAST(astRoot, true, rep)
if rep.ReporterError() != nil {
return
parseResult, err = parser.ResultFromAST(astRoot, true, rep)
if err != nil {
return err
}
}
results[filename] = parseResult

for _, decl := range astRoot.Decls {
imp, ok := decl.(*ast2.ImportNode)
if !ok {
continue
if astRoot != nil {
// We have an AST, so we use it to recursively examine imports.
for _, decl := range astRoot.Decls {
imp, ok := decl.(*ast2.ImportNode)
if !ok {
continue
}
err := func() error {
orig := *srcPosAddr
*srcPosAddr = astRoot.NodeInfo(imp.Name).Start()
defer func() {
*srcPosAddr = orig
}()

return parseToProtoRecursive(res, imp.Name.AsString(), rep, srcPosAddr, results)
}()
if err != nil {
return err
}
}
func() {
return nil
}

// Without an AST, we must recursively examine the proto. This makes it harder
// (but not necessarily impossible) to get the source location of the import.
fd := parseResult.FileDescriptorProto()
for i, dep := range fd.Dependency {
path := []int32{internal.File_dependencyTag, int32(i)}
err := func() error {
orig := *srcPosAddr
*srcPosAddr = astRoot.NodeInfo(imp.Name).Start()
found := false
for _, loc := range fd.GetSourceCodeInfo().GetLocation() {
if pathsEqual(loc.Path, path) {
*srcPosAddr = SourcePos{
Filename: dep,
Line: int(loc.Span[0]),
Col: int(loc.Span[1]),
}
found = true
break
}
}
if !found {
*srcPosAddr = *ast.UnknownPos(dep)
}
defer func() {
*srcPosAddr = orig
}()

parseToProtoRecursive(res, imp.Name.AsString(), rep, srcPosAddr, results)
return parseToProtoRecursive(res, dep, rep, srcPosAddr, results)
}()
if rep.ReporterError() != nil {
return
if err != nil {
return err
}
}
return nil
}

func pathsEqual(a, b []int32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

func newReporter(errRep ErrorReporter, warnRep WarningReporter) reporter.Reporter {
Expand Down Expand Up @@ -487,11 +555,12 @@ func (p Parser) getResolver(filenames []string) (protocompile.Resolver, *SourceP
}, &srcPos
}

func fixupFilenames(protos map[string]parser.Result) map[string]parser.Result {
func fixupFilenames(protos map[string]parser.Result) (revisedProtos map[string]parser.Result, rewrittenPaths map[string]string) {
// In the event that the given filenames (keys in the supplied map) do not
// match the actual paths used in 'import' statements in the files, we try
// to revise names in the protos so that they will match and be linkable.
revisedProtos := map[string]parser.Result{}
revisedProtos = make(map[string]parser.Result, len(protos))
rewrittenPaths = make(map[string]string, len(protos))

protoPaths := map[string]struct{}{}
// TODO: this is O(n^2) but could likely be O(n) with a clever data structure (prefix tree that is indexed backwards?)
Expand All @@ -501,7 +570,7 @@ func fixupFilenames(protos map[string]parser.Result) map[string]parser.Result {
candidatesAvailable[name] = struct{}{}
for _, f := range protos {
for _, imp := range f.FileDescriptorProto().Dependency {
if strings.HasSuffix(name, imp) {
if strings.HasSuffix(name, imp) || strings.HasSuffix(imp, name) {
candidates := importCandidates[imp]
if candidates == nil {
candidates = map[string]struct{}{}
Expand Down Expand Up @@ -529,37 +598,62 @@ func fixupFilenames(protos map[string]parser.Result) map[string]parser.Result {
if best == "" {
best = c
} else {
// HACK: we can't actually tell which files is supposed to match
// this import, so arbitrarily pick the "shorter" one (fewest
// path elements) or, on a tie, the lexically earlier one
// NB: We can't actually tell which file is supposed to match
// this import. So we prefer the longest name. On a tie, we
// choose the lexically earliest match.
minLen := strings.Count(best, string(filepath.Separator))
cLen := strings.Count(c, string(filepath.Separator))
if cLen < minLen || (cLen == minLen && c < best) {
if cLen > minLen || (cLen == minLen && c < best) {
best = c
}
}
}
if best != "" {
prefix := best[:len(best)-len(imp)]
if len(prefix) > 0 {
if len(best) > len(imp) {
prefix := best[:len(best)-len(imp)]
protoPaths[prefix] = struct{}{}
}
f := protos[best]
f.FileDescriptorProto().Name = proto.String(imp)
revisedProtos[imp] = f
rewrittenPaths[best] = imp
delete(candidatesAvailable, best)

// If other candidates are actually references to the same file, remove them.
for c := range candidates {
if _, ok := candidatesAvailable[c]; !ok {
// already used this candidate and re-written its filename accordingly
continue
}
possibleDup := protos[c]
prevName := possibleDup.FileDescriptorProto().Name
possibleDup.FileDescriptorProto().Name = proto.String(imp)
if !proto.Equal(f.FileDescriptorProto(), protos[c].FileDescriptorProto()) {
// not equal: restore name and look at next one
possibleDup.FileDescriptorProto().Name = prevName
continue
}
// This file used a different name but was actually the same file. So
// we prune it from the set.
rewrittenPaths[c] = imp
delete(candidatesAvailable, c)
if len(c) > len(imp) {
prefix := c[:len(c)-len(imp)]
protoPaths[prefix] = struct{}{}
}
}
}
}

if len(candidatesAvailable) == 0 {
return revisedProtos
return revisedProtos, rewrittenPaths
}

if len(protoPaths) == 0 {
for c := range candidatesAvailable {
revisedProtos[c] = protos[c]
}
return revisedProtos
return revisedProtos, rewrittenPaths
}

// Any remaining candidates are entry-points (not imported by others), so
Expand Down Expand Up @@ -588,12 +682,13 @@ func fixupFilenames(protos map[string]parser.Result) map[string]parser.Result {
f.FileDescriptorProto().Name = proto.String(imp)
f.FileNode()
revisedProtos[imp] = f
rewrittenPaths[c] = imp
} else {
revisedProtos[c] = protos[c]
}
}

return revisedProtos
return revisedProtos, rewrittenPaths
}

func removeDynamicExtensions(fd protoreflect.FileDescriptor, alreadySeen map[string]struct{}) {
Expand Down
105 changes: 104 additions & 1 deletion desc/protoparse/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func TestSimpleParse(t *testing.T) {
// We'll also check our fixup logic to make sure it correctly rewrites the
// names of the files to match corresponding import statementes. This should
// strip the "../../internal/testprotos/" prefix from each file.
protos = fixupFilenames(protos)
protos, _ = fixupFilenames(protos)
var actual []string
for n := range protos {
actual = append(actual, n)
Expand Down Expand Up @@ -422,3 +422,106 @@ message Foo {
comment := fds[0].GetMessageTypes()[0].GetFields()[0].GetSourceInfo().GetLeadingComments()
testutil.Eq(t, " leading comments\n", comment)
}

func TestParseInferImportPaths_SimpleNoOp(t *testing.T) {
sources := map[string]string{
"test.proto": `
syntax = "proto3";
import "google/protobuf/struct.proto";
message Foo {
string name = 1;
repeated uint64 refs = 2;
google.protobuf.Struct meta = 3;
}`,
}
p := Parser{
Accessor: FileContentsFromMap(sources),
InferImportPaths: true,
}
fds, err := p.ParseFiles("test.proto")
testutil.Ok(t, err)
testutil.Eq(t, 1, len(fds))
}

func TestParseInferImportPaths_FixesNestedPaths(t *testing.T) {
sources := FileContentsFromMap(map[string]string{
"/foo/bar/a.proto": `
syntax = "proto3";
import "baz/b.proto";
message A {
B b = 1;
}`,
"/foo/bar/baz/b.proto": `
syntax = "proto3";
import "baz/c.proto";
message B {
C c = 1;
}`,
"/foo/bar/baz/c.proto": `
syntax = "proto3";
message C {}`,
"/foo/bar/baz/d.proto": `
syntax = "proto3";
import "a.proto";
message D {
A a = 1;
}`,
})

testCases := []struct {
name string
cwd string
filenames []string
expect []string
}{
{
name: "outside hierarchy",
cwd: "/buzz",
filenames: []string{"../foo/bar/a.proto", "../foo/bar/baz/b.proto", "../foo/bar/baz/c.proto", "../foo/bar/baz/d.proto"},
},
{
name: "inside hierarchy",
cwd: "/foo",
filenames: []string{"bar/a.proto", "bar/baz/b.proto", "bar/baz/c.proto", "bar/baz/d.proto"},
},
{
name: "import path root (no-op)",
cwd: "/foo/bar",
filenames: []string{"a.proto", "baz/b.proto", "baz/c.proto", "baz/d.proto"},
},
{
name: "inside leaf directory",
cwd: "/foo/bar/baz",
filenames: []string{"../a.proto", "b.proto", "c.proto", "d.proto"},
// NB: Expected names differ from above cases because nothing imports d.proto.
// So when inferring the root paths, the fact that d.proto is defined in
// the baz sub-directory will not be discovered. That's okay since the parse
// operation still succeeds.
expect: []string{"a.proto", "baz/b.proto", "baz/c.proto", "d.proto"},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
p := Parser{
Accessor: sources,
ImportPaths: []string{testCase.cwd, "/foo/bar"},
InferImportPaths: true,
}
fds, err := p.ParseFiles(testCase.filenames...)
testutil.Ok(t, err)
testutil.Eq(t, 4, len(fds))
var expectedNames []string
if len(testCase.expect) == 0 {
expectedNames = []string{"a.proto", "baz/b.proto", "baz/c.proto", "baz/d.proto"}
} else {
testutil.Eq(t, 4, len(testCase.expect))
expectedNames = testCase.expect
}
// check that they have the expected name
testutil.Eq(t, expectedNames[0], fds[0].GetName())
testutil.Eq(t, expectedNames[1], fds[1].GetName())
testutil.Eq(t, expectedNames[2], fds[2].GetName())
testutil.Eq(t, expectedNames[3], fds[3].GetName())
})
}
}

0 comments on commit 35c5957

Please sign in to comment.