Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions ratex/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package ratex

import (
"context"
"crypto/rand"
"fmt"
"math/big"
"time"

"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/time/rate"

"github.com/moov-io/base/telemetry"
)

type RateLimitParams struct {
RateLimiter *rate.Limiter // can be nil to create a new rate limiter

// RetryAttempt represents the current retry attempt, starting at 1. This will increment for each retry
RetryAttempt int
MinDuration time.Duration
MaxDuration time.Duration
}

func RateLimit(ctx context.Context, params RateLimitParams) (*rate.Limiter, error) {
ctx, span := telemetry.StartSpan(ctx, "rate-limiter-wait",
trace.WithAttributes(
attribute.Int("retry_attempt", params.RetryAttempt),
attribute.Int64("min_duration_ms", params.MinDuration.Milliseconds()),
attribute.Int64("max_duration_ms", params.MaxDuration.Milliseconds()),
))
defer span.End()

var (
err error
)

params.RateLimiter, err = generateRateLimiter(ctx, params)
if err != nil {
return nil, fmt.Errorf("generating rate limiter: %w", err)
}

err = params.RateLimiter.Wait(ctx)
if err != nil {
return nil, fmt.Errorf("rate limiter wait: %w", err)
}

return params.RateLimiter, nil
}

// generateRateLimiter initializes a new rate limiter or sets a new limit on it.
func generateRateLimiter(ctx context.Context, params RateLimitParams) (*rate.Limiter, error) {
rateLimitDuration, err := generateRateLimitDuration(params.RetryAttempt, params.MinDuration, params.MaxDuration)
if err != nil {
return nil, fmt.Errorf("generating rate limit duration: %w", err)
}

rateLimitInterval := rate.Every(rateLimitDuration)
if params.RateLimiter == nil {
params.RateLimiter = rate.NewLimiter(rateLimitInterval, 1)
// A rate limiter is initialized with 1 token. So the first call to Wait will not wait/block, only subsequent calls to Wait will.
// Call wait immediately after initializing to use up token and ensure we trigger a delay on next call to Wait.
if err := params.RateLimiter.Wait(ctx); err != nil {
return nil, fmt.Errorf("rate limiter wait: %w", err)
}
} else {
params.RateLimiter.SetLimit(rateLimitInterval)
}

return params.RateLimiter, nil
}

// generateRateLimitDuration returns a random value between min-max duration multiplied by the multiplier.
func generateRateLimitDuration(multiplier int, minDuration, maxDuration time.Duration) (time.Duration, error) {
minVal := minDuration.Milliseconds()
maxVal := maxDuration.Milliseconds()

maxRand, err := rand.Int(rand.Reader, big.NewInt(maxVal-minVal))
if err != nil {
return 0, fmt.Errorf("rand int: %w", err)
}
waitInterval := (minVal + maxRand.Int64()) * int64(multiplier)
return time.Millisecond * time.Duration(waitInterval), nil
}

type RetryParams struct {
ShouldRetry func(err error) bool
MaxRetries int
MinDuration time.Duration
MaxDuration time.Duration
}

func ExecRetryable[R any](ctx context.Context, closure func(ctx context.Context) (R, error), params RetryParams) (R, error) {
var (
rateLimiter *rate.Limiter
retVal R
err error
)

retryFunc := func(ctx context.Context, retryAttempt int) (R, error) {
tryCtx, span := telemetry.StartSpan(ctx, "try",
trace.WithAttributes(
attribute.Int("retry_attempt", retryAttempt),
attribute.Int("max_tries", params.MaxRetries),
),
)
defer span.End()
return closure(tryCtx)
}

for i := range params.MaxRetries {
retryAttempt := i + 1
retVal, err = retryFunc(ctx, retryAttempt)

// no error means success - break out
if err == nil {
break
}

// if the error doesn't have one of the flags do not retry, instead return the error
if !params.ShouldRetry(err) {
return retVal, err
}

// record event if we'll be attempting retries
err = fmt.Errorf("try %d of %d: %w", retryAttempt, params.MaxRetries, err)
telemetry.AddEvent(ctx, err.Error())

if retryAttempt != params.MaxRetries {
// If error and we haven't hit max tries,
// generate rate limiter to delay retries.
// This will jitter a wait time before the next iteration.
//
// We continue on rate limit errors and retry without waiting
params := RateLimitParams{
RateLimiter: rateLimiter,
RetryAttempt: retryAttempt,
MinDuration: params.MinDuration,
MaxDuration: params.MaxDuration,
}
rateLimiter, err = RateLimit(ctx, params)
if err != nil {
telemetry.AddEvent(ctx, fmt.Sprintf("rate limit: %s", err.Error()))
continue
}
}
}

if err != nil {
return retVal, fmt.Errorf("hit max tries %d: %w", params.MaxRetries, err)
}
return retVal, nil
}
81 changes: 81 additions & 0 deletions ratex/ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package ratex

import (
"context"
"errors"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestExecRetryable(t *testing.T) {
ctx := context.Background()

t.Run("Success on first try", func(t *testing.T) {
closure := func(ctx context.Context) (string, error) {
return "success", nil
}
params := RetryParams{
ShouldRetry: func(err error) bool { return true },
MaxRetries: 3,
MinDuration: 10 * time.Millisecond,
MaxDuration: 50 * time.Millisecond,
}
result, err := ExecRetryable(ctx, closure, params)
require.NoErrorf(t, err, "Expected success, got error: %v", err)
require.Equalf(t, "success", result, "Expected result 'success', got: %v", result)
})

t.Run("Retryable failure with success before last retry", func(t *testing.T) {
attempts := 0
closure := func(ctx context.Context) (string, error) {
if attempts < 2 {
attempts++
return "", errors.New("retryable error")
}
return "success", nil
}
params := RetryParams{
ShouldRetry: func(err error) bool { return true },
MaxRetries: 3,
MinDuration: 10 * time.Millisecond,
MaxDuration: 50 * time.Millisecond,
}
result, err := ExecRetryable(ctx, closure, params)
require.NoErrorf(t, err, "Expected success, got error: %v", err)
require.Equalf(t, "success", result, "Expected result 'success', got: %v", result)
})

t.Run("Non-retryable failure", func(t *testing.T) {
closure := func(ctx context.Context) (string, error) {
return "", errors.New("non-retryable error")
}
params := RetryParams{
ShouldRetry: func(err error) bool { return false },
MaxRetries: 3,
MinDuration: 10 * time.Millisecond,
MaxDuration: 50 * time.Millisecond,
}
result, err := ExecRetryable(ctx, closure, params)
require.Errorf(t, err, "Expected non-retryable error, got: %v", err)
require.Empty(t, result)
require.Equal(t, "non-retryable error", err.Error())
})

t.Run("Retryable failures exceeding MaxRetries", func(t *testing.T) {
closure := func(ctx context.Context) (string, error) {
return "", errors.New("retryable error")
}
params := RetryParams{
ShouldRetry: func(err error) bool { return true },
MaxRetries: 3,
MinDuration: 10 * time.Millisecond,
MaxDuration: 50 * time.Millisecond,
}
result, err := ExecRetryable(ctx, closure, params)
require.Errorf(t, err, "Expected error after exceeding max retries, got: %v", err)
require.Empty(t, result)
require.Equal(t, "hit max tries 3: try 3 of 3: retryable error", err.Error())
})
}
62 changes: 56 additions & 6 deletions sql/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/moov-io/base/log"
"github.com/moov-io/base/ratex"
)

type DB struct {
Expand Down Expand Up @@ -65,15 +66,15 @@ func (w *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
return newStmt(ctx, w.logger, w.DB, query, w.id, w.slowQueryThresholdMs)
}

func (w *DB) Exec(query string, args ...interface{}) (gosql.Result, error) {
func (w *DB) Exec(query string, args ...any) (gosql.Result, error) {
done := w.start("exec", query, len(args))
defer done()

r, err := w.DB.Exec(query, args...)
return r, w.error(err)
}

func (w *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (gosql.Result, error) {
func (w *DB) ExecContext(ctx context.Context, query string, args ...any) (gosql.Result, error) {
done := w.start("exec", query, len(args))
ctx, end := span(ctx, w.id, "exec", query, len(args))
defer func() {
Expand All @@ -85,7 +86,22 @@ func (w *DB) ExecContext(ctx context.Context, query string, args ...interface{})
return r, w.error(err)
}

func (w *DB) Query(query string, args ...interface{}) (*gosql.Rows, error) {
func (w *DB) ExecContextRetryable(ctx context.Context, query string, retryParams ratex.RetryParams, args ...any) (gosql.Result, error) {
done := w.start("exec", query, len(args))
ctx, end := span(ctx, w.id, "exec", query, len(args))
defer func() {
end()
done()
}()

closure := func(ctx context.Context) (gosql.Result, error) {
r, err := w.DB.ExecContext(ctx, query, args...)
return r, w.error(err)
}
return ratex.ExecRetryable(ctx, closure, retryParams)
}

func (w *DB) Query(query string, args ...any) (*gosql.Rows, error) {
done := w.start("query", query, len(args))
defer done()

Expand All @@ -94,7 +110,7 @@ func (w *DB) Query(query string, args ...interface{}) (*gosql.Rows, error) {
return r, w.error(err)
}

func (w *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*gosql.Rows, error) {
func (w *DB) QueryContext(ctx context.Context, query string, args ...any) (*gosql.Rows, error) {
done := w.start("query", query, len(args))
ctx, end := span(ctx, w.id, "query", query, len(args))
defer func() {
Expand All @@ -106,7 +122,22 @@ func (w *DB) QueryContext(ctx context.Context, query string, args ...interface{}
return r, w.error(err)
}

func (w *DB) QueryRow(query string, args ...interface{}) *gosql.Row {
func (w *DB) QueryContextRetryable(ctx context.Context, query string, retryParams ratex.RetryParams, args ...any) (*gosql.Rows, error) {
done := w.start("query", query, len(args))
ctx, end := span(ctx, w.id, "query", query, len(args))
defer func() {
end()
done()
}()

closure := func(ctx context.Context) (*gosql.Rows, error) {
r, err := w.DB.QueryContext(ctx, query, args...)
return r, w.error(err)
}
return ratex.ExecRetryable(ctx, closure, retryParams) //nolint:sqlclosecheck
}

func (w *DB) QueryRow(query string, args ...any) *gosql.Row {
done := w.start("query-row", query, len(args))
defer done()

Expand All @@ -116,7 +147,7 @@ func (w *DB) QueryRow(query string, args ...interface{}) *gosql.Row {
return r
}

func (w *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *gosql.Row {
func (w *DB) QueryRowContext(ctx context.Context, query string, args ...any) *gosql.Row {
done := w.start("query-row", query, len(args))
ctx, end := span(ctx, w.id, "query-row", query, len(args))
defer func() {
Expand All @@ -130,6 +161,25 @@ func (w *DB) QueryRowContext(ctx context.Context, query string, args ...interfac
return r
}

func (w *DB) QueryRowContextRetryable(ctx context.Context, query string, retryParams ratex.RetryParams, args ...any) *gosql.Row {
done := w.start("query-row", query, len(args))
ctx, end := span(ctx, w.id, "query-row", query, len(args))
defer func() {
end()
done()
}()

closure := func(ctx context.Context) (*gosql.Row, error) {
r := w.DB.QueryRowContext(ctx, query, args...)
w.error(r.Err())
return r, r.Err()
}

// the error is contained in r as r.Err() (see closure implementation), though it will not contain metadata about the retries as are present in the other *Retryable methods
r, _ := ratex.ExecRetryable(ctx, closure, retryParams)
return r
}

func (w *DB) Begin() (*Tx, error) {
t, err := w.DB.Begin()
if err != nil {
Expand Down
Loading
Loading