Skip to content

Commit

Permalink
implement grpc interceptor chaining
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal Witkowski committed May 14, 2016
1 parent 202298a commit 2674e5c
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 0 deletions.
61 changes: 61 additions & 0 deletions chain.go
@@ -0,0 +1,61 @@
// Copyright 2016 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.

// gRPC Server Interceptor chaining middleware.

package grpc_middleware

import (
"golang.org/x/net/context"
"google.golang.org/grpc"
)

// ChainUnaryServer creates a single interceptor out of a chain of many interceptors.
// Execution is done in left-to-right order, including passing of context.
// For example ChainUnaryServer(one, two, three) will execute one before two before three, and three
// will see context changes of one and two.
func ChainUnaryServer(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
buildChain := func(current grpc.UnaryServerInterceptor, next grpc.UnaryHandler) grpc.UnaryHandler {
return func(currentCtx context.Context, currentReq interface{}) (interface{}, error) {
return current(currentCtx, currentReq, info, next)
}
}
chain := handler
for i := len(interceptors) - 1; i >= 0; i-- {
chain = buildChain(interceptors[i], chain)
}
return chain(ctx, req)
}
}

// ChainStreamServer creates a single interceptor out of a chain of many interceptors.
// Execution is done in left-to-right order, including passing of context.
// For example ChainUnaryServer(one, two, three) will execute one before two before three.
// If you want to pass context between interceptors, use WrapServerStream.
func ChainStreamServer(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
buildChain := func(current grpc.StreamServerInterceptor, next grpc.StreamHandler) grpc.StreamHandler {
return func(currentSrv interface{}, currentStream grpc.ServerStream) error {
return current(currentSrv, currentStream, info, next)
}
}
chain := handler
for i := len(interceptors) - 1; i >= 0; i-- {
chain = buildChain(interceptors[i], chain)
}
return chain(srv, stream)
}
}

// WithUnaryServerChain is a grpc.Server config option that accepts multiple unary interceptors.
// Basically syntactic sugar.
func WithUnaryServerChain(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption {
return grpc.UnaryInterceptor(ChainUnaryServer(interceptors...))
}

// WithStreamServerChain is a grpc.Server config option that accepts multiple stream interceptors.
// Basically syntactic sugar.
func WithStreamServerChain(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption {
return grpc.StreamInterceptor(ChainStreamServer(interceptors...))
}
98 changes: 98 additions & 0 deletions chain_test.go
@@ -0,0 +1,98 @@
// Copyright 2016 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.

package grpc_middleware

import (
"fmt"
"testing"

"github.com/stretchr/testify/require"
"golang.org/x/net/context"
"google.golang.org/grpc"
)

var (
parentUnaryInfo = &grpc.UnaryServerInfo{FullMethod: "SomeService.UnaryMethod"}
parentStreamInfo = &grpc.StreamServerInfo{
FullMethod: "SomeService.StreamMethod",
IsServerStream: true,
}
someValue = 1
parentContext = context.WithValue(context.TODO(), "parent", someValue)
)

func TestChainUnaryServer(t *testing.T) {
input := "input"
output := "output"

first := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
requireContextValue(t, ctx, "parent", "first interceptor must know the parent context value")
require.Equal(t, parentUnaryInfo, info, "first interceptor must know the someUnaryServerInfo")
ctx = context.WithValue(ctx, "first", 1)
return handler(ctx, req)
}
second := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
requireContextValue(t, ctx, "parent", "second interceptor must know the parent context value")
requireContextValue(t, ctx, "first", "second interceptor must know the first context value")
require.Equal(t, parentUnaryInfo, info, "second interceptor must know the someUnaryServerInfo")
ctx = context.WithValue(ctx, "second", 1)
return handler(ctx, req)
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
require.EqualValues(t, input, req, "handler must get the input")
requireContextValue(t, ctx, "parent", "handler must know the parent context value")
requireContextValue(t, ctx, "first", "handler must know the first context value")
requireContextValue(t, ctx, "second", "handler must know the second context value")
return output, nil
}

chain := ChainUnaryServer(first, second)
out, _ := chain(parentContext, input, parentUnaryInfo, handler)
require.EqualValues(t, output, out, "chain must return handler's output")
}

func TestChainStreamServer(t *testing.T) {
someService := &struct{}{}
recvMessage := "received"
sentMessage := "sent"
outputError := fmt.Errorf("some error")

first := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
requireContextValue(t, stream.Context(), "parent", "first interceptor must know the parent context value")
require.Equal(t, parentStreamInfo, info, "first interceptor must know the parentStreamInfo")
require.Equal(t, someService, srv, "first interceptor must know someService")
wrapped := WrapServerStream(stream)
wrapped.WrappedContext = context.WithValue(stream.Context(), "first", 1)
return handler(srv, wrapped)
}
second := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
requireContextValue(t, stream.Context(), "parent", "second interceptor must know the parent context value")
requireContextValue(t, stream.Context(), "parent", "second interceptor must know the first context value")
require.Equal(t, parentStreamInfo, info, "second interceptor must know the parentStreamInfo")
require.Equal(t, someService, srv, "second interceptor must know someService")
wrapped := WrapServerStream(stream)
wrapped.WrappedContext = context.WithValue(stream.Context(), "second", 1)
return handler(srv, wrapped)
}
handler := func(srv interface{}, stream grpc.ServerStream) error {
require.Equal(t, someService, srv, "handler must know someService")
requireContextValue(t, stream.Context(), "parent", "handler must know the parent context value")
requireContextValue(t, stream.Context(), "first", "handler must know the first context value")
requireContextValue(t, stream.Context(), "second", "handler must know the second context value")
require.NoError(t, stream.RecvMsg(recvMessage), "handler must have access to stream messages")
require.NoError(t, stream.SendMsg(sentMessage), "handler must be able to send stream messages")
return outputError
}
fakeStream := &fakeServerStream{ctx: parentContext, recvMessage: recvMessage}
chain := ChainStreamServer(first, second)
err := chain(someService, fakeStream, parentStreamInfo, handler)
require.Equal(t, outputError, err, "chain must return handler's error")
require.Equal(t, sentMessage, fakeStream.sentMessage, "handler's sent message must propagate to stream")
}

func requireContextValue(t *testing.T, ctx context.Context, key string, msg ...interface{}) {
val := ctx.Value(key)
require.NotNil(t, val, msg...)
require.Equal(t, someValue, val, msg...)
}
29 changes: 29 additions & 0 deletions wrappers.go
@@ -0,0 +1,29 @@
// Copyright 2016 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.

package grpc_middleware

import (
"golang.org/x/net/context"
"google.golang.org/grpc"
)

// WrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
type WrappedServerStream struct {
grpc.ServerStream
// WrappedContext is the wrapper's own Context. You can assign it.
WrappedContext context.Context
}

// Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context()
func (w *WrappedServerStream) Context() context.Context {
return w.WrappedContext
}

// WrapServerStream returns a ServerStream that has the ability to overwrite context.
func WrapServerStream(stream grpc.ServerStream) *WrappedServerStream {
if existing, ok := stream.(*WrappedServerStream); ok {
return existing
}
return &WrappedServerStream{ServerStream: stream, WrappedContext: stream.Context()}
}
47 changes: 47 additions & 0 deletions wrappers_test.go
@@ -0,0 +1,47 @@
// Copyright 2016 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.

package grpc_middleware

import (
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"testing"
"github.com/stretchr/testify/assert"
)

func TestWrapServerStream(t *testing.T) {
ctx := context.WithValue(context.TODO(), "something", 1)
fake := &fakeServerStream{ctx: ctx}
wrapped := WrapServerStream(fake)
assert.NotNil(t, wrapped.Context().Value("something"), "values from fake must propagate to wrapper")
wrapped.WrappedContext = context.WithValue(wrapped.Context(), "other", 2)
assert.NotNil(t, wrapped.Context().Value("other"), "values from wrapper must be set")
}

type fakeServerStream struct {
grpc.ServerStream
ctx context.Context
recvMessage interface{}
sentMessage interface{}
}

func (f * fakeServerStream) Context() context.Context {
return f.ctx
}

func (f *fakeServerStream) SendMsg(m interface{}) error {
if f.sentMessage != nil {
return grpc.Errorf(codes.AlreadyExists, "fakeServerStream only takes one message, sorry")
}
f.sentMessage = m
return nil
}

func (f *fakeServerStream) RecvMsg(m interface{}) error {
if f.recvMessage == nil {
return grpc.Errorf(codes.NotFound, "fakeServerStream has no message, sorry")
}
return nil
}

0 comments on commit 2674e5c

Please sign in to comment.