-
Notifications
You must be signed in to change notification settings - Fork 178
/
rate_limit_interceptor.go
95 lines (73 loc) · 2.88 KB
/
rate_limit_interceptor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
package rpc
import (
"context"
"path/filepath"
"github.com/rs/zerolog"
"golang.org/x/time/rate"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const defaultRateLimit = 1000 // aggregate default rate limit for all unspecified API calls
const defaultBurst = 100 // default burst limit (calls made at the same time) for an API
// rateLimiterInterceptor rate limits the
type rateLimiterInterceptor struct {
log zerolog.Logger
// a shared default rate limiter for APIs whose rate limit is not explicitly defined
defaultLimiter *rate.Limiter
// a map of api and its limiter
methodLimiterMap map[string]*rate.Limiter
}
// NewRateLimiterInterceptor creates a new rate limiter interceptor with the defined per second rate limits and the
// optional burst limit for each API.
func NewRateLimiterInterceptor(log zerolog.Logger, apiRateLimits map[string]int, apiBurstLimits map[string]int) *rateLimiterInterceptor {
defaultLimiter := rate.NewLimiter(rate.Limit(defaultRateLimit), defaultBurst)
methodLimiterMap := make(map[string]*rate.Limiter, len(apiRateLimits))
// read rate limit values for each API and create a limiter for each
for api, limit := range apiRateLimits {
// if a burst limit is defined for this api, use that else use the default
burst := defaultBurst
if b, ok := apiBurstLimits[api]; ok {
burst = b
}
methodLimiterMap[api] = rate.NewLimiter(rate.Limit(limit), burst)
}
if len(methodLimiterMap) == 0 {
log.Info().Int("default_rate_limit", defaultRateLimit).Msg("no rate limits specified, using the default limit")
}
return &rateLimiterInterceptor{
defaultLimiter: defaultLimiter,
methodLimiterMap: methodLimiterMap,
log: log,
}
}
// UnaryServerInterceptor rate limits the given request based on the limits defined when creating the rateLimiterInterceptor
func (interceptor *rateLimiterInterceptor) UnaryServerInterceptor(ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (resp interface{}, err error) {
// remove the package name (e.g. "/flow.access.AccessAPI/Ping" to "Ping")
methodName := filepath.Base(info.FullMethod)
// look up the limiter
limiter := interceptor.methodLimiterMap[methodName]
// if not found, use the default limiter
if limiter == nil {
interceptor.log.Trace().Str("method", methodName).Msg("rate limit not defined, using default limit")
limiter = interceptor.defaultLimiter
}
// check if request within limit
if !limiter.Allow() {
// log the limit violation
interceptor.log.Trace().
Str("method", methodName).
Interface("request", req).
Float64("limit", float64(limiter.Limit())).
Msg("rate limit exceeded")
// reject the request
return nil, status.Errorf(codes.ResourceExhausted, "%s rate limit reached, please retry later.",
info.FullMethod)
}
// call the handler
h, err := handler(ctx, req)
return h, err
}