-
Notifications
You must be signed in to change notification settings - Fork 0
/
limiter.go
152 lines (137 loc) · 4.56 KB
/
limiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package ratelimiter
import (
"encoding/json"
"strings"
"sync"
"time"
"github.com/go-redis/redis"
)
// bucket is the limit bucket
type bucket struct {
Attempts int `json:"attempts"` // number of attempts made (number of requests made to API)
BlockTill time.Time `json:"block_until"` // service unavailable when max attempts reached till
}
func (t *bucket) marshalBinary() ([]byte, error) {
return json.Marshal(t)
}
func (t *bucket) unmarshalBinary(data []byte) error {
if err := json.Unmarshal(data, &t); err != nil {
return err
}
return nil
}
// RateLimiter is the limiter instance
type RateLimiter struct {
TotalLimit int // maximum allowed requests over all
BurstLimit int // maximum allowed consecutive requests in a short burst
MaxTime time.Duration // period for maximum allowed requests
BurstPeriod time.Duration // period for short bursts
Client *redis.Client
TotalLimitPrefix string // prefix for total limit key in memory cache
BurstLimitPrefix string // prefix for bursts limit key in memory cache
}
// note total requests made with user parameters and update per new request made
func (limiter *RateLimiter) UpdateTotalRequests(user_params ...string) {
key_arr := append(user_params, limiter.TotalLimitPrefix)
key := strings.Join(key_arr, "_")
var bck bucket
err := bck.unmarshalBinary([]byte(limiter.Client.Get(key).Val()))
if err != nil {
limit := &bucket{Attempts: 1, BlockTill: time.Now().Add(limiter.MaxTime)}
jsonified, _ := limit.marshalBinary()
limiter.Client.Set(key, jsonified, time.Hour*24).Val()
} else {
updated_limit := &bucket{Attempts: bck.Attempts + 1,
BlockTill: bck.BlockTill}
jsonified, _ := updated_limit.marshalBinary()
limiter.Client.Set(key, jsonified, bck.BlockTill.Sub(time.Now())).Val()
}
}
// check if total requests made within specified limit before accepting new user request
func (limiter *RateLimiter) AllowWithinTotalRequests(user_params ...string) bool {
key_arr := append(user_params, limiter.TotalLimitPrefix)
key := strings.Join(key_arr, "_")
var bck bucket
err := bck.unmarshalBinary([]byte(limiter.Client.Get(key).Val()))
if err != nil {
return true
} else {
if bck.Attempts >= limiter.TotalLimit {
if time.Now().After(bck.BlockTill) {
limiter.Client.Del(key)
return true
}
return false
}
return true
}
}
// note consecutive requests in short bursts made with user parameters and update per new request made
func (limiter *RateLimiter) UpdateConsecutiveRequests(user_params ...string) {
key_arr := append(user_params, limiter.BurstLimitPrefix)
key := strings.Join(key_arr, "_")
var bck bucket
err := bck.unmarshalBinary([]byte(limiter.Client.Get(key).Val()))
if err != nil {
limit := &bucket{Attempts: 1, BlockTill: time.Now().Add(limiter.BurstPeriod)}
jsonified, _ := limit.marshalBinary()
limiter.Client.Set(key, jsonified, time.Hour*1)
} else {
updated_limit := &bucket{Attempts: bck.Attempts + 1,
BlockTill: bck.BlockTill}
jsonified, _ := updated_limit.marshalBinary()
limiter.Client.Set(key, jsonified, bck.BlockTill.Sub(time.Now()))
}
}
// check if consecutive requests made within specified limit before accepting new user request
func (limiter *RateLimiter) AllowConsecutiveRequest(user_params ...string) bool {
key_arr := append(user_params, limiter.BurstLimitPrefix)
key := strings.Join(key_arr, "_")
var bck bucket
err := bck.unmarshalBinary([]byte(limiter.Client.Get(key).Val()))
if err != nil {
return true
} else {
if bck.Attempts >= limiter.BurstLimit {
if time.Now().After(bck.BlockTill) {
limiter.Client.Del(key)
return true
}
return false
}
return true
}
}
// note consecutive and total requests in made with user parameters and update each as per new request made
func (limiter *RateLimiter) UpdateRequest(user_params ...string) {
var wg sync.WaitGroup
wg.Add(2)
go func() {
limiter.UpdateConsecutiveRequests(user_params...)
wg.Done()
}()
go func() {
limiter.UpdateTotalRequests(user_params...)
wg.Done()
}()
wg.Wait()
}
// check if consecutive and total requests made within specified limit before accepting new user request
func (limiter *RateLimiter) AllowRequest(user_params ...string) bool {
var allowTotal, allowConsec bool
var wg sync.WaitGroup
wg.Add(2)
go func(result *bool) {
*result = limiter.AllowConsecutiveRequest(user_params...)
wg.Done()
}(&allowConsec)
go func(result *bool) {
*result = limiter.AllowWithinTotalRequests(user_params...)
wg.Done()
}(&allowTotal)
wg.Wait()
if allowConsec && allowTotal {
return true
}
return false
}