diff --git a/processor/ratelimitprocessor/gubernator.go b/processor/ratelimitprocessor/gubernator.go index d1ffd203a..39d04cb45 100644 --- a/processor/ratelimitprocessor/gubernator.go +++ b/processor/ratelimitprocessor/gubernator.go @@ -125,49 +125,67 @@ func (r *gubernatorRateLimiter) RateLimit(ctx context.Context, hits int) error { uniqueKey := getUniqueKey(ctx, r.cfg.MetadataKeys) cfg := resolveRateLimitSettings(r.cfg, uniqueKey) - createdAt := time.Now().UnixMilli() - getRateLimitsResp, err := r.client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{ - Requests: []*gubernator.RateLimitReq{ - { - Name: r.set.ID.String(), - UniqueKey: uniqueKey, - Hits: int64(hits), - Behavior: r.behavior, - Algorithm: gubernator.Algorithm_LEAKY_BUCKET, - Limit: int64(cfg.Rate), // rate is per second - Burst: int64(cfg.Burst), - Duration: cfg.ThrottleInterval.Milliseconds(), // duration is in milliseconds, i.e. 1s - CreatedAt: &createdAt, + makeRateLimitRequest := func() (*gubernator.RateLimitResp, error) { + createdAt := time.Now().UnixMilli() + getRateLimitsResp, err := r.client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{ + Requests: []*gubernator.RateLimitReq{ + { + Name: r.set.ID.String(), + UniqueKey: uniqueKey, + Hits: int64(hits), + Behavior: r.behavior, + Algorithm: gubernator.Algorithm_LEAKY_BUCKET, + Limit: int64(cfg.Rate), // rate is per second + Burst: int64(cfg.Burst), + Duration: cfg.ThrottleInterval.Milliseconds(), // duration is in milliseconds, i.e. 1s + CreatedAt: &createdAt, + }, }, - }, - }) + }) + if err != nil { + return nil, err + } + // Inside the gRPC response, we should have a single-item list of responses. + responses := getRateLimitsResp.GetResponses() + if n := len(responses); n != 1 { + return nil, fmt.Errorf("expected 1 response from gubernator, got %d", n) + } + resp := responses[0] + if resp.GetError() != "" { + return nil, errors.New(resp.GetError()) + } + return resp, nil + } + resp, err := makeRateLimitRequest() if err != nil { return err } - // Inside the gRPC response, we should have a single-item list of responses. - responses := getRateLimitsResp.GetResponses() - if n := len(responses); n != 1 { - return fmt.Errorf("expected 1 response from gubernator, got %d", n) - } - resp := responses[0] - if resp.GetError() != "" { - return errors.New(resp.GetError()) - } - if resp.GetStatus() == gubernator.Status_OVER_LIMIT { // Same logic as local switch r.cfg.ThrottleBehavior { case ThrottleBehaviorError: return status.Error(codes.ResourceExhausted, errTooManyRequests.Error()) case ThrottleBehaviorDelay: - delay := time.Duration(resp.GetResetTime()-createdAt) * time.Millisecond + delay := time.Duration(resp.GetResetTime()-time.Now().UnixMilli()) * time.Millisecond timer := time.NewTimer(delay) defer timer.Stop() - select { - case <-ctx.Done(): - return ctx.Err() - case <-timer.C: + retry: + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + resp, err = makeRateLimitRequest() + if err != nil { + return err + } + if resp.GetStatus() == gubernator.Status_UNDER_LIMIT { + break retry + } + delay = time.Duration(resp.GetResetTime()-time.Now().UnixMilli()) * time.Millisecond + timer.Reset(delay) + } } } } diff --git a/processor/ratelimitprocessor/gubernator_test.go b/processor/ratelimitprocessor/gubernator_test.go index 2150a9d5f..0e035bdaf 100644 --- a/processor/ratelimitprocessor/gubernator_test.go +++ b/processor/ratelimitprocessor/gubernator_test.go @@ -19,6 +19,8 @@ package ratelimitprocessor import ( "context" + "slices" + "sync" "testing" "time" @@ -142,3 +144,50 @@ func TestGubernatorRateLimiter_RateLimit_MetadataKeys(t *testing.T) { err = rateLimiter.RateLimit(clientContext2, 1) assert.NoError(t, err) } + +func TestGubernatorRateLimiter_MultipleRequests_Delay(t *testing.T) { + throttleInterval := 100 * time.Millisecond + rl := newTestGubernatorRateLimiter(t, &Config{ + RateLimitSettings: RateLimitSettings{ + Rate: 1, // request per second + Burst: 1, // capacity only for one + ThrottleBehavior: ThrottleBehaviorDelay, + ThrottleInterval: throttleInterval, // add 1 token after 100ms + }, + MetadataKeys: []string{"metadata_key"}, + }) + + // Simulate 4 requests hitting the rate limit simultaneously. + // The first request passes, and the next ones hit it simultaneously. + requests := 5 + endingTimes := make([]time.Time, requests) + var wg sync.WaitGroup + wg.Add(requests) + + for i := 0; i < requests; i++ { + go func(i int) { + defer wg.Done() + err := rl.RateLimit(context.Background(), 1) + require.NoError(t, err) + endingTimes[i] = time.Now() + }(i) + } + wg.Wait() + + // Make sure all ending times have a difference of at least 100ms, as tokens are + // added at that rate. We need to sort them first. + slices.SortFunc(endingTimes, func(a, b time.Time) int { + if a.Before(b) { + return -1 + } + return 1 + }) + + for i := 1; i < requests; i++ { + diff := endingTimes[i].Sub(endingTimes[i-1]).Milliseconds() + minExpected := throttleInterval - 5*time.Millisecond // allow small tolerance + if diff < minExpected.Milliseconds() { + t.Fatalf("difference is %dms, requests were sent before tokens were added", diff) + } + } +} diff --git a/processor/ratelimitprocessor/local_test.go b/processor/ratelimitprocessor/local_test.go index 052d18ada..8db2377e2 100644 --- a/processor/ratelimitprocessor/local_test.go +++ b/processor/ratelimitprocessor/local_test.go @@ -19,6 +19,8 @@ package ratelimitprocessor import ( "context" + "slices" + "sync" "testing" "time" @@ -126,3 +128,51 @@ func TestLocalRateLimiter_RateLimit_MetadataKeys(t *testing.T) { assert.NoError(t, err) } } + +func TestLocalRateLimiter_MultipleRequests_Delay(t *testing.T) { + throttleInterval := 100 * time.Millisecond + rl := newTestLocalRateLimiter(t, &Config{ + Type: LocalRateLimiter, + RateLimitSettings: RateLimitSettings{ + Rate: 1, // request per second + Burst: 1, // capacity only for one + ThrottleBehavior: ThrottleBehaviorDelay, + ThrottleInterval: throttleInterval, // add 1 token after 100ms + }, + MetadataKeys: []string{"metadata_key"}, + }) + + // Simulate 4 requests hitting the rate limit simultaneously. + // The first request passes, and the next ones hit it simultaneously. + requests := 5 + endingTimes := make([]time.Time, requests) + var wg sync.WaitGroup + wg.Add(requests) + + for i := 0; i < requests; i++ { + go func(i int) { + defer wg.Done() + err := rl.RateLimit(context.Background(), 1) + require.NoError(t, err) + endingTimes[i] = time.Now() + }(i) + } + wg.Wait() + + // Make sure all ending times have a difference of at least 100ms, as tokens are + // added at that rate. We need to sort them first. + slices.SortFunc(endingTimes, func(a, b time.Time) int { + if a.Before(b) { + return -1 + } + return 1 + }) + + for i := 1; i < requests; i++ { + diff := endingTimes[i].Sub(endingTimes[i-1]).Milliseconds() + minExpected := throttleInterval - 5*time.Millisecond // allow small tolerance + if diff < minExpected.Milliseconds() { + t.Fatalf("difference is %dms, requests were sent before tokens were added", diff) + } + } +}