Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly handle some extreme edge cases in the ratelimiter implementation #3706

Merged
merged 5 commits into from
May 25, 2023
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
70 changes: 53 additions & 17 deletions src/rate_limiter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ impl TokenBucket {
// refill_token_count = (delta_time * size) / (complete_refill_time_ms * 1_000_000)
// In order to avoid overflows, simplify the fractions by computing greatest common divisor.

let complete_refill_time_ns = complete_refill_time_ms * NANOSEC_IN_ONE_MILLISEC;
let complete_refill_time_ns =
roypat marked this conversation as resolved.
Show resolved Hide resolved
complete_refill_time_ms.checked_mul(NANOSEC_IN_ONE_MILLISEC)?;
// Get the greatest common factor between `size` and `complete_refill_time_ns`.
let common_factor = gcd(size, complete_refill_time_ns);
// The division will be exact since `common_factor` is a factor of `size`.
Expand All @@ -158,22 +159,54 @@ impl TokenBucket {
// Replenishes token bucket based on elapsed time. Should only be called internally by `Self`.
fn auto_replenish(&mut self) {
// Compute time passed since last refill/update.
let time_delta = self.last_update.elapsed().as_nanos() as u64;
let now = Instant::now();
let time_delta = (now - self.last_update).as_nanos();

// At each 'time_delta' nanoseconds the bucket should refill with:
// refill_amount = (time_delta * size) / (complete_refill_time_ms * 1_000_000)
// `processed_capacity` and `processed_refill_time` are the result of simplifying above
// fraction formula with their greatest-common-factor.
let tokens = (time_delta * self.processed_capacity) / self.processed_refill_time;
if time_delta >= u128::from(self.refill_time * NANOSEC_IN_ONE_MILLISEC) {
self.budget = self.size;
self.last_update = now;
} else {
// At each 'time_delta' nanoseconds the bucket should refill with:
// refill_amount = (time_delta * size) / (complete_refill_time_ms * 1_000_000)
// `processed_capacity` and `processed_refill_time` are the result of simplifying above
// fraction formula with their greatest-common-factor.

// In the constructor, we assured that (self.refill_time * NANOSEC_IN_ONE_MILLISEC)
// fits into a u64 That means, at this point we know that time_delta <
// u64::MAX. Since all other values here are u64, this assures that u128
// multiplication cannot overflow.
let processed_capacity = u128::from(self.processed_capacity);
let processed_refill_time = u128::from(self.processed_refill_time);

let tokens = (time_delta * processed_capacity) / processed_refill_time;

// We increment `self.last_update` by the minimum time required to generate `tokens`, in
// the case where we have the time to generate `1.8` tokens but only
// generate `x` tokens due to integer arithmetic this will carry the time
// required to generate 0.8th of a token over to the next call, such that if
// the next call where to generate `2.3` tokens it would instead
// generate `3.1` tokens. This minimizes dropping tokens at high frequencies.
// We want the integer division here to round up instead of down (as if we round down,
roypat marked this conversation as resolved.
Show resolved Hide resolved
// we would allow some fraction of a nano second to be used twice, allowing
// for the generation of one extra token in extreme circumstances).
let mut time_adjustment = tokens * processed_refill_time / processed_capacity;
if tokens * processed_refill_time % processed_capacity != 0 {
time_adjustment += 1;
}

// We increment `self.last_update` by the minimum time required to generate `tokens`, in the
// case where we have the time to generate `1.8` tokens but only generate `x` tokens due to
// integer arithmetic this will carry the time required to generate 0.8th of a token over to
// the next call, such that if the next call where to generate `2.3` tokens it would instead
// generate `3.1` tokens. This minimizes dropping tokens at high frequencies.
self.last_update +=
Duration::from_nanos((tokens * self.processed_refill_time) / self.processed_capacity);
self.budget = std::cmp::min(self.budget + tokens, self.size);
// Ensure that we always generate as many tokens as we can: assert that the "unused"
// part of time_delta is less than the time it would take to generate a
// single token (= processed_refill_time / processed_capacity)
debug_assert!(time_adjustment <= time_delta);
roypat marked this conversation as resolved.
Show resolved Hide resolved
debug_assert!(
(time_delta - time_adjustment) * processed_capacity <= processed_refill_time
);

// time_adjustment is at most time_delta, and since time_delta <= u64::MAX, this cast is
// fine
self.last_update += Duration::from_nanos(time_adjustment as u64);
self.budget = std::cmp::min(self.budget.saturating_add(tokens as u64), self.size);
}
}

/// Attempts to consume `tokens` from the bucket and returns whether the action succeeded.
Expand Down Expand Up @@ -229,10 +262,13 @@ impl TokenBucket {
// budget which should now be replenished, but for performance and code-complexity
// reasons we're just gonna let that slide since it's practically inconsequential.
if self.one_time_burst > 0 {
self.one_time_burst += tokens;
self.one_time_burst = std::cmp::min(
self.one_time_burst.saturating_add(tokens),
self.initial_one_time_burst,
);
return;
}
self.budget = std::cmp::min(self.budget + tokens, self.size);
self.budget = std::cmp::min(self.budget.saturating_add(tokens), self.size);
}

/// Returns the capacity of the token bucket.
Expand Down