-
Notifications
You must be signed in to change notification settings - Fork 25
/
global-token-bucket.go
367 lines (317 loc) · 10.8 KB
/
global-token-bucket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
package globaltokenbucket
import (
"context"
"math"
"sync"
"time"
"github.com/buraksezer/olric"
"github.com/buraksezer/olric/config"
"google.golang.org/protobuf/types/known/timestamppb"
tokenbucketv1 "github.com/fluxninja/aperture/v2/api/gen/proto/go/aperture/tokenbucket/v1"
distcache "github.com/fluxninja/aperture/v2/pkg/dist-cache"
"github.com/fluxninja/aperture/v2/pkg/log"
ratelimiter "github.com/fluxninja/aperture/v2/pkg/rate-limiter"
)
const (
// TakeNFunction is the name of the function used to take N tokens from the bucket.
TakeNFunction = "TakeN"
lookupMargin = 10 * time.Millisecond
)
// GlobalTokenBucket implements Limiter.
type GlobalTokenBucket struct {
dMap olric.DMap
dc *distcache.DistCache
name string
bucketCapacity float64
fillAmount float64
interval time.Duration
mu sync.RWMutex
continuousFill bool
delayInitialFill bool
passThrough bool
}
// NewGlobalTokenBucket creates a new instance of DistCacheRateTracker.
func NewGlobalTokenBucket(dc *distcache.DistCache,
name string,
interval time.Duration,
maxIdleDuration time.Duration,
continuousFill bool,
delayInitialFill bool,
) (*GlobalTokenBucket, error) {
gtb := &GlobalTokenBucket{
name: name,
interval: interval,
passThrough: true,
continuousFill: continuousFill,
delayInitialFill: delayInitialFill,
dc: dc,
}
dmapConfig := config.DMap{
MaxIdleDuration: maxIdleDuration,
Functions: map[string]config.Function{
TakeNFunction: gtb.takeN,
},
}
dMap, err := dc.NewDMap(name, dmapConfig)
if err != nil {
return nil, err
}
gtb.dMap = dMap
return gtb, nil
}
// SetBucketCapacity sets the rate limit for the rate limiter.
func (gtb *GlobalTokenBucket) SetBucketCapacity(bucketCapacity float64) {
gtb.mu.Lock()
defer gtb.mu.Unlock()
gtb.bucketCapacity = bucketCapacity
}
// GetBucketCapacity returns the rate limit for the rate limiter.
func (gtb *GlobalTokenBucket) GetBucketCapacity() float64 {
gtb.mu.RLock()
defer gtb.mu.RUnlock()
return gtb.bucketCapacity
}
func isMarginExceeded(ctx context.Context) bool {
deadline, deadlineOK := ctx.Deadline()
if deadlineOK {
// check if deadline will be passed in the next 10ms
deadline = deadline.Add(-lookupMargin)
return time.Now().After(deadline)
}
return false
}
// SetFillAmount sets the default fill amount for the rate limiter.
func (gtb *GlobalTokenBucket) SetFillAmount(fillAmount float64) {
gtb.mu.Lock()
defer gtb.mu.Unlock()
gtb.fillAmount = fillAmount
}
// Name returns the name of the DistCacheRateTracker.
func (gtb *GlobalTokenBucket) Name() string {
return gtb.name
}
// Close cleans up DMap held within the DistCacheRateTracker.
func (gtb *GlobalTokenBucket) Close() error {
gtb.mu.Lock()
defer gtb.mu.Unlock()
err := gtb.dc.DeleteDMap(gtb.name)
if err != nil {
return err
}
return nil
}
func (gtb *GlobalTokenBucket) executeTakeRequest(ctx context.Context, label string, n float64, canWait bool, deadline time.Time) (bool, time.Duration, float64, float64) {
if gtb.GetPassThrough() {
return true, 0, 0, 0
}
if isMarginExceeded(ctx) {
return false, 0, 0, 0
}
req := tokenbucketv1.TakeNRequest{
Deadline: timestamppb.New(deadline),
Want: n,
CanWait: canWait,
}
reqBytes, err := req.MarshalVT()
if err != nil {
log.Autosample().Errorf("error encoding request: %v", err)
return true, 0, 0, 0
}
resultBytes, err := gtb.dMap.Function(ctx, label, TakeNFunction, reqBytes)
if err != nil {
log.Autosample().Error().Err(err).Str("dmapName", gtb.dMap.Name()).Float64("tokens", n).Msg("error taking from token bucket")
return true, 0, 0, 0
}
var resp tokenbucketv1.TakeNResponse
err = resp.UnmarshalVT(resultBytes)
if err != nil {
log.Autosample().Errorf("error decoding response: %v", err)
return true, 0, 0, 0
}
var waitTime time.Duration
availableAt := resp.GetAvailableAt().AsTime()
if !availableAt.IsZero() {
waitTime = time.Until(availableAt)
if waitTime < 0 {
waitTime = 0
}
}
return resp.Ok, waitTime, resp.Remaining, resp.Current
}
// TakeIfAvailable increments value in label by n and returns whether n events should be allowed along with the remaining value (limit - new n) after increment and the current count for the label.
// If an error occurred it returns true, 0, 0 and 0 (fail open).
// It also may return the wait time at which the tokens will be available.
func (gtb *GlobalTokenBucket) TakeIfAvailable(ctx context.Context, label string, n float64) (bool, time.Duration, float64, float64) {
return gtb.executeTakeRequest(ctx, label, n, false, time.Time{})
}
// Take increments value in label by n and returns whether n events should be allowed along with the remaining value (limit - new n) after increment and the current count for the label.
// It also returns the wait time at which the tokens will be available.
func (gtb *GlobalTokenBucket) Take(ctx context.Context, label string, n float64) (bool, time.Duration, float64, float64) {
deadline := time.Time{}
d, ok := ctx.Deadline()
if ok {
deadline = d
}
return gtb.executeTakeRequest(ctx, label, n, true, deadline)
}
// Return returns n tokens to the bucket.
func (gtb *GlobalTokenBucket) Return(ctx context.Context, label string, n float64) (float64, float64) {
_, _, remaining, current := gtb.TakeIfAvailable(ctx, label, -n)
return remaining, current
}
// takeN takes a number of tokens from the bucket.
func (gtb *GlobalTokenBucket) takeN(key string, stateBytes, argBytes []byte) ([]byte, []byte, error) {
gtb.mu.RLock()
defer gtb.mu.RUnlock()
// Decode currentState from proto encoded currentStateBytes
now := time.Now()
state, err := gtb.fastForwardState(now, stateBytes, key)
if err != nil {
return nil, nil, err
}
// Decode arg from proto encoded argBytes
var arg tokenbucketv1.TakeNRequest
if argBytes != nil {
err = arg.UnmarshalVT(argBytes)
if err != nil {
log.Autosample().Errorf("error decoding arg: %v", err)
return nil, nil, err
}
}
result := tokenbucketv1.TakeNResponse{
Ok: true,
AvailableAt: timestamppb.New(now),
}
// if we are first time drawing from the bucket, set the start fill time
if gtb.delayInitialFill && state.Available == gtb.bucketCapacity {
state.StartFillAt = timestamppb.New(now.Add(gtb.timeToFill(gtb.bucketCapacity)))
if gtb.continuousFill {
state.LastFillAt = state.StartFillAt
} else {
startFilleAtTime := state.StartFillAt.AsTime()
state.LastFillAt = timestamppb.New(startFilleAtTime.Add(-gtb.interval))
}
}
state.Available -= arg.Want
if arg.Want > 0 {
if state.Available < 0 {
result.Ok = arg.CanWait && gtb.fillAmount != 0
if gtb.fillAmount != 0 {
result.AvailableAt = timestamppb.New(gtb.getAvailableAt(now, state))
deadlineTime := arg.Deadline.AsTime()
if arg.CanWait && !deadlineTime.IsZero() && result.AvailableAt.AsTime().After(arg.Deadline.AsTime()) {
result.Ok = false
}
}
// return the tokens to the bucket if the request is not ok
if !result.Ok {
state.Available += arg.Want
}
}
}
if state.Available > gtb.bucketCapacity {
state.Available = gtb.bucketCapacity
}
result.Remaining = state.Available
result.Current = gtb.bucketCapacity - state.Available
// Encode result to proto encoded resultBytes
resultBytes, err := result.MarshalVT()
if err != nil {
log.Autosample().Errorf("error encoding result: %v", err)
return nil, nil, err
}
// Encode currentState to proto encoded newStateBytes
newStateBytes, err := state.MarshalVT()
if err != nil {
log.Autosample().Errorf("error encoding new state: %v", err)
return nil, nil, err
}
return newStateBytes, resultBytes, nil
}
func (gtb *GlobalTokenBucket) fastForwardState(now time.Time, stateBytes []byte, key string) (*tokenbucketv1.State, error) {
var state tokenbucketv1.State
if stateBytes != nil {
err := state.UnmarshalVT(stateBytes)
if err != nil {
log.Autosample().Errorf("error decoding current state: %v", err)
return nil, err
}
} else {
log.Info().Msgf("Creating new token bucket state for key %s in dmap %s", key, gtb.dMap.Name())
state.LastFillAt = timestamppb.New(now)
state.Available = gtb.bucketCapacity
}
startFillAtTime := state.StartFillAt.AsTime()
lastFillAtTime := state.LastFillAt.AsTime()
// do not fill the bucket until the start fill time
if startFillAtTime.IsZero() || now.After(startFillAtTime) {
// Calculate the time passed since the last fill
sinceLastFill := now.Sub(lastFillAtTime)
fillAmount := 0.0
if gtb.continuousFill {
fillAmount = gtb.fillAmount * float64(sinceLastFill) / float64(gtb.interval)
state.LastFillAt = timestamppb.New(now)
} else if sinceLastFill >= gtb.interval {
fills := int(sinceLastFill / gtb.interval)
if fills > 0 {
fillAmount = gtb.fillAmount * float64(fills)
state.LastFillAt = timestamppb.New(lastFillAtTime.Add(time.Duration(fills) * gtb.interval))
}
}
// Fill the calculated amount
state.Available += fillAmount
if state.Available > gtb.bucketCapacity {
state.Available = gtb.bucketCapacity
}
}
return &state, nil
}
// timeToFill calculates the wait time for the given number of tokens based on the fill rate.
func (gtb *GlobalTokenBucket) timeToFill(tokens float64) time.Duration {
if gtb.fillAmount != 0 {
if gtb.continuousFill {
return time.Duration(tokens / gtb.fillAmount * float64(gtb.interval))
} else {
// calculate how many fills we need
fills := math.Ceil(tokens / gtb.fillAmount)
return time.Duration(fills) * gtb.interval
}
}
return 0
}
// getAvailableAt calculates the time at which the given number of tokens will be available.
func (gtb *GlobalTokenBucket) getAvailableAt(now time.Time, state *tokenbucketv1.State) time.Time {
if state.Available >= 0 {
return now
}
timeToFill := gtb.timeToFill(-state.Available)
startFillAtTime := state.StartFillAt.AsTime()
if now.Before(startFillAtTime) {
return startFillAtTime.Add(timeToFill)
} else {
// this code assumes that other parts of the code are correct, such as
// LastFill is not in the future if now is after StartFillAt
// And timeSinceLastFill is not greater than interval
lastFillAtTime := state.LastFillAt.AsTime()
timeSinceLastFill := now.Sub(lastFillAtTime)
if timeSinceLastFill > gtb.interval {
log.Autosample().Errorf("time since last fill is greater than interval: %v", timeSinceLastFill)
timeSinceLastFill = time.Duration(0)
}
return now.Add(timeToFill - timeSinceLastFill)
}
}
// SetPassThrough sets the pass through flag.
func (gtb *GlobalTokenBucket) SetPassThrough(passThrough bool) {
gtb.mu.Lock()
defer gtb.mu.Unlock()
gtb.passThrough = passThrough
}
// GetPassThrough returns the pass through flag.
func (gtb *GlobalTokenBucket) GetPassThrough() bool {
gtb.mu.RLock()
defer gtb.mu.RUnlock()
return gtb.passThrough
}
// Make sure TokenBucketRateTracker implements Limiter interface.
var _ ratelimiter.RateLimiter = (*GlobalTokenBucket)(nil)