From bae8b922c29bb0fb6d34e279d4df3f663c187957 Mon Sep 17 00:00:00 2001 From: Raymond Tukpe Date: Wed, 22 May 2024 17:01:02 +0200 Subject: [PATCH] Update rate limit function (#2002) * chore: update rate limit function * fix: return when locked row is null * fix: select only specific rows * chore: create record if not exists --- internal/pkg/limiter/pg/client.go | 28 ++++++-- sql/1715118159.sql | 107 ++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 5 deletions(-) create mode 100644 sql/1715118159.sql diff --git a/internal/pkg/limiter/pg/client.go b/internal/pkg/limiter/pg/client.go index 4f64e15455..68c44b936f 100644 --- a/internal/pkg/limiter/pg/client.go +++ b/internal/pkg/limiter/pg/client.go @@ -5,8 +5,12 @@ import ( "errors" "github.com/frain-dev/convoy/database" "github.com/frain-dev/convoy/pkg/log" + "github.com/jmoiron/sqlx" + "github.com/lib/pq" ) +var ErrRateLimitExceeded = errors.New("rate limit exceeded") + type SlidingWindowRateLimiter struct { db database.Database } @@ -31,28 +35,42 @@ func (p *SlidingWindowRateLimiter) takeToken(ctx context.Context, key string, ra tx, err := p.db.GetDB().BeginTxx(ctx, nil) if err != nil { - log.Infof("ratelimit failed: %v", err) return nil } var allowed bool err = tx.QueryRowContext(ctx, `select convoy.take_token($1, $2, $3)::bool;`, key, rate, windowSize).Scan(&allowed) if err != nil { - log.Infof("ratelimit failed: %v", err) - return nil + return postgresErrorTransform(tx, err) } err = tx.Commit() if err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { - log.Infof("update failed: %v, unable to rollback: %v", err, rollbackErr) + log.Infof("failed: %v, unable to rollback: %v", err, rollbackErr) } return nil } if !allowed { - return errors.New("rate limit error") + return ErrRateLimitExceeded } return nil } + +func postgresErrorTransform(tx *sqlx.Tx, err error) error { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + log.Infof("failed: %v, unable to rollback: %v", err, rollbackErr) + } + + var pgErr *pq.Error + ok := errors.As(err, &pgErr) + if ok { + if pgErr.Code == "23505" { + return ErrRateLimitExceeded + } + } + + return err +} diff --git a/sql/1715118159.sql b/sql/1715118159.sql new file mode 100644 index 0000000000..667288dd92 --- /dev/null +++ b/sql/1715118159.sql @@ -0,0 +1,107 @@ +-- +migrate Up +-- +migrate StatementBegin +create or replace function convoy.take_token(_key text, _rate integer, _bucket_size integer) returns boolean + language plpgsql +as +$$ +DECLARE + next_min timestamptz; + can_take BOOLEAN; + row RECORD; +BEGIN + next_min := current_timestamp + make_interval(secs := _bucket_size); + + SELECT expires_at, tokens FROM convoy.token_bucket WHERE key = _key FOR UPDATE SKIP LOCKED LIMIT 1 INTO row; + if row is null then + INSERT INTO convoy.token_bucket (key, rate, expires_at) + VALUES (_key, _rate, next_min); + return true; + end if; + + IF current_timestamp < row.expires_at AND row.tokens = _rate THEN + RETURN FALSE; + END IF; + + -- Update existing record + UPDATE convoy.token_bucket + SET tokens = + CASE WHEN current_timestamp > expires_at + THEN 1 + ELSE CASE WHEN tokens < _rate + THEN tokens + 1 + ELSE tokens END + END, + expires_at = + CASE WHEN current_timestamp > expires_at + THEN next_min + ELSE CASE WHEN tokens < _rate + THEN next_min + ELSE expires_at + END + END, + rate = COALESCE(_rate, rate), + updated_at = DEFAULT + WHERE key = _key + RETURNING TRUE INTO can_take; + + RETURN can_take; +END; +$$; +-- +migrate StatementEnd + +-- +migrate Down +-- +migrate StatementBegin +create or replace function convoy.take_token(_key text, _rate integer, _bucket_size integer) returns boolean + language plpgsql +as +$$ +declare + row record; + next_min timestamptz; + new_rate int; +begin + select * from convoy.token_bucket where key = _key for update into row; + next_min := now() + make_interval(secs := _bucket_size); + + -- the bucket doesn't exist yet + if row is null then + insert into convoy.token_bucket (key, rate, expires_at) + SELECT _key, _rate, next_min + WHERE NOT EXISTS ( + SELECT 1 FROM convoy.token_bucket WHERE key = _key + ); + + return true; + end if; + + -- update the rate if it's different from what's in the db + new_rate = case when row.rate != _rate then _rate else row.rate end; + + -- this bucket has expired, reset it + if now() > row.expires_at then + UPDATE convoy.token_bucket + SET tokens = 1, + expires_at = next_min, + updated_at = default, + rate = new_rate + WHERE key = _key; + return true; + end if; + + -- take a token + if row.tokens < new_rate then + update convoy.token_bucket + set tokens = row.tokens + 1, + expires_at = next_min, + updated_at = default, + rate = new_rate + where key = _key; + return true; + end if; + + -- no tokens for you sorry + return false; +end; +$$; +-- +migrate StatementEnd +