From 2674e5ce465dd06b1da9b0544eea0ee03b9db964 Mon Sep 17 00:00:00 2001 From: Michal Witkowski Date: Sat, 14 May 2016 18:33:54 +0100 Subject: [PATCH] implement grpc interceptor chaining --- chain.go | 61 ++++++++++++++++++++++++++++++ chain_test.go | 98 ++++++++++++++++++++++++++++++++++++++++++++++++ wrappers.go | 29 ++++++++++++++ wrappers_test.go | 47 +++++++++++++++++++++++ 4 files changed, 235 insertions(+) create mode 100644 chain.go create mode 100644 chain_test.go create mode 100644 wrappers.go create mode 100644 wrappers_test.go diff --git a/chain.go b/chain.go new file mode 100644 index 000000000..8445008eb --- /dev/null +++ b/chain.go @@ -0,0 +1,61 @@ +// Copyright 2016 Michal Witkowski. All Rights Reserved. +// See LICENSE for licensing terms. + +// gRPC Server Interceptor chaining middleware. + +package grpc_middleware + +import ( + "golang.org/x/net/context" + "google.golang.org/grpc" +) + +// ChainUnaryServer creates a single interceptor out of a chain of many interceptors. +// Execution is done in left-to-right order, including passing of context. +// For example ChainUnaryServer(one, two, three) will execute one before two before three, and three +// will see context changes of one and two. +func ChainUnaryServer(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + buildChain := func(current grpc.UnaryServerInterceptor, next grpc.UnaryHandler) grpc.UnaryHandler { + return func(currentCtx context.Context, currentReq interface{}) (interface{}, error) { + return current(currentCtx, currentReq, info, next) + } + } + chain := handler + for i := len(interceptors) - 1; i >= 0; i-- { + chain = buildChain(interceptors[i], chain) + } + return chain(ctx, req) + } +} + +// ChainStreamServer creates a single interceptor out of a chain of many interceptors. +// Execution is done in left-to-right order, including passing of context. +// For example ChainUnaryServer(one, two, three) will execute one before two before three. +// If you want to pass context between interceptors, use WrapServerStream. +func ChainStreamServer(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor { + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + buildChain := func(current grpc.StreamServerInterceptor, next grpc.StreamHandler) grpc.StreamHandler { + return func(currentSrv interface{}, currentStream grpc.ServerStream) error { + return current(currentSrv, currentStream, info, next) + } + } + chain := handler + for i := len(interceptors) - 1; i >= 0; i-- { + chain = buildChain(interceptors[i], chain) + } + return chain(srv, stream) + } +} + +// WithUnaryServerChain is a grpc.Server config option that accepts multiple unary interceptors. +// Basically syntactic sugar. +func WithUnaryServerChain(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption { + return grpc.UnaryInterceptor(ChainUnaryServer(interceptors...)) +} + +// WithStreamServerChain is a grpc.Server config option that accepts multiple stream interceptors. +// Basically syntactic sugar. +func WithStreamServerChain(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption { + return grpc.StreamInterceptor(ChainStreamServer(interceptors...)) +} diff --git a/chain_test.go b/chain_test.go new file mode 100644 index 000000000..2b26092b5 --- /dev/null +++ b/chain_test.go @@ -0,0 +1,98 @@ +// Copyright 2016 Michal Witkowski. All Rights Reserved. +// See LICENSE for licensing terms. + +package grpc_middleware + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/net/context" + "google.golang.org/grpc" +) + +var ( + parentUnaryInfo = &grpc.UnaryServerInfo{FullMethod: "SomeService.UnaryMethod"} + parentStreamInfo = &grpc.StreamServerInfo{ + FullMethod: "SomeService.StreamMethod", + IsServerStream: true, + } + someValue = 1 + parentContext = context.WithValue(context.TODO(), "parent", someValue) +) + +func TestChainUnaryServer(t *testing.T) { + input := "input" + output := "output" + + first := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + requireContextValue(t, ctx, "parent", "first interceptor must know the parent context value") + require.Equal(t, parentUnaryInfo, info, "first interceptor must know the someUnaryServerInfo") + ctx = context.WithValue(ctx, "first", 1) + return handler(ctx, req) + } + second := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + requireContextValue(t, ctx, "parent", "second interceptor must know the parent context value") + requireContextValue(t, ctx, "first", "second interceptor must know the first context value") + require.Equal(t, parentUnaryInfo, info, "second interceptor must know the someUnaryServerInfo") + ctx = context.WithValue(ctx, "second", 1) + return handler(ctx, req) + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + require.EqualValues(t, input, req, "handler must get the input") + requireContextValue(t, ctx, "parent", "handler must know the parent context value") + requireContextValue(t, ctx, "first", "handler must know the first context value") + requireContextValue(t, ctx, "second", "handler must know the second context value") + return output, nil + } + + chain := ChainUnaryServer(first, second) + out, _ := chain(parentContext, input, parentUnaryInfo, handler) + require.EqualValues(t, output, out, "chain must return handler's output") +} + +func TestChainStreamServer(t *testing.T) { + someService := &struct{}{} + recvMessage := "received" + sentMessage := "sent" + outputError := fmt.Errorf("some error") + + first := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + requireContextValue(t, stream.Context(), "parent", "first interceptor must know the parent context value") + require.Equal(t, parentStreamInfo, info, "first interceptor must know the parentStreamInfo") + require.Equal(t, someService, srv, "first interceptor must know someService") + wrapped := WrapServerStream(stream) + wrapped.WrappedContext = context.WithValue(stream.Context(), "first", 1) + return handler(srv, wrapped) + } + second := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + requireContextValue(t, stream.Context(), "parent", "second interceptor must know the parent context value") + requireContextValue(t, stream.Context(), "parent", "second interceptor must know the first context value") + require.Equal(t, parentStreamInfo, info, "second interceptor must know the parentStreamInfo") + require.Equal(t, someService, srv, "second interceptor must know someService") + wrapped := WrapServerStream(stream) + wrapped.WrappedContext = context.WithValue(stream.Context(), "second", 1) + return handler(srv, wrapped) + } + handler := func(srv interface{}, stream grpc.ServerStream) error { + require.Equal(t, someService, srv, "handler must know someService") + requireContextValue(t, stream.Context(), "parent", "handler must know the parent context value") + requireContextValue(t, stream.Context(), "first", "handler must know the first context value") + requireContextValue(t, stream.Context(), "second", "handler must know the second context value") + require.NoError(t, stream.RecvMsg(recvMessage), "handler must have access to stream messages") + require.NoError(t, stream.SendMsg(sentMessage), "handler must be able to send stream messages") + return outputError + } + fakeStream := &fakeServerStream{ctx: parentContext, recvMessage: recvMessage} + chain := ChainStreamServer(first, second) + err := chain(someService, fakeStream, parentStreamInfo, handler) + require.Equal(t, outputError, err, "chain must return handler's error") + require.Equal(t, sentMessage, fakeStream.sentMessage, "handler's sent message must propagate to stream") +} + +func requireContextValue(t *testing.T, ctx context.Context, key string, msg ...interface{}) { + val := ctx.Value(key) + require.NotNil(t, val, msg...) + require.Equal(t, someValue, val, msg...) +} diff --git a/wrappers.go b/wrappers.go new file mode 100644 index 000000000..597b86244 --- /dev/null +++ b/wrappers.go @@ -0,0 +1,29 @@ +// Copyright 2016 Michal Witkowski. All Rights Reserved. +// See LICENSE for licensing terms. + +package grpc_middleware + +import ( + "golang.org/x/net/context" + "google.golang.org/grpc" +) + +// WrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context. +type WrappedServerStream struct { + grpc.ServerStream + // WrappedContext is the wrapper's own Context. You can assign it. + WrappedContext context.Context +} + +// Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context() +func (w *WrappedServerStream) Context() context.Context { + return w.WrappedContext +} + +// WrapServerStream returns a ServerStream that has the ability to overwrite context. +func WrapServerStream(stream grpc.ServerStream) *WrappedServerStream { + if existing, ok := stream.(*WrappedServerStream); ok { + return existing + } + return &WrappedServerStream{ServerStream: stream, WrappedContext: stream.Context()} +} diff --git a/wrappers_test.go b/wrappers_test.go new file mode 100644 index 000000000..2796700d1 --- /dev/null +++ b/wrappers_test.go @@ -0,0 +1,47 @@ +// Copyright 2016 Michal Witkowski. All Rights Reserved. +// See LICENSE for licensing terms. + +package grpc_middleware + +import ( + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "testing" + "github.com/stretchr/testify/assert" +) + +func TestWrapServerStream(t *testing.T) { + ctx := context.WithValue(context.TODO(), "something", 1) + fake := &fakeServerStream{ctx: ctx} + wrapped := WrapServerStream(fake) + assert.NotNil(t, wrapped.Context().Value("something"), "values from fake must propagate to wrapper") + wrapped.WrappedContext = context.WithValue(wrapped.Context(), "other", 2) + assert.NotNil(t, wrapped.Context().Value("other"), "values from wrapper must be set") +} + +type fakeServerStream struct { + grpc.ServerStream + ctx context.Context + recvMessage interface{} + sentMessage interface{} +} + +func (f * fakeServerStream) Context() context.Context { + return f.ctx +} + +func (f *fakeServerStream) SendMsg(m interface{}) error { + if f.sentMessage != nil { + return grpc.Errorf(codes.AlreadyExists, "fakeServerStream only takes one message, sorry") + } + f.sentMessage = m + return nil +} + +func (f *fakeServerStream) RecvMsg(m interface{}) error { + if f.recvMessage == nil { + return grpc.Errorf(codes.NotFound, "fakeServerStream has no message, sorry") + } + return nil +} \ No newline at end of file