Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update rate limit function #2002

Merged
merged 4 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions internal/pkg/limiter/pg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ import (
"errors"
"github.com/frain-dev/convoy/database"
"github.com/frain-dev/convoy/pkg/log"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
)

var ErrRateLimitExceeded = errors.New("rate limit exceeded")

type SlidingWindowRateLimiter struct {
db database.Database
}
Expand All @@ -31,28 +35,42 @@ func (p *SlidingWindowRateLimiter) takeToken(ctx context.Context, key string, ra

tx, err := p.db.GetDB().BeginTxx(ctx, nil)
if err != nil {
log.Infof("ratelimit failed: %v", err)
return nil
}

var allowed bool
err = tx.QueryRowContext(ctx, `select convoy.take_token($1, $2, $3)::bool;`, key, rate, windowSize).Scan(&allowed)
if err != nil {
log.Infof("ratelimit failed: %v", err)
return nil
return postgresErrorTransform(tx, err)
}

err = tx.Commit()
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
log.Infof("update failed: %v, unable to rollback: %v", err, rollbackErr)
log.Infof("failed: %v, unable to rollback: %v", err, rollbackErr)
}
return nil
}

if !allowed {
return errors.New("rate limit error")
return ErrRateLimitExceeded
}

return nil
}

func postgresErrorTransform(tx *sqlx.Tx, err error) error {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
log.Infof("failed: %v, unable to rollback: %v", err, rollbackErr)
}

var pgErr *pq.Error
ok := errors.As(err, &pgErr)
if ok {
if pgErr.Code == "23505" {
return ErrRateLimitExceeded
}
}

return err
}
107 changes: 107 additions & 0 deletions sql/1715118159.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
-- +migrate Up
-- +migrate StatementBegin
create or replace function convoy.take_token(_key text, _rate integer, _bucket_size integer) returns boolean
language plpgsql
as
$$
DECLARE
next_min timestamptz;
can_take BOOLEAN;
row RECORD;
BEGIN
next_min := current_timestamp + make_interval(secs := _bucket_size);

SELECT expires_at, tokens FROM convoy.token_bucket WHERE key = _key FOR UPDATE SKIP LOCKED LIMIT 1 INTO row;
if row is null then
jirevwe marked this conversation as resolved.
Show resolved Hide resolved
INSERT INTO convoy.token_bucket (key, rate, expires_at)
VALUES (_key, _rate, next_min);
return true;
end if;

IF current_timestamp < row.expires_at AND row.tokens = _rate THEN
RETURN FALSE;
END IF;

-- Update existing record
UPDATE convoy.token_bucket
SET tokens =
CASE WHEN current_timestamp > expires_at
THEN 1
ELSE CASE WHEN tokens < _rate
THEN tokens + 1
ELSE tokens END
END,
expires_at =
CASE WHEN current_timestamp > expires_at
THEN next_min
ELSE CASE WHEN tokens < _rate
THEN next_min
ELSE expires_at
END
END,
rate = COALESCE(_rate, rate),
updated_at = DEFAULT
WHERE key = _key
RETURNING TRUE INTO can_take;

RETURN can_take;
END;
$$;
-- +migrate StatementEnd

-- +migrate Down
-- +migrate StatementBegin
create or replace function convoy.take_token(_key text, _rate integer, _bucket_size integer) returns boolean
language plpgsql
as
$$
declare
row record;
next_min timestamptz;
new_rate int;
begin
select * from convoy.token_bucket where key = _key for update into row;
next_min := now() + make_interval(secs := _bucket_size);

-- the bucket doesn't exist yet
if row is null then
insert into convoy.token_bucket (key, rate, expires_at)
SELECT _key, _rate, next_min
WHERE NOT EXISTS (
SELECT 1 FROM convoy.token_bucket WHERE key = _key
);

return true;
end if;

-- update the rate if it's different from what's in the db
new_rate = case when row.rate != _rate then _rate else row.rate end;

-- this bucket has expired, reset it
if now() > row.expires_at then
UPDATE convoy.token_bucket
SET tokens = 1,
expires_at = next_min,
updated_at = default,
rate = new_rate
WHERE key = _key;
return true;
end if;

-- take a token
if row.tokens < new_rate then
update convoy.token_bucket
set tokens = row.tokens + 1,
expires_at = next_min,
updated_at = default,
rate = new_rate
where key = _key;
return true;
end if;

-- no tokens for you sorry
return false;
end;
$$;
-- +migrate StatementEnd

Loading