/
shard_rate_limiter_impl.go
118 lines (99 loc) · 2.32 KB
/
shard_rate_limiter_impl.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package sharding
import (
"context"
"log/slog"
"sync"
"time"
"github.com/sasha-s/go-csync"
)
var _ RateLimiter = (*rateLimiterImpl)(nil)
// NewRateLimiter creates a new default RateLimiter with the given RateLimiterConfigOpt(s).
func NewRateLimiter(opts ...RateLimiterConfigOpt) RateLimiter {
config := DefaultRateLimiterConfig()
config.Apply(opts)
config.Logger = config.Logger.With(slog.String("name", "sharding_rate_limiter"))
return &rateLimiterImpl{
buckets: map[int]*bucket{},
config: *config,
}
}
type rateLimiterImpl struct {
mu sync.Mutex
buckets map[int]*bucket
config RateLimiterConfig
}
func (r *rateLimiterImpl) Close(ctx context.Context) {
var wg sync.WaitGroup
r.mu.Lock()
for key := range r.buckets {
wg.Add(1)
b := r.buckets[key]
go func() {
defer wg.Done()
if err := b.mu.CLock(ctx); err != nil {
r.config.Logger.Error("failed to close bucket: ", err)
}
b.mu.Unlock()
}()
}
}
func (r *rateLimiterImpl) getBucket(shardID int, create bool) *bucket {
r.config.Logger.Debug("locking shard rate limiter")
r.mu.Lock()
defer func() {
r.config.Logger.Debug("unlocking shard rate limiter")
r.mu.Unlock()
}()
key := ShardMaxConcurrencyKey(shardID, r.config.MaxConcurrency)
b, ok := r.buckets[key]
if !ok {
if !create {
return nil
}
b = &bucket{
Key: key,
}
r.buckets[key] = b
}
return b
}
func (r *rateLimiterImpl) WaitBucket(ctx context.Context, shardID int) error {
b := r.getBucket(shardID, true)
r.config.Logger.Debug("locking shard bucket", slog.Int("key", b.Key), slog.Time("reset", b.Reset))
if err := b.mu.CLock(ctx); err != nil {
return err
}
var until time.Time
now := time.Now()
if b.Reset.After(now) {
until = b.Reset
}
if until.After(now) {
if deadline, ok := ctx.Deadline(); ok && until.After(deadline) {
return context.DeadlineExceeded
}
select {
case <-ctx.Done():
b.mu.Unlock()
return ctx.Err()
case <-time.After(until.Sub(now)):
}
}
return nil
}
func (r *rateLimiterImpl) UnlockBucket(shardID int) {
b := r.getBucket(shardID, false)
if b == nil {
return
}
defer func() {
r.config.Logger.Debug("unlocking shard bucket", slog.Int("key", b.Key), slog.Time("reset", b.Reset))
b.mu.Unlock()
}()
b.Reset = time.Now().Add(5 * time.Second)
}
type bucket struct {
mu csync.Mutex
Key int
Reset time.Time
}