/
limiter.go
106 lines (91 loc) · 2.29 KB
/
limiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
package limiter
import (
"fmt"
"sync"
"time"
"github.com/go-redis/redis"
uuid "github.com/gofrs/uuid"
)
type opt struct {
max int
window time.Duration
}
type Limiter struct {
pool *redis.Client
opts map[string]*opt
mux sync.Mutex
}
func NewLimiter(addr string, password string, db int) (*Limiter, error) {
options := &redis.Options{
Addr: addr,
Password: password,
DB: db,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
PoolTimeout: 4 * time.Second,
IdleTimeout: 60 * time.Second,
PoolSize: 512,
}
redisPool := redis.NewClient(options)
return NewLimiterWithClient(redisPool)
}
func NewLimiterWithClient(c *redis.Client) (*Limiter, error) {
if err := c.Ping().Err(); err != nil {
return nil, err
}
return &Limiter{
pool: c,
opts: make(map[string]*opt, 0),
}, nil
}
func (limiter *Limiter) AddGroup(group string, max int, window time.Duration) {
limiter.mux.Lock()
limiter.opts[group] = &opt{max, window}
limiter.mux.Unlock()
}
func limiterKey(group, key string) string {
return fmt.Sprintf("limiter:%s:%s", group, key)
}
func (limiter *Limiter) Available(key, group string, weight int) (int, error) {
var (
max = 0
window = time.Second
)
limiter.mux.Lock()
if opt, _ := limiter.opts[group]; opt != nil {
max, window = opt.max, opt.window
}
limiter.mux.Unlock()
if max < weight {
return max - weight, nil
}
now := time.Now()
key = limiterKey(group, key)
var zcount *redis.IntCmd
_, err := limiter.pool.Pipelined(func(pipe redis.Pipeliner) error {
pipe.ZRemRangeByScore(key, "-inf", fmt.Sprint(now.Add(-window).UnixNano()/1000000))
if weight > 0 {
members := make([]*redis.Z, 0, weight)
score := float64(now.UnixNano() / 1000000)
for idx := 0; idx < weight; idx += 1 {
mem, _ := uuid.NewV4()
members = append(members, &redis.Z{Score: score, Member: mem.String()})
}
pipe.ZAdd(key, members...)
}
pipe.Expire(key, time.Second*time.Duration(int64(window.Seconds())+60))
zcount = pipe.ZCount(key, "-inf", "+inf")
return nil
})
if err != nil {
return 0, err
}
count, err := zcount.Result()
return max - int(count), err
}
func (limiter *Limiter) Clear(key, group string) error {
key = limiterKey(group, key)
zcount := limiter.pool.Del(key)
_, err := zcount.Result()
return err
}