Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Michal Witkowski
committed
May 14, 2016
1 parent
202298a
commit 2674e5c
Showing
4 changed files
with
235 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
// Copyright 2016 Michal Witkowski. All Rights Reserved. | ||
// See LICENSE for licensing terms. | ||
|
||
// gRPC Server Interceptor chaining middleware. | ||
|
||
package grpc_middleware | ||
|
||
import ( | ||
"golang.org/x/net/context" | ||
"google.golang.org/grpc" | ||
) | ||
|
||
// ChainUnaryServer creates a single interceptor out of a chain of many interceptors. | ||
// Execution is done in left-to-right order, including passing of context. | ||
// For example ChainUnaryServer(one, two, three) will execute one before two before three, and three | ||
// will see context changes of one and two. | ||
func ChainUnaryServer(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { | ||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { | ||
buildChain := func(current grpc.UnaryServerInterceptor, next grpc.UnaryHandler) grpc.UnaryHandler { | ||
return func(currentCtx context.Context, currentReq interface{}) (interface{}, error) { | ||
return current(currentCtx, currentReq, info, next) | ||
} | ||
} | ||
chain := handler | ||
for i := len(interceptors) - 1; i >= 0; i-- { | ||
chain = buildChain(interceptors[i], chain) | ||
} | ||
return chain(ctx, req) | ||
} | ||
} | ||
|
||
// ChainStreamServer creates a single interceptor out of a chain of many interceptors. | ||
// Execution is done in left-to-right order, including passing of context. | ||
// For example ChainUnaryServer(one, two, three) will execute one before two before three. | ||
// If you want to pass context between interceptors, use WrapServerStream. | ||
func ChainStreamServer(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor { | ||
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { | ||
buildChain := func(current grpc.StreamServerInterceptor, next grpc.StreamHandler) grpc.StreamHandler { | ||
return func(currentSrv interface{}, currentStream grpc.ServerStream) error { | ||
return current(currentSrv, currentStream, info, next) | ||
} | ||
} | ||
chain := handler | ||
for i := len(interceptors) - 1; i >= 0; i-- { | ||
chain = buildChain(interceptors[i], chain) | ||
} | ||
return chain(srv, stream) | ||
} | ||
} | ||
|
||
// WithUnaryServerChain is a grpc.Server config option that accepts multiple unary interceptors. | ||
// Basically syntactic sugar. | ||
func WithUnaryServerChain(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption { | ||
return grpc.UnaryInterceptor(ChainUnaryServer(interceptors...)) | ||
} | ||
|
||
// WithStreamServerChain is a grpc.Server config option that accepts multiple stream interceptors. | ||
// Basically syntactic sugar. | ||
func WithStreamServerChain(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption { | ||
return grpc.StreamInterceptor(ChainStreamServer(interceptors...)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
// Copyright 2016 Michal Witkowski. All Rights Reserved. | ||
// See LICENSE for licensing terms. | ||
|
||
package grpc_middleware | ||
|
||
import ( | ||
"fmt" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
"golang.org/x/net/context" | ||
"google.golang.org/grpc" | ||
) | ||
|
||
var ( | ||
parentUnaryInfo = &grpc.UnaryServerInfo{FullMethod: "SomeService.UnaryMethod"} | ||
parentStreamInfo = &grpc.StreamServerInfo{ | ||
FullMethod: "SomeService.StreamMethod", | ||
IsServerStream: true, | ||
} | ||
someValue = 1 | ||
parentContext = context.WithValue(context.TODO(), "parent", someValue) | ||
) | ||
|
||
func TestChainUnaryServer(t *testing.T) { | ||
input := "input" | ||
output := "output" | ||
|
||
first := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { | ||
requireContextValue(t, ctx, "parent", "first interceptor must know the parent context value") | ||
require.Equal(t, parentUnaryInfo, info, "first interceptor must know the someUnaryServerInfo") | ||
ctx = context.WithValue(ctx, "first", 1) | ||
return handler(ctx, req) | ||
} | ||
second := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { | ||
requireContextValue(t, ctx, "parent", "second interceptor must know the parent context value") | ||
requireContextValue(t, ctx, "first", "second interceptor must know the first context value") | ||
require.Equal(t, parentUnaryInfo, info, "second interceptor must know the someUnaryServerInfo") | ||
ctx = context.WithValue(ctx, "second", 1) | ||
return handler(ctx, req) | ||
} | ||
handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||
require.EqualValues(t, input, req, "handler must get the input") | ||
requireContextValue(t, ctx, "parent", "handler must know the parent context value") | ||
requireContextValue(t, ctx, "first", "handler must know the first context value") | ||
requireContextValue(t, ctx, "second", "handler must know the second context value") | ||
return output, nil | ||
} | ||
|
||
chain := ChainUnaryServer(first, second) | ||
out, _ := chain(parentContext, input, parentUnaryInfo, handler) | ||
require.EqualValues(t, output, out, "chain must return handler's output") | ||
} | ||
|
||
func TestChainStreamServer(t *testing.T) { | ||
someService := &struct{}{} | ||
recvMessage := "received" | ||
sentMessage := "sent" | ||
outputError := fmt.Errorf("some error") | ||
|
||
first := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { | ||
requireContextValue(t, stream.Context(), "parent", "first interceptor must know the parent context value") | ||
require.Equal(t, parentStreamInfo, info, "first interceptor must know the parentStreamInfo") | ||
require.Equal(t, someService, srv, "first interceptor must know someService") | ||
wrapped := WrapServerStream(stream) | ||
wrapped.WrappedContext = context.WithValue(stream.Context(), "first", 1) | ||
return handler(srv, wrapped) | ||
} | ||
second := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { | ||
requireContextValue(t, stream.Context(), "parent", "second interceptor must know the parent context value") | ||
requireContextValue(t, stream.Context(), "parent", "second interceptor must know the first context value") | ||
require.Equal(t, parentStreamInfo, info, "second interceptor must know the parentStreamInfo") | ||
require.Equal(t, someService, srv, "second interceptor must know someService") | ||
wrapped := WrapServerStream(stream) | ||
wrapped.WrappedContext = context.WithValue(stream.Context(), "second", 1) | ||
return handler(srv, wrapped) | ||
} | ||
handler := func(srv interface{}, stream grpc.ServerStream) error { | ||
require.Equal(t, someService, srv, "handler must know someService") | ||
requireContextValue(t, stream.Context(), "parent", "handler must know the parent context value") | ||
requireContextValue(t, stream.Context(), "first", "handler must know the first context value") | ||
requireContextValue(t, stream.Context(), "second", "handler must know the second context value") | ||
require.NoError(t, stream.RecvMsg(recvMessage), "handler must have access to stream messages") | ||
require.NoError(t, stream.SendMsg(sentMessage), "handler must be able to send stream messages") | ||
return outputError | ||
} | ||
fakeStream := &fakeServerStream{ctx: parentContext, recvMessage: recvMessage} | ||
chain := ChainStreamServer(first, second) | ||
err := chain(someService, fakeStream, parentStreamInfo, handler) | ||
require.Equal(t, outputError, err, "chain must return handler's error") | ||
require.Equal(t, sentMessage, fakeStream.sentMessage, "handler's sent message must propagate to stream") | ||
} | ||
|
||
func requireContextValue(t *testing.T, ctx context.Context, key string, msg ...interface{}) { | ||
val := ctx.Value(key) | ||
require.NotNil(t, val, msg...) | ||
require.Equal(t, someValue, val, msg...) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright 2016 Michal Witkowski. All Rights Reserved. | ||
// See LICENSE for licensing terms. | ||
|
||
package grpc_middleware | ||
|
||
import ( | ||
"golang.org/x/net/context" | ||
"google.golang.org/grpc" | ||
) | ||
|
||
// WrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context. | ||
type WrappedServerStream struct { | ||
grpc.ServerStream | ||
// WrappedContext is the wrapper's own Context. You can assign it. | ||
WrappedContext context.Context | ||
} | ||
|
||
// Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context() | ||
func (w *WrappedServerStream) Context() context.Context { | ||
return w.WrappedContext | ||
} | ||
|
||
// WrapServerStream returns a ServerStream that has the ability to overwrite context. | ||
func WrapServerStream(stream grpc.ServerStream) *WrappedServerStream { | ||
if existing, ok := stream.(*WrappedServerStream); ok { | ||
return existing | ||
} | ||
return &WrappedServerStream{ServerStream: stream, WrappedContext: stream.Context()} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// Copyright 2016 Michal Witkowski. All Rights Reserved. | ||
// See LICENSE for licensing terms. | ||
|
||
package grpc_middleware | ||
|
||
import ( | ||
"golang.org/x/net/context" | ||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/codes" | ||
"testing" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestWrapServerStream(t *testing.T) { | ||
ctx := context.WithValue(context.TODO(), "something", 1) | ||
fake := &fakeServerStream{ctx: ctx} | ||
wrapped := WrapServerStream(fake) | ||
assert.NotNil(t, wrapped.Context().Value("something"), "values from fake must propagate to wrapper") | ||
wrapped.WrappedContext = context.WithValue(wrapped.Context(), "other", 2) | ||
assert.NotNil(t, wrapped.Context().Value("other"), "values from wrapper must be set") | ||
} | ||
|
||
type fakeServerStream struct { | ||
grpc.ServerStream | ||
ctx context.Context | ||
recvMessage interface{} | ||
sentMessage interface{} | ||
} | ||
|
||
func (f * fakeServerStream) Context() context.Context { | ||
return f.ctx | ||
} | ||
|
||
func (f *fakeServerStream) SendMsg(m interface{}) error { | ||
if f.sentMessage != nil { | ||
return grpc.Errorf(codes.AlreadyExists, "fakeServerStream only takes one message, sorry") | ||
} | ||
f.sentMessage = m | ||
return nil | ||
} | ||
|
||
func (f *fakeServerStream) RecvMsg(m interface{}) error { | ||
if f.recvMessage == nil { | ||
return grpc.Errorf(codes.NotFound, "fakeServerStream has no message, sorry") | ||
} | ||
return nil | ||
} |