From 97ca004d8ebfda255a5b378ca91eddfe020bb19f Mon Sep 17 00:00:00 2001 From: da440dil Date: Sat, 18 Dec 2021 16:23:44 +0300 Subject: [PATCH] use limit script at batch limiter --- benchmark_test.go | 46 +++++++++++ counter.go | 127 ++--------------------------- counter_test.go | 154 ++++------------------------------- examples/limiter/main.go | 9 ++- fixedwindow.lua | 6 +- fixedwindow_test.go | 20 ++--- limit.lua | 70 ++++++++++++++++ limiter.go | 171 +++++++++++++++++++++++++++++++++++++++ limiter_test.go | 128 +++++++++++++++++++++++++++++ slidingwindow_test.go | 16 ++-- 10 files changed, 468 insertions(+), 279 deletions(-) create mode 100644 limit.lua create mode 100644 limiter.go create mode 100644 limiter_test.go diff --git a/benchmark_test.go b/benchmark_test.go index ba93ddd..c799326 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -34,3 +34,49 @@ func BenchmarkCounter(b *testing.B) { }) } } + +func BenchmarkLimiter(b *testing.B) { + client := redis.NewClient(&redis.Options{}) + defer client.Close() + + size := 10 * time.Second + limit := uint(10000) + tests := map[string]Limiter{ + "One": NewLimiter( + client, + WithLimiter(size, limit), + ), + "Two": NewLimiter( + client, + WithLimiter(size, limit), + WithLimiter(size*2, limit*2), + ), + "Three": NewLimiter( + client, + WithLimiter(size, limit), + WithLimiter(size*2, limit*2), + WithLimiter(size*3, limit*3), + ), + "Four": NewLimiter( + client, + WithLimiter(size, limit), + WithLimiter(size*2, limit*2), + WithLimiter(size*3, limit*3), + WithLimiter(size*4, limit*4), + ), + } + + ctx := context.Background() + key := "key" + for name, tc := range tests { + b.Run(name, func(b *testing.B) { + err := client.Del(ctx, key).Err() + if err != nil { + b.Fatal(err) + } + for i := 0; i < b.N; i++ { + tc.Limit(ctx, key) + } + }) + } +} diff --git a/counter.go b/counter.go index cf4a720..0ae366e 100644 --- a/counter.go +++ b/counter.go @@ -5,9 +5,6 @@ import ( "context" _ "embed" "errors" - "math/rand" - "strconv" - "sync" "time" "github.com/go-redis/redis/v8" @@ -25,7 +22,7 @@ type RedisClient interface { type Result struct { counter int64 ttl int64 - limit int + limit int64 } // OK is operation success flag. @@ -34,13 +31,13 @@ func (r Result) OK() bool { } // Counter is current counter value. -func (r Result) Counter() int { - return int(r.counter) +func (r Result) Counter() int64 { + return r.counter } // Remainder is diff between limit and current counter value. -func (r Result) Remainder() int { - return r.limit - int(r.counter) +func (r Result) Remainder() int64 { + return r.limit - r.counter } // TTL of the current window. @@ -56,8 +53,8 @@ var ErrUnexpectedRedisResponse = errors.New("counter: unexpected redis response" type Counter struct { client RedisClient script *redis.Script + limit int64 size int - limit int } // Count increments key by value. @@ -83,9 +80,6 @@ func (c *Counter) Count(ctx context.Context, key string, value int) (Result, err return r, ErrUnexpectedRedisResponse } r.limit = c.limit - if r.ttl == -2 { - r.ttl = 0 - } return r, nil } @@ -95,7 +89,7 @@ var fwscr = redis.NewScript(fwsrc) // FixedWindow creates new counter which implements distributed counter using fixed window algorithm. func FixedWindow(client RedisClient, size time.Duration, limit uint) *Counter { - return &Counter{client, fwscr, int(size / time.Millisecond), int(limit)} + return &Counter{client: client, script: fwscr, size: int(size / time.Millisecond), limit: int64(limit)} } //go:embed slidingwindow.lua @@ -104,110 +98,5 @@ var swscr = redis.NewScript(swsrc) // SlidingWindow creates new counter which implements distributed counter using sliding window algorithm. func SlidingWindow(client RedisClient, size time.Duration, limit uint) *Counter { - return &Counter{client, swscr, int(size / time.Millisecond), int(limit)} -} - -// Limiter implements distributed rate limiting. -type Limiter interface { - // Limit applies the limit: increments key value of each distributed counter. - Limit(ctx context.Context, key string) (Result, error) -} - -var random *rand.Rand - -func init() { - random = rand.New(rand.NewSource(time.Now().UnixNano())) -} - -// NewLimiter creates new limiter which implements distributed rate limiting. -// Each limiter is created with pseudo-random name which may be set with options, every Redis key will be prefixed with this name. -// The rate of decreasing the window size on each next limiter call by default equal 1, may be set with options. -func NewLimiter(c *Counter, options ...func(*limiter)) Limiter { - lt := &limiter{c, strconv.Itoa(random.Int()) + ":", 1} - for _, option := range options { - option(lt) - } - return lt -} - -// WithLimiterName sets unique limiter name. -func WithLimiterName(name string) func(*limiter) { - return func(lt *limiter) { - lt.prefix = name + ":" - } -} - -// WithLimiterRate sets limiter rate of decreasing the window size on each next limiter call. -func WithLimiterRate(rate uint) func(*limiter) { - return func(lt *limiter) { - lt.rate = int(rate) - } -} - -// NewLimiterSuite creates new limiter suite which contains two or more limiters which run concurently on every limiter suite call. -func NewLimiterSuite(v1 Limiter, v2 Limiter, vs ...Limiter) Limiter { - lts := append([]Limiter{v1, v2}, vs...) - return &limiters{lts: lts, size: len(lts)} -} - -type limiter struct { - counter *Counter - prefix string - rate int -} - -func (lt *limiter) Limit(ctx context.Context, key string) (Result, error) { - return lt.counter.Count(ctx, lt.prefix+key, lt.rate) -} - -type limiters struct { - lts []Limiter - wg sync.WaitGroup - mu sync.Mutex - size int -} - -const maxInt = int(^uint(0) >> 1) - -func (ls *limiters) Limit(ctx context.Context, key string) (Result, error) { - results := make([]result, ls.size) - - ls.mu.Lock() - ls.wg.Add(ls.size) - for i := 0; i < ls.size; i++ { - go func(i int) { - defer ls.wg.Done() - r, err := ls.lts[i].Limit(ctx, key) - results[i] = result{r, err} - }(i) - } - ls.wg.Wait() - ls.mu.Unlock() - - r := Result{0, int64(-1), maxInt} - for i := 0; i < ls.size; i++ { - v := results[i] - if v.err != nil { - return r, v.err - } - if v.result.OK() { - if r.OK() && r.Remainder() > v.result.Remainder() { // minimal remainder - r = v.result - } - continue - } - if r.OK() { // not ok first time - r = v.result - continue - } - if r.TTL() < v.result.TTL() { // maximum TTL - r = v.result - } - } - return r, nil -} - -type result struct { - result Result - err error + return &Counter{client: client, script: swscr, size: int(size / time.Millisecond), limit: int64(limit)} } diff --git a/counter_test.go b/counter_test.go index 98f0ec8..d22b620 100644 --- a/counter_test.go +++ b/counter_test.go @@ -35,170 +35,50 @@ func (m *ClientMock) ScriptLoad(ctx context.Context, script string) *redis.Strin func TestCounter(t *testing.T) { clientMock := &ClientMock{} size := 1000 - limit := 100 - scr := redis.NewScript("") - c := &Counter{clientMock, scr, size, limit} + limit := int64(100) + c := &Counter{client: clientMock, script: fwscr, size: size, limit: limit} ctx := context.Background() - key := "key" - keys := []string{key} + hash := fwscr.Hash() + value := 1 var i interface{} - v := 1 e := errors.New("redis error") - clientMock.On("EvalSha", ctx, scr.Hash(), keys, v, size, limit).Return(redis.NewCmdResult(i, e)) - _, err := c.Count(ctx, key, v) + clientMock.On("EvalSha", ctx, hash, []string{"1"}, value, size, limit).Return(redis.NewCmdResult(i, e)) + _, err := c.Count(ctx, "1", value) require.Equal(t, e, err) - v = 2 - clientMock.On("EvalSha", ctx, scr.Hash(), keys, v, size, limit).Return(redis.NewCmdResult(i, nil)) - _, err = c.Count(ctx, key, v) + clientMock.On("EvalSha", ctx, hash, []string{"2"}, value, size, limit).Return(redis.NewCmdResult(i, nil)) + _, err = c.Count(ctx, "2", value) require.Equal(t, ErrUnexpectedRedisResponse, err) - v = 3 i = []interface{}{1} - clientMock.On("EvalSha", ctx, scr.Hash(), keys, v, size, limit).Return(redis.NewCmdResult(i, nil)) - _, err = c.Count(ctx, key, v) + clientMock.On("EvalSha", ctx, hash, []string{"3"}, value, size, limit).Return(redis.NewCmdResult(i, nil)) + _, err = c.Count(ctx, "3", value) require.Equal(t, ErrUnexpectedRedisResponse, err) - v = 4 i = []interface{}{1, -1} - clientMock.On("EvalSha", ctx, scr.Hash(), keys, v, size, limit).Return(redis.NewCmdResult(i, nil)) - _, err = c.Count(ctx, key, v) + clientMock.On("EvalSha", ctx, hash, []string{"4"}, value, size, limit).Return(redis.NewCmdResult(i, nil)) + _, err = c.Count(ctx, "4", value) require.Equal(t, ErrUnexpectedRedisResponse, err) - v = 5 i = []interface{}{int64(1), -1} - clientMock.On("EvalSha", ctx, scr.Hash(), keys, v, size, limit).Return(redis.NewCmdResult(i, nil)) - _, err = c.Count(ctx, key, v) + clientMock.On("EvalSha", ctx, hash, []string{"5"}, value, size, limit).Return(redis.NewCmdResult(i, nil)) + _, err = c.Count(ctx, "5", value) require.Equal(t, ErrUnexpectedRedisResponse, err) - v = 6 i = []interface{}{int64(1), int64(-1)} - clientMock.On("EvalSha", ctx, scr.Hash(), keys, v, size, limit).Return(redis.NewCmdResult(i, nil)) - result, err := c.Count(ctx, key, v) + clientMock.On("EvalSha", ctx, hash, []string{"6"}, value, size, limit).Return(redis.NewCmdResult(i, nil)) + result, err := c.Count(ctx, "6", value) require.NoError(t, err) require.True(t, result.OK()) - require.Equal(t, 1, result.Counter()) + require.Equal(t, int64(1), result.Counter()) require.Equal(t, limit-1, result.Remainder()) require.Equal(t, msToDuration(-1), result.TTL()) clientMock.AssertExpectations(t) } -func TestLimiter(t *testing.T) { - clientMock := &ClientMock{} - size := 1000 - limit := 100 - scr := redis.NewScript("") - c := &Counter{clientMock, scr, size, limit} - ctx := context.Background() - key := "key" - - rate := uint(1) - v := int(rate) - - name := "name" - lt := NewLimiter(c, WithLimiterName(name), WithLimiterRate(rate)) - - clientMock.On("EvalSha", ctx, scr.Hash(), []string{name + ":" + key}, v, size, limit).Return(redis.NewCmdResult([]interface{}{int64(rate), int64(-1)}, nil)) - result, err := lt.Limit(ctx, key) - require.NoError(t, err) - require.True(t, result.OK()) - require.Equal(t, v, result.Counter()) - require.Equal(t, limit-v, result.Remainder()) - require.Equal(t, msToDuration(-1), result.TTL()) - - clientMock.AssertExpectations(t) -} - -func TestLimiterSuite(t *testing.T) { - clientMock := &ClientMock{} - size := 1000 - limit := 100 - scr := redis.NewScript("") - c := &Counter{clientMock, scr, size, limit} - ctx := context.Background() - key := "key" - - rate := uint(1) - v := int(rate) - - n1 := "name1" - lt1 := NewLimiter(c, WithLimiterName(n1), WithLimiterRate(rate)) - r1 := 25 - clientMock.On("EvalSha", ctx, scr.Hash(), []string{n1 + ":" + key}, v, size, limit).Return(redis.NewCmdResult([]interface{}{int64(r1), int64(-1)}, nil)) - - n2 := "name2" - lt2 := NewLimiter(c, WithLimiterName(n2), WithLimiterRate(rate)) - r2 := 58 - clientMock.On("EvalSha", ctx, scr.Hash(), []string{n2 + ":" + key}, v, size, limit).Return(redis.NewCmdResult([]interface{}{int64(r2), int64(-1)}, nil)) - - n3 := "name3" - lt3 := NewLimiter(c, WithLimiterName(n3), WithLimiterRate(rate)) - r3 := 26 - clientMock.On("EvalSha", ctx, scr.Hash(), []string{n3 + ":" + key}, v, size, limit).Return(redis.NewCmdResult([]interface{}{int64(r3), int64(-1)}, nil)) - - ls1 := NewLimiterSuite(lt1, lt2, lt3) - result, err := ls1.Limit(ctx, key) - require.NoError(t, err) - require.True(t, result.OK()) - require.Equal(t, r2, result.Counter()) - require.Equal(t, limit-r2, result.Remainder()) - require.Equal(t, msToDuration(-1), result.TTL()) - - n4 := "name4" - lt4 := NewLimiter(c, WithLimiterName(n4), WithLimiterRate(rate)) - r4 := 58 - t4 := int64(42) - clientMock.On("EvalSha", ctx, scr.Hash(), []string{n4 + ":" + key}, v, size, limit).Return(redis.NewCmdResult([]interface{}{int64(r4), t4}, nil)) - - ls2 := NewLimiterSuite(lt1, lt4, lt3) - result, err = ls2.Limit(ctx, key) - require.NoError(t, err) - require.False(t, result.OK()) - require.Equal(t, r4, result.Counter()) - require.Equal(t, limit-r4, result.Remainder()) - require.Equal(t, msToDuration(t4), result.TTL()) - - n5 := "name5" - lt5 := NewLimiter(c, WithLimiterName(n5), WithLimiterRate(rate)) - r5 := 58 - t5 := int64(42) - clientMock.On("EvalSha", ctx, scr.Hash(), []string{n5 + ":" + key}, v, size, limit).Return(redis.NewCmdResult([]interface{}{int64(r5), t5}, nil)) - - n6 := "name6" - lt6 := NewLimiter(c, WithLimiterName(n6), WithLimiterRate(rate)) - r6 := 25 - t6 := int64(75) - clientMock.On("EvalSha", ctx, scr.Hash(), []string{n6 + ":" + key}, v, size, limit).Return(redis.NewCmdResult([]interface{}{int64(r6), t6}, nil)) - - n7 := "name7" - lt7 := NewLimiter(c, WithLimiterName(n7), WithLimiterRate(rate)) - r7 := 26 - t7 := int64(74) - clientMock.On("EvalSha", ctx, scr.Hash(), []string{n7 + ":" + key}, v, size, limit).Return(redis.NewCmdResult([]interface{}{int64(r7), t7}, nil)) - - ls3 := NewLimiterSuite(lt5, lt6, lt7) - result, err = ls3.Limit(ctx, key) - require.NoError(t, err) - require.False(t, result.OK()) - require.Equal(t, r6, result.Counter()) - require.Equal(t, limit-r6, result.Remainder()) - require.Equal(t, msToDuration(t6), result.TTL()) - - n8 := "name8" - lt8 := NewLimiter(c, WithLimiterName(n8), WithLimiterRate(rate)) - e := errors.New("redis error") - clientMock.On("EvalSha", ctx, scr.Hash(), []string{n8 + ":" + key}, v, size, limit).Return(redis.NewCmdResult(0, e)) - - ls4 := NewLimiterSuite(lt1, lt8, lt2) - _, err = ls4.Limit(ctx, key) - require.Equal(t, e, err) - - clientMock.AssertExpectations(t) -} - func msToDuration(ms int64) time.Duration { return time.Duration(ms) * time.Millisecond } diff --git a/examples/limiter/main.go b/examples/limiter/main.go index dcee835..566b934 100644 --- a/examples/limiter/main.go +++ b/examples/limiter/main.go @@ -19,12 +19,13 @@ func main() { err := client.Del(ctx, key).Err() requireNoError(err) - // Create limiter suite with 2 limiters. - ls := counter.NewLimiterSuite( + // Create limiter with 2 limiters. + ls := counter.NewLimiter( + client, // First limiter is limited to 3 calls per second. - counter.NewLimiter(counter.FixedWindow(client, time.Second, 3)), + counter.WithLimiter(time.Second, 3), // Second limiter is limited to 5 calls per 2 seconds. - counter.NewLimiter(counter.FixedWindow(client, time.Second*2, 5)), + counter.WithLimiter(time.Second*2, 5), ) limit := func() { diff --git a/fixedwindow.lua b/fixedwindow.lua index 8fcae6d..44b29e8 100644 --- a/fixedwindow.lua +++ b/fixedwindow.lua @@ -3,7 +3,11 @@ if counter == false then counter = 0 end if counter + ARGV[1] > tonumber(ARGV[3]) then - return { tonumber(counter), redis.call("pttl", KEYS[1]) } + local v = redis.call("pttl", KEYS[1]) + if v == -2 then + v = 0 + end + return { tonumber(counter), v } end if counter == 0 then redis.call("set", KEYS[1], ARGV[1], "px", ARGV[2]) diff --git a/fixedwindow_test.go b/fixedwindow_test.go index 550b993..3fdebec 100644 --- a/fixedwindow_test.go +++ b/fixedwindow_test.go @@ -24,29 +24,29 @@ func TestFixedWindow(t *testing.T) { result, err := counter.Count(ctx, key, 101) require.NoError(t, err) require.False(t, result.OK()) - require.Equal(t, 0, result.Counter()) - require.Equal(t, 100, result.Remainder()) + require.Equal(t, int64(0), result.Counter()) + require.Equal(t, int64(100), result.Remainder()) require.Equal(t, msToDuration(0), result.TTL()) result, err = counter.Count(ctx, key, 20) require.NoError(t, err) require.True(t, result.OK()) - require.Equal(t, 20, result.Counter()) - require.Equal(t, 80, result.Remainder()) + require.Equal(t, int64(20), result.Counter()) + require.Equal(t, int64(80), result.Remainder()) require.Equal(t, msToDuration(-1), result.TTL()) result, err = counter.Count(ctx, key, 30) require.NoError(t, err) require.True(t, result.OK()) - require.Equal(t, 50, result.Counter()) - require.Equal(t, 50, result.Remainder()) + require.Equal(t, int64(50), result.Counter()) + require.Equal(t, int64(50), result.Remainder()) require.Equal(t, msToDuration(-1), result.TTL()) result, err = counter.Count(ctx, key, 51) require.NoError(t, err) require.False(t, result.OK()) - require.Equal(t, 50, result.Counter()) - require.Equal(t, 50, result.Remainder()) + require.Equal(t, int64(50), result.Counter()) + require.Equal(t, int64(50), result.Remainder()) require.True(t, result.TTL() >= msToDuration(0) && result.TTL() <= size) time.Sleep(result.TTL() + 100*time.Millisecond) // wait for the next window to start @@ -54,7 +54,7 @@ func TestFixedWindow(t *testing.T) { result, err = counter.Count(ctx, key, 70) require.NoError(t, err) require.True(t, result.OK()) - require.Equal(t, 70, result.Counter()) - require.Equal(t, 30, result.Remainder()) + require.Equal(t, int64(70), result.Counter()) + require.Equal(t, int64(30), result.Remainder()) require.Equal(t, msToDuration(-1), result.TTL()) } diff --git a/limit.lua b/limit.lua new file mode 100644 index 0000000..64d5df0 --- /dev/null +++ b/limit.lua @@ -0,0 +1,70 @@ +local function fixedWindow(key, value, size, limit) + local counter = redis.call("get", key) + if counter == false then + counter = 0 + end + if counter + value > limit then + local v = redis.call("pttl", key) + if v == -2 then + v = 0 + end + return { tonumber(counter), v } + end + if counter == 0 then + redis.call("set", key, value, "px", size) + return { tonumber(value), -1 } + end + return { redis.call("incrby", key, value), -1 } +end + +local function slidingWindow(key, value, size, limit) + local t = redis.call("time") + local now = t[1] * 1000 + math.floor(t[2]/1000) + local currWindowTime = now - now % size + local currWindowKey = key .. ":" .. currWindowTime + local prevWindowKey = key .. ":" .. currWindowTime - size + local currWindowCounter = redis.call("get", currWindowKey) + if currWindowCounter == false then + currWindowCounter = 0 + end + local prevWindowCounter = redis.call("get", prevWindowKey) + if prevWindowCounter == false then + prevWindowCounter = 0 + end + local currWindowRemainingDuration = size - (now - currWindowTime) + local slidingWindowCounter = math.floor(prevWindowCounter * (currWindowRemainingDuration / size) + currWindowCounter) + local counter = slidingWindowCounter + value + if counter > limit then + return { slidingWindowCounter, currWindowRemainingDuration } + end + if currWindowCounter == 0 then + redis.call("set", currWindowKey, value, "px", size * 2) + else + redis.call("incrby", currWindowKey, value) + end + return { counter, -1 } +end + +local z = 0 +local limit, v, result +for i, key in ipairs(KEYS) do + z = z + 4 + limit = tonumber(ARGV[z - 1]) + if ARGV[z] == "1" then + v = fixedWindow(key, ARGV[z - 3], ARGV[z - 2], limit) + else + v = slidingWindow(key, ARGV[z - 3], ARGV[z - 2], limit) + end + if i == 1 then -- first result + result = { v[1], v[2], limit }; + elseif v[2] == -1 then -- ok + if result[2] == -1 and result[3] - result[1] > limit - v[1] then -- minimal remainder + result = { v[1], v[2], limit }; + end + elseif result[2] == -1 then -- not ok first time + result = { v[1], v[2], limit }; + elseif result[2] < v[2] then -- maximum TTL + result = { v[1], v[2], limit }; + end +end +return result \ No newline at end of file diff --git a/limiter.go b/limiter.go new file mode 100644 index 0000000..735673d --- /dev/null +++ b/limiter.go @@ -0,0 +1,171 @@ +package counter + +import ( + "context" + _ "embed" + "math/rand" + "strconv" + "time" + + "github.com/go-redis/redis/v8" +) + +var random *rand.Rand + +func init() { + random = rand.New(rand.NewSource(time.Now().UnixNano())) +} + +// Limiter implements distributed rate limiting. Contains one or more distributed counters. +type Limiter interface { + // Limit applies the limit: increments key value of each distributed counter. + Limit(ctx context.Context, key string) (Result, error) +} + +type params struct { + prefix string + alg int + rate int + size int + limit int64 +} + +const ( + algFixed = 1 + algSliding = 2 +) + +// WithLimiter creates params to build limiter. +// +// Each limiter uses fixed window algorithm by default, may be set with options. +// Each limiter is created with pseudo-random name which may be set with options, every Redis key will be prefixed with this name. +// The rate of decreasing the window size on each next limiter call by default equal 1, may be set with options. +func WithLimiter(size time.Duration, limit uint, options ...func(*params)) *params { + p := ¶ms{alg: algFixed, size: int(size / time.Millisecond), limit: int64(limit)} + for _, opt := range options { + opt(p) + } + if p.prefix == "" { + p.prefix = strconv.Itoa(random.Int()) + ":" + } + if p.rate == 0 { + p.rate = 1 + } + return p +} + +// WithFixedWindow sets limiter algorithm to fixed window. +func WithFixedWindow() func(*params) { + return func(p *params) { + p.alg = algFixed + } +} + +// WithSlidingWindow sets limiter algorithm to sliding window. +func WithSlidingWindow() func(*params) { + return func(p *params) { + p.alg = algSliding + } +} + +// WithName sets unique limiter name. +func WithName(name string) func(*params) { + return func(p *params) { + p.prefix = name + ":" + } +} + +// WithRate sets limiter rate of decreasing the window size on each next limiter call. +func WithRate(rate uint) func(*params) { + return func(p *params) { + p.rate = int(rate) + } +} + +// NewLimiter creates new limiter which implements distributed rate limiting. +func NewLimiter(client RedisClient, first *params, rest ...*params) Limiter { + n := len(rest) + if n == 0 { + var scr *redis.Script + if first.alg == algFixed { + scr = fwscr + } else { + scr = swscr + } + c := &Counter{client: client, script: scr, size: first.size, limit: first.limit} + return &limiter{counter: c, prefix: first.prefix, rate: first.rate} + } + + size := n + 1 + prefixes := make([]string, size) + prefixes[0] = first.prefix + args := make([]interface{}, size*4) + args[0] = first.rate + args[1] = first.size + args[2] = first.limit + args[3] = first.alg + + z := 0 + for i := 0; i < n; i++ { + z += 4 + prefixes[i+1] = rest[i].prefix + args[z] = rest[i].rate + args[z+1] = rest[i].size + args[z+2] = rest[i].limit + args[z+3] = rest[i].alg + } + + return &batchlimiter{client: client, prefixes: prefixes, args: args} +} + +type limiter struct { + counter *Counter + prefix string + rate int +} + +func (lt *limiter) Limit(ctx context.Context, key string) (Result, error) { + return lt.counter.Count(ctx, lt.prefix+key, lt.rate) +} + +type batchlimiter struct { + client RedisClient + prefixes []string + args []interface{} +} + +//go:embed limit.lua +var ltsrc string +var ltscr = redis.NewScript(ltsrc) + +func (blt *batchlimiter) Limit(ctx context.Context, key string) (Result, error) { + keys := make([]string, len(blt.prefixes)) + for i := 0; i < len(blt.prefixes); i++ { + keys[i] = blt.prefixes[i] + key + } + r := Result{} + res, err := ltscr.Run(ctx, blt.client, keys, blt.args...).Result() + if err != nil { + return r, err + } + arr, ok := res.([]interface{}) + if !ok { + return r, ErrUnexpectedRedisResponse + } + if len(arr) != 3 { + return r, ErrUnexpectedRedisResponse + } + r.counter, ok = arr[0].(int64) + if !ok { + return r, ErrUnexpectedRedisResponse + } + r.ttl, ok = arr[1].(int64) + if !ok { + return r, ErrUnexpectedRedisResponse + } + r.limit, ok = arr[2].(int64) + if !ok { + return r, ErrUnexpectedRedisResponse + } + return r, nil +} diff --git a/limiter_test.go b/limiter_test.go new file mode 100644 index 0000000..e0da145 --- /dev/null +++ b/limiter_test.go @@ -0,0 +1,128 @@ +package counter + +import ( + "context" + "errors" + "math/rand" + "testing" + "time" + + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/require" +) + +func TestNewLimiter(t *testing.T) { + clientMock := &ClientMock{} + size := time.Second + limit := uint(100) + sizev := int(size / time.Millisecond) + limitv := int64(limit) + + v1 := NewLimiter(clientMock, WithLimiter(size, limit, WithName("x"))) + require.Equal(t, &limiter{counter: &Counter{client: clientMock, script: fwscr, size: sizev, limit: limitv}, prefix: "x:", rate: 1}, v1) + + v2 := NewLimiter(clientMock, WithLimiter(size, limit, WithName("x"), WithFixedWindow())) + require.Equal(t, &limiter{counter: &Counter{client: clientMock, script: fwscr, size: sizev, limit: limitv}, prefix: "x:", rate: 1}, v2) + + v3 := NewLimiter(clientMock, WithLimiter(size, limit, WithName("x"), WithSlidingWindow())) + require.Equal(t, &limiter{counter: &Counter{client: clientMock, script: swscr, size: sizev, limit: limitv}, prefix: "x:", rate: 1}, v3) + + v4 := NewLimiter(clientMock, WithLimiter(size, limit, WithName("x"), WithRate(2))) + require.Equal(t, &limiter{counter: &Counter{client: clientMock, script: fwscr, size: sizev, limit: limitv}, prefix: "x:", rate: 2}, v4) + + v5 := NewLimiter(clientMock, WithLimiter(size, limit, WithName("x")), WithLimiter(size, limit, WithName("y"))) + require.Equal(t, &batchlimiter{client: clientMock, prefixes: []string{"x:", "y:"}, args: []interface{}{1, sizev, limitv, algFixed, 1, sizev, limitv, algFixed}}, v5) + + rnd := random + random = rand.New(rand.NewSource(42)) + defer func() { + random = rnd + }() + + v6 := NewLimiter(clientMock, WithLimiter(size, limit)) + require.Equal(t, &limiter{counter: &Counter{client: clientMock, script: fwscr, size: sizev, limit: limitv}, prefix: "3440579354231278675:", rate: 1}, v6) +} + +func TestLimiter(t *testing.T) { + clientMock := &ClientMock{} + size := 1000 + limit := int64(100) + c := &Counter{client: clientMock, script: fwscr, size: size, limit: limit} + prefix := "x:" + rate := 1 + lt := &limiter{counter: c, prefix: prefix, rate: rate} + ctx := context.Background() + hash := fwscr.Hash() + + var i interface{} + + e := errors.New("redis error") + clientMock.On("EvalSha", ctx, hash, []string{"x:1"}, rate, size, limit).Return(redis.NewCmdResult(i, e)) + _, err := lt.Limit(ctx, "1") + require.Equal(t, e, err) + + i = []interface{}{int64(1), int64(-1)} + clientMock.On("EvalSha", ctx, hash, []string{"x:2"}, rate, size, limit).Return(redis.NewCmdResult(i, nil)) + result, err := lt.Limit(ctx, "2") + require.NoError(t, err) + require.True(t, result.OK()) + require.Equal(t, int64(1), result.Counter()) + require.Equal(t, limit-1, result.Remainder()) + require.Equal(t, msToDuration(-1), result.TTL()) + + clientMock.AssertExpectations(t) +} + +func TestBatchLimiter(t *testing.T) { + clientMock := &ClientMock{} + rate := 1 + size := 1000 + limit := int64(100) + prefixes := []string{"x:", "y:"} + args := []interface{}{rate, size, limit, algFixed, rate, size, limit, algFixed} + blt := &batchlimiter{client: clientMock, prefixes: prefixes, args: args} + ctx := context.Background() + hash := ltscr.Hash() + + var i interface{} + + e := errors.New("redis error") + clientMock.On("EvalSha", ctx, hash, []string{"x:1", "y:1"}, rate, size, limit, algFixed, rate, size, limit, algFixed).Return(redis.NewCmdResult(i, e)) + _, err := blt.Limit(ctx, "1") + require.Equal(t, e, err) + + clientMock.On("EvalSha", ctx, hash, []string{"x:2", "y:2"}, rate, size, limit, algFixed, rate, size, limit, algFixed).Return(redis.NewCmdResult(i, nil)) + _, err = blt.Limit(ctx, "2") + require.Equal(t, ErrUnexpectedRedisResponse, err) + + i = []interface{}{1, -1} + clientMock.On("EvalSha", ctx, hash, []string{"x:3", "y:3"}, rate, size, limit, algFixed, rate, size, limit, algFixed).Return(redis.NewCmdResult(i, nil)) + _, err = blt.Limit(ctx, "3") + require.Equal(t, ErrUnexpectedRedisResponse, err) + + i = []interface{}{1, -1, 100} + clientMock.On("EvalSha", ctx, hash, []string{"x:4", "y:4"}, rate, size, limit, algFixed, rate, size, limit, algFixed).Return(redis.NewCmdResult(i, nil)) + _, err = blt.Limit(ctx, "4") + require.Equal(t, ErrUnexpectedRedisResponse, err) + + i = []interface{}{int64(1), -1, 100} + clientMock.On("EvalSha", ctx, hash, []string{"x:5", "y:5"}, rate, size, limit, algFixed, rate, size, limit, algFixed).Return(redis.NewCmdResult(i, nil)) + _, err = blt.Limit(ctx, "5") + require.Equal(t, ErrUnexpectedRedisResponse, err) + + i = []interface{}{int64(1), int64(-1), 100} + clientMock.On("EvalSha", ctx, hash, []string{"x:6", "y:6"}, rate, size, limit, algFixed, rate, size, limit, algFixed).Return(redis.NewCmdResult(i, nil)) + _, err = blt.Limit(ctx, "6") + require.Equal(t, ErrUnexpectedRedisResponse, err) + + i = []interface{}{int64(1), int64(-1), limit} + clientMock.On("EvalSha", ctx, hash, []string{"x:7", "y:7"}, rate, size, limit, algFixed, rate, size, limit, algFixed).Return(redis.NewCmdResult(i, nil)) + result, err := blt.Limit(ctx, "7") + require.NoError(t, err) + require.True(t, result.OK()) + require.Equal(t, int64(1), result.Counter()) + require.Equal(t, limit-1, result.Remainder()) + require.Equal(t, msToDuration(-1), result.TTL()) + + clientMock.AssertExpectations(t) +} diff --git a/slidingwindow_test.go b/slidingwindow_test.go index 96fcbec..6d232ad 100644 --- a/slidingwindow_test.go +++ b/slidingwindow_test.go @@ -24,8 +24,8 @@ func TestSlidingWindow(t *testing.T) { result, err := counter.Count(ctx, key, 101) require.NoError(t, err) require.False(t, result.OK()) - require.Equal(t, 0, result.Counter()) - require.Equal(t, 100, result.Remainder()) + require.Equal(t, int64(0), result.Counter()) + require.Equal(t, int64(100), result.Remainder()) require.True(t, result.TTL() >= msToDuration(0) && result.TTL() <= size) time.Sleep(result.TTL()) // wait for the next window to start @@ -33,22 +33,22 @@ func TestSlidingWindow(t *testing.T) { result, err = counter.Count(ctx, key, 20) require.NoError(t, err) require.True(t, result.OK()) - require.Equal(t, 20, result.Counter()) - require.Equal(t, 80, result.Remainder()) + require.Equal(t, int64(20), result.Counter()) + require.Equal(t, int64(80), result.Remainder()) require.Equal(t, msToDuration(-1), result.TTL()) result, err = counter.Count(ctx, key, 30) require.NoError(t, err) require.True(t, result.OK()) - require.Equal(t, 50, result.Counter()) - require.Equal(t, 50, result.Remainder()) + require.Equal(t, int64(50), result.Counter()) + require.Equal(t, int64(50), result.Remainder()) require.Equal(t, msToDuration(-1), result.TTL()) result, err = counter.Count(ctx, key, 51) require.NoError(t, err) require.False(t, result.OK()) - require.Equal(t, 50, result.Counter()) - require.Equal(t, 50, result.Remainder()) + require.Equal(t, int64(50), result.Counter()) + require.Equal(t, int64(50), result.Remainder()) require.True(t, result.TTL() >= msToDuration(0) && result.TTL() <= size) time.Sleep(result.TTL()) // wait for the next window to start