diff --git a/desc/protoparse/linker.go b/desc/protoparse/linker.go index 34350263..7bc27b08 100644 --- a/desc/protoparse/linker.go +++ b/desc/protoparse/linker.go @@ -71,7 +71,7 @@ func (l *linker) linkFiles() (map[string]*desc.FileDescriptor, error) { } // we should now have any message_set_wire_format options parsed // and can do further validation on tag ranges - if err := checkExtensionsInFile(fd, r); err != nil { + if err := l.checkExtensionsInFile(fd, r); err != nil { return nil, err } } @@ -999,3 +999,59 @@ func (l *linker) checkForUnusedImports(filename string) { } } } + +func (l *linker) checkExtensionsInFile(fd *desc.FileDescriptor, res *parseResult) error { + for _, fld := range fd.GetExtensions() { + if err := l.checkExtension(fld, res); err != nil { + return err + } + } + for _, md := range fd.GetMessageTypes() { + if err := l.checkExtensionsInMessage(md, res); err != nil { + return err + } + } + return nil +} + +func (l *linker) checkExtensionsInMessage(md *desc.MessageDescriptor, res *parseResult) error { + for _, fld := range md.GetNestedExtensions() { + if err := l.checkExtension(fld, res); err != nil { + return err + } + } + for _, nmd := range md.GetNestedMessageTypes() { + if err := l.checkExtensionsInMessage(nmd, res); err != nil { + return err + } + } + return nil +} + +func (l *linker) checkExtension(fld *desc.FieldDescriptor, res *parseResult) error { + // NB: It's a little gross that we don't enforce these in validateBasic(). + // But requires some minimal linking to resolve the extendee, so we can + // interrogate its descriptor. + if fld.GetOwner().GetMessageOptions().GetMessageSetWireFormat() { + // Message set wire format requires that all extensions be messages + // themselves (no scalar extensions) + if fld.GetType() != dpb.FieldDescriptorProto_TYPE_MESSAGE { + pos := res.getFieldNode(fld.AsFieldDescriptorProto()).FieldType().Start() + return l.errs.handleErrorWithPos(pos, "messages with message-set wire format cannot contain scalar extensions, only messages") + } + if fld.IsRepeated() { + pos := res.getFieldNode(fld.AsFieldDescriptorProto()).FieldLabel().Start() + return l.errs.handleErrorWithPos(pos, "messages with message-set wire format cannot contain repeated extensions, only optional") + } + } else { + // In validateBasic() we just made sure these were within bounds for any message. But + // now that things are linked, we can check if the extendee is messageset wire format + // and, if not, enforce tighter limit. + if fld.GetNumber() > internal.MaxNormalTag { + pos := res.getFieldNode(fld.AsFieldDescriptorProto()).FieldTag().Start() + return l.errs.handleErrorWithPos(pos, "tag number %d is higher than max allowed tag number (%d)", fld.GetNumber(), internal.MaxNormalTag) + } + } + + return nil +} diff --git a/desc/protoparse/linker_test.go b/desc/protoparse/linker_test.go index 3365e543..8cbcee23 100644 --- a/desc/protoparse/linker_test.go +++ b/desc/protoparse/linker_test.go @@ -413,6 +413,12 @@ func TestLinkerValidation(t *testing.T) { }, "", // should succeed }, + { + map[string]string{ + "foo.proto": "message Foo { option message_set_wire_format = true; extensions 1 to 100; } extend Foo { repeated Foo bar = 1; }", + }, + "foo.proto:1:90: messages with message-set wire format cannot contain repeated extensions, only optional", + }, { map[string]string{ "foo.proto": "message Foo { extensions 1 to max; } extend Foo { optional int32 bar = 536870912; }", diff --git a/desc/protoparse/parser.go b/desc/protoparse/parser.go index d9801cd5..5300aab6 100644 --- a/desc/protoparse/parser.go +++ b/desc/protoparse/parser.go @@ -817,58 +817,6 @@ func checkTag(pos *SourcePos, v uint64, maxTag int32) error { return nil } -func checkExtensionsInFile(fd *desc.FileDescriptor, res *parseResult) error { - for _, fld := range fd.GetExtensions() { - if err := checkExtension(fld, res); err != nil { - return err - } - } - for _, md := range fd.GetMessageTypes() { - if err := checkExtensionsInMessage(md, res); err != nil { - return err - } - } - return nil -} - -func checkExtensionsInMessage(md *desc.MessageDescriptor, res *parseResult) error { - for _, fld := range md.GetNestedExtensions() { - if err := checkExtension(fld, res); err != nil { - return err - } - } - for _, nmd := range md.GetNestedMessageTypes() { - if err := checkExtensionsInMessage(nmd, res); err != nil { - return err - } - } - return nil -} - -func checkExtension(fld *desc.FieldDescriptor, res *parseResult) error { - // NB: It's a little gross that we don't enforce these in validateBasic(). - // But requires some minimal linking to resolve the extendee, so we can - // interrogate its descriptor. - if fld.GetOwner().GetMessageOptions().GetMessageSetWireFormat() { - // Message set wire format requires that all extensions be messages - // themselves (no scalar extensions) - if fld.GetType() != dpb.FieldDescriptorProto_TYPE_MESSAGE { - pos := res.getFieldNode(fld.AsFieldDescriptorProto()).FieldType().Start() - return errorWithPos(pos, "messages with message-set wire format cannot contain scalar extensions, only messages") - } - } else { - // In validateBasic() we just made sure these were within bounds for any message. But - // now that things are linked, we can check if the extendee is messageset wire format - // and, if not, enforce tighter limit. - if fld.GetNumber() > internal.MaxNormalTag { - pos := res.getFieldNode(fld.AsFieldDescriptorProto()).FieldTag().Start() - return errorWithPos(pos, "tag number %d is higher than max allowed tag number (%d)", fld.GetNumber(), internal.MaxNormalTag) - } - } - - return nil -} - func aggToString(agg []*ast.MessageFieldNode, buf *bytes.Buffer) { buf.WriteString("{") for _, a := range agg {