Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[interceptors/validator] feat: add error logging in validator #544

Merged
merged 8 commits into from
Mar 26, 2023
Merged
69 changes: 69 additions & 0 deletions interceptors/validator/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package validator

import "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"

var (
defaultOptions = &options{
logger: DefaultLoggerMethod,
shouldFailFast: DefaultDeciderMethod,
}
)

type options struct {
logger Logger
shouldFailFast Decider
}

// Option
type Option func(*options)

func evaluateServerOpt(opts []Option) *options {
optCopy := &options{}
*optCopy = *defaultOptions
for _, o := range opts {
o(optCopy)
}
return optCopy
}

func evaluateClientOpt(opts []Option) *options {
optCopy := &options{}
*optCopy = *defaultOptions
for _, o := range opts {
o(optCopy)
}
return optCopy
}

// Logger
type Logger func() (logging.Level, logging.Logger)

// DefaultLoggerMethod
func DefaultLoggerMethod() (logging.Level, logging.Logger) {
return "", nil
}

// WithLogger
func WithLogger(logger Logger) Option {
rohanraj7316 marked this conversation as resolved.
Show resolved Hide resolved
return func(o *options) {
o.logger = logger
rohanraj7316 marked this conversation as resolved.
Show resolved Hide resolved
}
}

// Decision
type Decision bool

// Decider function defines rules for suppressing any interceptor logs.
type Decider func() Decision

// DefaultDeciderMethod
func DefaultDeciderMethod() Decision {
rohanraj7316 marked this conversation as resolved.
Show resolved Hide resolved
return false
}

// WithFailFast
func WithFailFast(d Decider) Option {
rohanraj7316 marked this conversation as resolved.
Show resolved Hide resolved
return func(o *options) {
o.shouldFailFast = d
}
}
69 changes: 32 additions & 37 deletions interceptors/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
)

// The validateAller interface at protoc-gen-validate main branch.
Expand All @@ -28,20 +30,31 @@ type validatorLegacy interface {
Validate() error
}

func validate(req any, all bool) error {
if all {
func log(level logging.Level, logger logging.Logger, msg string) {
if logger != nil {
logger.Log(level, msg)
}
}

func validate(req interface{}, d Decider, l Logger) error {
isFailFast := bool(d())
level, logger := l()
if isFailFast {
switch v := req.(type) {
case validateAller:
if err := v.ValidateAll(); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
case validator:
if err := v.Validate(true); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
case validatorLegacy:
// Fallback to legacy validator
if err := v.Validate(); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
}
Expand All @@ -50,78 +63,60 @@ func validate(req any, all bool) error {
switch v := req.(type) {
case validatorLegacy:
if err := v.Validate(); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
case validator:
if err := v.Validate(false); err != nil {
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
}
return nil
}

// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
//
// Invalid messages will be rejected with `InvalidArgument` before reaching any userspace handlers.
// If `all` is false, the interceptor returns first validation error. Otherwise, the interceptor
// returns ALL validation error as a wrapped multi-error.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
func UnaryServerInterceptor(all bool) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if err := validate(req, all); err != nil {
func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
o := evaluateServerOpt(opts)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if err := validate(req, o.shouldFailFast, o.logger); err != nil {
return nil, err
}
return handler(ctx, req)
}
}

// UnaryClientInterceptor returns a new unary client interceptor that validates outgoing messages.
//
// Invalid messages will be rejected with `InvalidArgument` before sending the request to server.
// If `all` is false, the interceptor returns first validation error. Otherwise, the interceptor
// returns ALL validation error as a wrapped multi-error.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
func UnaryClientInterceptor(all bool) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if err := validate(req, all); err != nil {
func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor {
o := evaluateClientOpt(opts)
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if err := validate(req, o.shouldFailFast, o.logger); err != nil {
return err
}
return invoker(ctx, method, req, reply, cc, opts...)
}
}

// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages.
rohanraj7316 marked this conversation as resolved.
Show resolved Hide resolved
//
// If `all` is false, the interceptor returns first validation error. Otherwise, the interceptor
// returns ALL validation error as a wrapped multi-error.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
// The stage at which invalid messages will be rejected with `InvalidArgument` varies based on the
// type of the RPC. For `ServerStream` (1:m) requests, it will happen before reaching any userspace
// handlers. For `ClientStream` (n:1) or `BidiStream` (n:m) RPCs, the messages will be rejected on
// calls to `stream.Recv()`.
func StreamServerInterceptor(all bool) grpc.StreamServerInterceptor {
return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor {
o := evaluateServerOpt(opts)
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
wrapper := &recvWrapper{
all: all,
options: o,
ServerStream: stream,
}

return handler(srv, wrapper)
}
}

type recvWrapper struct {
all bool
*options
grpc.ServerStream
}

func (s *recvWrapper) RecvMsg(m any) error {
if err := s.ServerStream.RecvMsg(m); err != nil {
return err
}
if err := validate(m, s.all); err != nil {
if err := validate(m, s.shouldFailFast, s.logger); err != nil {
return err
}
return nil
Expand Down
56 changes: 34 additions & 22 deletions interceptors/validator/validator_test.go
Original file line number Diff line number Diff line change
@@ -1,53 +1,65 @@
// Copyright (c) The go-grpc-middleware Authors.
// Licensed under the Apache License 2.0.

package validator
package validator_test

import (
"io"
"testing"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
)

type Logger struct {
}

func (l *Logger) Log(lvl logging.Level, msg string) {}

func (l *Logger) With(fields ...string) logging.Logger {
return &Logger{}
}

func TestValidateWrapper(t *testing.T) {
assert.NoError(t, validate(testpb.GoodPing, false))
assert.Error(t, validate(testpb.BadPing, false))
assert.NoError(t, validate(testpb.GoodPing, true))
assert.Error(t, validate(testpb.BadPing, true))

assert.NoError(t, validate(testpb.GoodPingError, false))
assert.Error(t, validate(testpb.BadPingError, false))
assert.NoError(t, validate(testpb.GoodPingError, true))
assert.Error(t, validate(testpb.BadPingError, true))

assert.NoError(t, validate(testpb.GoodPingResponse, false))
assert.NoError(t, validate(testpb.GoodPingResponse, true))
assert.Error(t, validate(testpb.BadPingResponse, false))
assert.Error(t, validate(testpb.BadPingResponse, true))
assert.NoError(t, validate(testpb.GoodPing, false, &Logger{}))
assert.Error(t, validate(testpb.BadPing, false, &Logger{}))
assert.NoError(t, validate(testpb.GoodPing, true, &Logger{}))
assert.Error(t, validate(testpb.BadPing, true, &Logger{}))

assert.NoError(t, validate(testpb.GoodPingError, false, &Logger{}))
assert.Error(t, validate(testpb.BadPingError, false, &Logger{}))
assert.NoError(t, validate(testpb.GoodPingError, true, &Logger{}))
assert.Error(t, validate(testpb.BadPingError, true, &Logger{}))

assert.NoError(t, validate(testpb.GoodPingResponse, false, &Logger{}))
assert.NoError(t, validate(testpb.GoodPingResponse, true, &Logger{}))
assert.Error(t, validate(testpb.BadPingResponse, false, &Logger{}))
assert.Error(t, validate(testpb.BadPingResponse, true, &Logger{}))
}

func TestValidatorTestSuite(t *testing.T) {
s := &ValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(StreamServerInterceptor(false)),
grpc.UnaryInterceptor(UnaryServerInterceptor(false)),
grpc.StreamInterceptor(StreamServerInterceptor(false, &Logger{})),
grpc.UnaryInterceptor(UnaryServerInterceptor(false, &Logger{})),
},
},
}
suite.Run(t, s)
sAll := &ValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(StreamServerInterceptor(true)),
grpc.UnaryInterceptor(UnaryServerInterceptor(true)),
grpc.StreamInterceptor(StreamServerInterceptor(true, &Logger{})),
grpc.UnaryInterceptor(UnaryServerInterceptor(true, &Logger{})),
},
},
}
Expand All @@ -56,15 +68,15 @@ func TestValidatorTestSuite(t *testing.T) {
cs := &ClientValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ClientOpts: []grpc.DialOption{
grpc.WithUnaryInterceptor(UnaryClientInterceptor(false)),
grpc.WithUnaryInterceptor(UnaryClientInterceptor(false, &Logger{})),
},
},
}
suite.Run(t, cs)
csAll := &ClientValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ClientOpts: []grpc.DialOption{
grpc.WithUnaryInterceptor(UnaryClientInterceptor(true)),
grpc.WithUnaryInterceptor(UnaryClientInterceptor(true, &Logger{})),
},
},
}
Expand Down