Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make ratelimit interface context aware #367

Merged
merged 2 commits into from Nov 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion interceptors/ratelimit/examples_test.go
@@ -1,6 +1,8 @@
package ratelimit_test

import (
"context"

"google.golang.org/grpc"

middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
Expand All @@ -11,7 +13,7 @@ import (
// It does not limit any request because Limit function always returns false.
type alwaysPassLimiter struct{}

func (*alwaysPassLimiter) Limit() bool {
func (*alwaysPassLimiter) Limit(_ context.Context) bool {
return false
}

Expand Down
6 changes: 3 additions & 3 deletions interceptors/ratelimit/ratelimit.go
Expand Up @@ -12,13 +12,13 @@ import (
// If Limit function return true, the request will be rejected.
// Otherwise, the request will pass.
type Limiter interface {
Limit() bool
Limit(ctx context.Context) bool
}

// UnaryServerInterceptor returns a new unary server interceptors that performs request rate limiting.
func UnaryServerInterceptor(limiter Limiter) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if limiter.Limit() {
if limiter.Limit(ctx) {
return nil, status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later.", info.FullMethod)
}
return handler(ctx, req)
Expand All @@ -28,7 +28,7 @@ func UnaryServerInterceptor(limiter Limiter) grpc.UnaryServerInterceptor {
// StreamServerInterceptor returns a new stream server interceptor that performs rate limiting on the request.
func StreamServerInterceptor(limiter Limiter) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if limiter.Limit() {
if limiter.Limit(stream.Context()) {
return status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later.", info.FullMethod)
}
return handler(srv, stream)
Expand Down
71 changes: 45 additions & 26 deletions interceptors/ratelimit/ratelimit_test.go
Expand Up @@ -11,64 +11,83 @@ import (

const errMsgFake = "fake error"

type mockPassLimiter struct{}
var ctxLimitKey = struct{}{}

func (*mockPassLimiter) Limit() bool {
return false
type mockGRPCServerStream struct {
grpc.ServerStream

ctx context.Context
}

func TestUnaryServerInterceptor_RateLimitPass(t *testing.T) {
interceptor := UnaryServerInterceptor(&mockPassLimiter{})
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, errors.New(errMsgFake)
}
info := &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}
req, err := interceptor(nil, nil, info, handler)
assert.Nil(t, req)
assert.EqualError(t, err, errMsgFake)
func (m *mockGRPCServerStream) Context() context.Context {
return m.ctx
}

type mockFailLimiter struct{}
type mockContextBasedLimiter struct{}

func (*mockFailLimiter) Limit() bool {
return true
func (*mockContextBasedLimiter) Limit(ctx context.Context) bool {
l, ok := ctx.Value(ctxLimitKey).(bool)
return ok && l
}

func TestUnaryServerInterceptor_RateLimitFail(t *testing.T) {
interceptor := UnaryServerInterceptor(&mockFailLimiter{})
func TestUnaryServerInterceptor_RateLimitPass(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxLimitKey, false)

interceptor := UnaryServerInterceptor(limiter)
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, errors.New(errMsgFake)
}
info := &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}
req, err := interceptor(nil, nil, info, handler)
assert.Nil(t, req)
assert.EqualError(t, err, "rpc error: code = ResourceExhausted desc = FakeMethod is rejected by grpc_ratelimit middleware, please retry later.")
resp, err := interceptor(ctx, nil, info, handler)
assert.Nil(t, resp)
assert.EqualError(t, err, errMsgFake)
}

func TestStreamServerInterceptor_RateLimitPass(t *testing.T) {
interceptor := StreamServerInterceptor(&mockPassLimiter{})
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxLimitKey, false)

interceptor := StreamServerInterceptor(limiter)
handler := func(srv interface{}, stream grpc.ServerStream) error {
return errors.New(errMsgFake)
}
info := &grpc.StreamServerInfo{
FullMethod: "FakeMethod",
}
err := interceptor(nil, nil, info, handler)
err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, info, handler)
assert.EqualError(t, err, errMsgFake)
}

func TestUnaryServerInterceptor_RateLimitFail(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxLimitKey, true)

interceptor := UnaryServerInterceptor(limiter)
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, errors.New(errMsgFake)
}
info := &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}
resp, err := interceptor(ctx, nil, info, handler)
assert.Nil(t, resp)
assert.EqualError(t, err, "rpc error: code = ResourceExhausted desc = FakeMethod is rejected by grpc_ratelimit middleware, please retry later.")
}

func TestStreamServerInterceptor_RateLimitFail(t *testing.T) {
interceptor := StreamServerInterceptor(&mockFailLimiter{})
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxLimitKey, true)

interceptor := StreamServerInterceptor(limiter)
handler := func(srv interface{}, stream grpc.ServerStream) error {
return errors.New(errMsgFake)
}
info := &grpc.StreamServerInfo{
FullMethod: "FakeMethod",
}
err := interceptor(nil, nil, info, handler)
err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, info, handler)
assert.EqualError(t, err, "rpc error: code = ResourceExhausted desc = FakeMethod is rejected by grpc_ratelimit middleware, please retry later.")
}