Skip to content

Commit

Permalink
feat(client): Add options for intercepting gRPC operations (#1724)
Browse files Browse the repository at this point in the history
* Add client options for intercepting gRPC operations
* Add OpenTelemetry interceptors to the gRPC server

Signed-off-by: Andrew Haines <haines@cerbos.dev>
  • Loading branch information
haines committed Aug 4, 2023
1 parent 690d90a commit b9228f6
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 5 deletions.
42 changes: 37 additions & 5 deletions client/client.go
Expand Up @@ -69,6 +69,8 @@ type config struct {
tlsClientKey string
userAgent string
playgroundInstance string
streamInterceptors []grpc.StreamClientInterceptor
unaryInterceptors []grpc.UnaryClientInterceptor
connectTimeout time.Duration
retryTimeout time.Duration
maxRetries uint
Expand Down Expand Up @@ -151,6 +153,20 @@ func WithPlaygroundInstance(instance string) Opt {
}
}

// WithStreamInterceptors sets the interceptors to be used for streaming gRPC operations.
func WithStreamInterceptors(interceptors ...grpc.StreamClientInterceptor) Opt {
return func(c *config) {
c.streamInterceptors = interceptors
}
}

// WithUnaryInterceptors sets the interceptors to be used for unary gRPC operations.
func WithUnaryInterceptors(interceptors ...grpc.UnaryClientInterceptor) Opt {
return func(c *config) {
c.unaryInterceptors = interceptors
}
}

// New creates a new Cerbos client.
func New(address string, opts ...Opt) (Client, error) {
grpcConn, _, err := mkConn(address, opts...)
Expand Down Expand Up @@ -194,23 +210,39 @@ func mkDialOpts(conf *config) ([]grpc.DialOption, error) {
dialOpts = append(dialOpts, grpc.WithConnectParams(grpc.ConnectParams{MinConnectTimeout: conf.connectTimeout}))
}

streamInterceptors := conf.streamInterceptors
unaryInterceptors := conf.unaryInterceptors

if conf.maxRetries > 0 && conf.retryTimeout > 0 {
dialOpts = append(dialOpts,
grpc.WithChainStreamInterceptor(
streamInterceptors = append(
[]grpc.StreamClientInterceptor{
grpc_retry.StreamClientInterceptor(
grpc_retry.WithMax(conf.maxRetries),
grpc_retry.WithPerRetryTimeout(conf.retryTimeout),
),
),
grpc.WithChainUnaryInterceptor(
},
streamInterceptors...,
)

unaryInterceptors = append(
[]grpc.UnaryClientInterceptor{
grpc_retry.UnaryClientInterceptor(
grpc_retry.WithMax(conf.maxRetries),
grpc_retry.WithPerRetryTimeout(conf.retryTimeout),
),
),
},
unaryInterceptors...,
)
}

if len(streamInterceptors) > 0 {
dialOpts = append(dialOpts, grpc.WithChainStreamInterceptor(streamInterceptors...))
}

if len(unaryInterceptors) > 0 {
dialOpts = append(dialOpts, grpc.WithChainUnaryInterceptor(unaryInterceptors...))
}

if conf.plaintext {
dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
} else {
Expand Down
39 changes: 39 additions & 0 deletions client/client_test.go
Expand Up @@ -13,7 +13,11 @@ import (
"time"

"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

svcv1 "github.com/cerbos/cerbos/api/genpb/cerbos/svc/v1"
"github.com/cerbos/cerbos/client"
"github.com/cerbos/cerbos/client/testutil"
"github.com/cerbos/cerbos/internal/test"
Expand Down Expand Up @@ -108,6 +112,41 @@ func TestClient(t *testing.T) {
})
})
}

t.Run("interceptors", func(t *testing.T) {
errCanceled := status.Error(codes.Canceled, "canceled")

t.Run("stream", func(t *testing.T) {
var called string

c, err := client.NewAdminClientWithCredentials("unix:/dev/null", "username", "password", client.WithStreamInterceptors(func(_ context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, method string, _ grpc.Streamer, _ ...grpc.CallOption) (grpc.ClientStream, error) {
called = method
return nil, errCanceled
}))
require.NoError(t, err, "Failed to create client")

_, err = c.AuditLogs(context.Background(), client.AuditLogOptions{
Type: client.DecisionLogs,
Tail: 1,
})
require.ErrorIs(t, err, errCanceled)
require.Equal(t, svcv1.CerbosAdminService_ListAuditLogEntries_FullMethodName, called)
})

t.Run("unary", func(t *testing.T) {
var called string

c, err := client.New("unix:/dev/null", client.WithUnaryInterceptors(func(_ context.Context, method string, _, _ any, _ *grpc.ClientConn, _ grpc.UnaryInvoker, _ ...grpc.CallOption) error {
called = method
return errCanceled
}))
require.NoError(t, err, "Failed to create client")

_, err = c.IsAllowed(context.Background(), client.NewPrincipal("id", "role"), client.NewResource("kind", "id"), "action")
require.ErrorIs(t, err, errCanceled)
require.Equal(t, svcv1.CerbosService_CheckResources_FullMethodName, called)
})
})
}

func mkServerOpts(t *testing.T, withTLS bool) []testutil.ServerOpt {
Expand Down
1 change: 1 addition & 0 deletions go.mod
Expand Up @@ -76,6 +76,7 @@ require (
github.com/twmb/franz-go/plugin/kzap v1.1.2
go.elastic.co/ecszap v1.0.1
go.opencensus.io v0.24.0
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.42.0
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0
go.opentelemetry.io/contrib/propagators/autoprop v0.42.0
go.opentelemetry.io/contrib/propagators/b3 v1.17.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Expand Up @@ -864,6 +864,8 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.42.0 h1:ZOLJc06r4CB42laIXg/7udr0pbZyuAihN10A/XuiQRY=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.42.0/go.mod h1:5z+/ZWJQKXa9YT34fQNx5K8Hd1EoIhvtUygUQPqEOgQ=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0 h1:pginetY7+onl4qN1vl0xW/V/v6OBZ0vVdH+esuJgvmM=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0/go.mod h1:XiYsayHc36K3EByOO6nbAXnAWbrUxdjUROCEeeROOH8=
go.opentelemetry.io/contrib/propagators/autoprop v0.42.0 h1:s2RzYOAqHVgG23q8fPWYChobUoZM6rJZ98EnylJr66w=
Expand Down
3 changes: 3 additions & 0 deletions internal/server/server.go
Expand Up @@ -30,6 +30,7 @@ import (
"go.opencensus.io/plugin/ochttp"
"go.opencensus.io/stats/view"
"go.opencensus.io/zpages"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.uber.org/zap"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
Expand Down Expand Up @@ -443,6 +444,7 @@ func (s *Server) mkGRPCServer(log *zap.Logger, auditLog audit.Log) (*grpc.Server
grpc.ChainStreamInterceptor(
grpc_recovery.StreamServerInterceptor(),
telemetryInt.StreamServerInterceptor(),
otelgrpc.StreamServerInterceptor(),
grpc_validator.StreamServerInterceptor(),
grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractorForInitialReq(svc.ExtractRequestFields)),
grpc_zap.StreamServerInterceptor(log,
Expand All @@ -454,6 +456,7 @@ func (s *Server) mkGRPCServer(log *zap.Logger, auditLog audit.Log) (*grpc.Server
grpc.ChainUnaryInterceptor(
grpc_recovery.UnaryServerInterceptor(),
telemetryInt.UnaryServerInterceptor(),
otelgrpc.UnaryServerInterceptor(),
grpc_validator.UnaryServerInterceptor(),
grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractorForInitialReq(svc.ExtractRequestFields)),
XForwardedHostUnaryServerInterceptor,
Expand Down

0 comments on commit b9228f6

Please sign in to comment.