Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide access to reflection resolver from bufimage.Image and bufprotosource #2887

Merged
merged 6 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,7 @@ func checkEnumZeroValueSuffix(add addFunc, enumValue bufprotosource.EnumValue, s
var CheckFieldLowerSnakeCase = newFieldCheckFunc(checkFieldLowerSnakeCase)

func checkFieldLowerSnakeCase(add addFunc, field bufprotosource.Field) error {
message := field.ParentMessage()
if message == nil {
// just a sanity check
return errors.New("field.Message() was nil")
}
Comment on lines -272 to -275
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

message will be nil for top-level extensions. (See other comment about bug-fix below.)

if message.IsMapEntry() {
if message := field.ParentMessage(); message != nil && message.IsMapEntry() {
// this check should always pass anyways but just in case
return nil
}
Expand Down Expand Up @@ -630,10 +625,19 @@ func checkPackageVersionSuffix(add addFunc, file bufprotosource.File) error {
}

// CheckProtovalidate is a check function.
var CheckProtovalidate = newFilesWithImportsCheckFunc(checkProtovalidate)
var CheckProtovalidate = combine(
newMessageCheckFunc(checkProtovalidateMessage),
newFieldCheckFunc(checkProtovalidateField),
// NOTE: Oneofs also have protovalidate support, but they
// only have a "required" field, so nothing to lint.
)

func checkProtovalidateMessage(add addFunc, message bufprotosource.Message) error {
return buflintvalidate.CheckMessage(add, message)
}

func checkProtovalidate(add addFunc, files []bufprotosource.File) error {
return buflintvalidate.Check(add, files)
func checkProtovalidateField(add addFunc, field bufprotosource.Field) error {
return buflintvalidate.CheckField(add, field)
}

// CheckRPCNoClientStreaming is a check function.
Expand Down
53 changes: 40 additions & 13 deletions private/bufpkg/bufcheck/buflint/internal/buflintcheck/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,21 +246,32 @@ func newMessageCheckFunc(
func newFieldCheckFunc(
f func(addFunc, bufprotosource.Field) error,
) func(string, internal.IgnoreFunc, []bufprotosource.File) ([]bufanalysis.FileAnnotation, error) {
return newMessageCheckFunc(
func(add addFunc, message bufprotosource.Message) error {
for _, field := range message.Fields() {
if err := f(add, field); err != nil {
return err
return combine(
newMessageCheckFunc(
func(add addFunc, message bufprotosource.Message) error {
for _, field := range message.Fields() {
if err := f(add, field); err != nil {
return err
}
}
}
// TODO: is this right?
for _, field := range message.Extensions() {
if err := f(add, field); err != nil {
return err
for _, field := range message.Extensions() {
if err := f(add, field); err != nil {
return err
}
}
}
return nil
},
return nil
},
),
newFileCheckFunc(
func(add addFunc, file bufprotosource.File) error {
for _, field := range file.Extensions() {
if err := f(add, field); err != nil {
return err
}
}
return nil
},
),
Comment on lines +265 to +274
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doh! We previously weren't applying lint rules to top-level extensions. Fixed in this branch.

)
}

Expand Down Expand Up @@ -308,3 +319,19 @@ func newMethodCheckFunc(
},
)
}

func combine(
bufdev marked this conversation as resolved.
Show resolved Hide resolved
checks ...func(string, internal.IgnoreFunc, []bufprotosource.File) ([]bufanalysis.FileAnnotation, error),
) func(string, internal.IgnoreFunc, []bufprotosource.File) ([]bufanalysis.FileAnnotation, error) {
return func(id string, ignoreFunc internal.IgnoreFunc, files []bufprotosource.File) ([]bufanalysis.FileAnnotation, error) {
var annotations []bufanalysis.FileAnnotation
for _, check := range checks {
checkAnnotations, err := check(id, ignoreFunc, files)
if err != nil {
return nil, err
}
annotations = append(annotations, checkAnnotations...)
}
return annotations, nil
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,63 +17,18 @@ package buflintvalidate
import (
"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
"github.com/bufbuild/buf/private/bufpkg/bufprotosource"
"github.com/bufbuild/buf/private/pkg/protodescriptor"
"github.com/bufbuild/protovalidate-go/resolver"
"google.golang.org/protobuf/reflect/protodesc"
)

// https://buf.build/bufbuild/protovalidate/docs/v0.5.1:buf.validate#buf.validate.MessageConstraints
const disabledFieldNumberInMesageConstraints = 1

// Check validates that all rules on fields are valid, and all CEL expressions compile.
//
// For a set of rules to be valid, it must
// 1. permit _some_ value
// 2. have a type compatible with the field it validates.
func Check(
add func(bufprotosource.Descriptor, bufprotosource.Location, []bufprotosource.Location, string, ...interface{}),
files []bufprotosource.File,
) error {
fileDescriptors := make([]protodescriptor.FileDescriptor, 0, len(files))
for _, file := range files {
fileDescriptors = append(fileDescriptors, file.FileDescriptor())
}
descriptorResolver, err := protodesc.NewFiles(protodescriptor.FileDescriptorSetForFileDescriptors(fileDescriptors...))
if err != nil {
return err
}
for _, file := range files {
if file.IsImport() {
continue
}
for _, message := range file.Messages() {
if err := checkForMessage(
add,
descriptorResolver,
message,
); err != nil {
return err
}
}
for _, extension := range file.Extensions() {
if err := checkForField(
add,
descriptorResolver,
extension,
); err != nil {
return err
}
}
}
return nil
}

func checkForMessage(
// CheckMessage validates that all rules on the message are valid, and any CEL expressions compile.
func CheckMessage(
add func(bufprotosource.Descriptor, bufprotosource.Location, []bufprotosource.Location, string, ...interface{}),
descriptorResolver protodesc.Resolver,
message bufprotosource.Message,
) error {
messageDescriptor, err := getReflectMessageDescriptor(descriptorResolver, message)
messageDescriptor, err := message.AsDescriptor()
if err != nil {
return err
}
Expand All @@ -87,36 +42,22 @@ func checkForMessage(
message.Name(),
)
}
if err := checkCELForMessage(
return checkCELForMessage(
add,
messageConstraints,
messageDescriptor,
message,
); err != nil {
return err
}
for _, nestedMessage := range message.Messages() {
if err := checkForMessage(add, descriptorResolver, nestedMessage); err != nil {
return err
}
}
for _, field := range message.Fields() {
if err := checkForField(
add,
descriptorResolver,
field,
); err != nil {
return err
}
}
for _, extension := range message.Extensions() {
if err := checkForField(
add,
descriptorResolver,
extension,
); err != nil {
return err
}
}
return nil
)
}

// CheckField validates that all rules on the field are valid, and any CEL expressions compile.
//
// For a set of rules to be valid, it must
// 1. permit _some_ value
// 2. have a type compatible with the field it validates.
func CheckField(
add func(bufprotosource.Descriptor, bufprotosource.Location, []bufprotosource.Location, string, ...interface{}),
field bufprotosource.Field,
) error {
return checkField(add, field)
}
34 changes: 0 additions & 34 deletions private/bufpkg/bufcheck/buflint/internal/buflintvalidate/cel.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ import (
"github.com/bufbuild/protovalidate-go/celext"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/dynamicpb"
)

Expand Down Expand Up @@ -206,38 +204,6 @@ func checkCEL(
}
}

func getReflectMessageDescriptor(resolver protodesc.Resolver, message bufprotosource.Message) (protoreflect.MessageDescriptor, error) {
descriptor, err := resolver.FindDescriptorByName(protoreflect.FullName(message.FullName()))
if err == protoregistry.NotFound {
return nil, fmt.Errorf("unable to resolve MessageDescriptor: %s", message.FullName())
}
if err != nil {
return nil, err
}
messageDescriptor, ok := descriptor.(protoreflect.MessageDescriptor)
if !ok {
// this should not happen
return nil, fmt.Errorf("%s is not a message", descriptor.FullName())
}
return messageDescriptor, nil
}

func getReflectFieldDescriptor(resolver protodesc.Resolver, field bufprotosource.Field) (protoreflect.FieldDescriptor, error) {
descriptor, err := resolver.FindDescriptorByName(protoreflect.FullName(field.FullName()))
if err == protoregistry.NotFound {
return nil, fmt.Errorf("unable to resolve FieldDescriptor: %s", field.FullName())
}
if err != nil {
return nil, err
}
fieldDescriptor, ok := descriptor.(protoreflect.FieldDescriptor)
if !ok {
// this should never happen
return nil, fmt.Errorf("%s is not a field", descriptor.FullName())
}
return fieldDescriptor, nil
}

// this depends on the undocumented behavior of cel-go's error message
//
// maps a string in this form:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,13 @@ var (
typeOneofDescriptor = fieldConstraintsDescriptor.Oneofs().ByName("type")
)

// checkForField validates that protovalidate rules defined for this field are
// checkField validates that protovalidate rules defined for this field are
// valid, not including CEL expressions.
func checkForField(
func checkField(
add func(bufprotosource.Descriptor, bufprotosource.Location, []bufprotosource.Location, string, ...interface{}),
descriptorResolver protodesc.Resolver,
field bufprotosource.Field,
) error {
fieldDescriptor, err := getReflectFieldDescriptor(descriptorResolver, field)
fieldDescriptor, err := field.AsDescriptor()
if err != nil {
return err
}
Expand Down