Skip to content

Commit

Permalink
#51 protect from undefined behavior caused by arithmetic overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-bukhtoyarov committed Sep 22, 2017
1 parent 895d436 commit c939130
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ public static IllegalArgumentException restrictionsNotSpecified() {
return new IllegalArgumentException(msg);
}

public static IllegalArgumentException tooHighRefillRate(long periodNanos, long tokens) {
double actualRate = (double) tokens / (double) periodNanos;
String pattern = "{0} token/nanosecond is not permitted refill rate" +
", because highest supported rate is 1 token/nanosecond";
String msg = MessageFormat.format(pattern, actualRate);
return new IllegalArgumentException(msg);
}

// ------------------- end of construction time exceptions --------------------------------

// ------------------- usage time exceptions ---------------------------------------------
Expand Down
89 changes: 65 additions & 24 deletions bucket4j-core/src/main/java/io/github/bucket4j/BucketState.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package io.github.bucket4j;

import java.io.Serializable;
import java.math.BigInteger;
import java.util.Arrays;

public class BucketState implements Serializable {
Expand Down Expand Up @@ -108,8 +109,11 @@ private void addTokens(int bandwidthIndex, Bandwidth bandwidth, long tokensToAdd
long currentSize = getCurrentSize(bandwidthIndex);
long newSize = currentSize + tokensToAdd;
if (newSize >= bandwidth.capacity) {
setCurrentSize(bandwidthIndex, bandwidth.capacity);
setRoundingError(bandwidthIndex, 0L);
resetBandwidth(bandwidthIndex, bandwidth.capacity);
} else if (newSize < currentSize) {
// arithmetic overflow happens. This mean that bucket reached Long.MAX_VALUE tokens.
// just reset bandwidth state
resetBandwidth(bandwidthIndex, bandwidth.capacity);
} else {
setCurrentSize(bandwidthIndex, newSize);
}
Expand All @@ -121,36 +125,65 @@ private void consume(int bandwidth, long tokens) {

private void refill(int bandwidthIndex, Bandwidth bandwidth, long previousRefillNanos, long currentTimeNanos) {
final long capacity = bandwidth.capacity;
long currentSize = getCurrentSize(bandwidthIndex);

if (currentSize >= capacity) {
setCurrentSize(bandwidthIndex, capacity);
setRoundingError(bandwidthIndex, 0L);
return;
}
final long refillPeriodNanos = bandwidth.refill.getPeriodNanos();
final long refillTokens = bandwidth.refill.getTokens();
final long currentSize = getCurrentSize(bandwidthIndex);

long durationSinceLastRefillNanos = currentTimeNanos - previousRefillNanos;
long newSize = currentSize;

if (durationSinceLastRefillNanos > refillPeriodNanos) {
long elapsedPeriods = durationSinceLastRefillNanos / refillPeriodNanos;
newSize += elapsedPeriods * refillTokens;
if (newSize > capacity) {
resetBandwidth(bandwidthIndex, capacity);
return;
}
if (newSize < currentSize) {
// arithmetic overflow happens. This mean that bucket reached Long.MAX_VALUE tokens.
// just reset bandwidth state
resetBandwidth(bandwidthIndex, capacity);
return;
}
durationSinceLastRefillNanos %= refillPeriodNanos;
}

long refillPeriod = bandwidth.refill.getPeriodNanos();
long refillTokens = bandwidth.refill.getTokens();
long roundingError = getRoundingError(bandwidthIndex);
long divided = refillTokens * durationSinceLastRefillNanos + roundingError;
long calculatedRefill = divided / refillPeriod;
if (calculatedRefill == 0) {
setRoundingError(bandwidthIndex, divided);
return;
long dividedWithoutError = refillTokens * durationSinceLastRefillNanos;
long divided = dividedWithoutError + roundingError;
if (divided < 0 || dividedWithoutError < 0) {
// arithmetic overflow happens.
// there is no sense to stay in integer arithmetic when having deal with so big numbers
long calculatedRefill = (long) ((double) durationSinceLastRefillNanos / (double) refillPeriodNanos * (double) refillTokens);
newSize += calculatedRefill;
setRoundingError(bandwidthIndex, 0);
} else {
long calculatedRefill = divided / refillPeriodNanos;
if (calculatedRefill == 0) {
setRoundingError(bandwidthIndex, divided);
} else {
newSize += calculatedRefill;
roundingError = divided % refillPeriodNanos;
setRoundingError(bandwidthIndex, roundingError);
}
}

long newSize = currentSize + calculatedRefill;
if (newSize >= capacity) {
setCurrentSize(bandwidthIndex, capacity);
setRoundingError(bandwidthIndex, 0);
resetBandwidth(bandwidthIndex, capacity);
return;
}
if (newSize < currentSize) {
// arithmetic overflow happens. This mean that bucket reached Long.MAX_VALUE tokens.
// just reset bandwidth state
resetBandwidth(bandwidthIndex, capacity);
return;
}

roundingError = divided % refillPeriod;
setCurrentSize(bandwidthIndex, newSize);
setRoundingError(bandwidthIndex, roundingError);
}

private void resetBandwidth(int bandwidthIndex, long capacity) {
setCurrentSize(bandwidthIndex, capacity);
setRoundingError(bandwidthIndex, 0);
}

private long delayNanosAfterWillBePossibleToConsume(int bandwidthIndex, Bandwidth bandwidth, long tokens) {
Expand All @@ -159,8 +192,16 @@ private long delayNanosAfterWillBePossibleToConsume(int bandwidthIndex, Bandwidt
return 0;
}
long deficit = tokens - currentSize;
long periodNanos = bandwidth.refill.getPeriodNanos();
return periodNanos * deficit / bandwidth.refill.getTokens();
long refillPeriodNanos = bandwidth.refill.getPeriodNanos();
long refillPeriodTokens = bandwidth.refill.getTokens();

if (refillPeriodNanos * deficit < 0) {
// arithmetic overflow happens.
// there is no sense to stay in integer arithmetic when having deal with so big numbers
return (long)((double)deficit / (double)refillPeriodTokens * (double)refillPeriodNanos);
} else {
return refillPeriodNanos * deficit / refillPeriodTokens;
}
}

long getCurrentSize(int bandwidth) {
Expand Down
3 changes: 3 additions & 0 deletions bucket4j-core/src/main/java/io/github/bucket4j/Refill.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ private Refill(long tokens, Duration period) {
if (periodNanos <= 0) {
throw BucketExceptions.nonPositivePeriod(periodNanos);
}
if (tokens > periodNanos) {
throw BucketExceptions.tooHighRefillRate(periodNanos, tokens);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ class DetectionOfIllegalApiUsageSpecification extends Specification {
ex.message == restrictionsNotSpecified().message
}

def "Should detect the high rate of refill"() {
when:
Bucket4j.builder().addLimit(Bandwidth.simple(2, Duration.ofNanos(1)))
then:
IllegalArgumentException ex = thrown()
ex.message == tooHighRefillRate(1, 2).message
}

def "Should check that tokens to consume should be positive"() {
setup:
def bucket = Bucket4j.builder().addLimit(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
*
* Copyright 2015-2017 Vladimir Bukhtoyarov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.bucket4j

import io.github.bucket4j.mock.TimeMeterMock
import spock.lang.Specification

import java.time.Duration

class HandlingArithmeticOverflowSpecification extends Specification {

def "regression test for https://github.com/vladimir-bukhtoyarov/bucket4j/issues/51"() {
setup:
final Bandwidth limit1 = Bandwidth.simple(700000, Duration.ofHours(1))
final Bandwidth limit2 = Bandwidth.simple(14500, Duration.ofMinutes(1))
final Bandwidth limit3 = Bandwidth.simple(300, Duration.ofSeconds(1))
TimeMeterMock customTimeMeter = new TimeMeterMock(0)
long twelveHourNanos = 12 * 60 * 60 * 1_000_000_000L;
Bucket bucket = Bucket4j.builder()
.addLimit(limit1)
.addLimit(limit2)
.addLimit(limit3)
.withCustomTimePrecision(customTimeMeter)
.build()
when:
// shift time to 12 hours forward
customTimeMeter.addTime(twelveHourNanos)
then:
bucket.tryConsume(1)
bucket.tryConsume(300 - 1)
!bucket.tryConsume(1)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
import java.util.function.Function;
import java.util.function.Supplier;

import static org.junit.Assert.assertTrue;

public class LocalTest {

private LocalBucketBuilder builder = Bucket4j.builder()
Expand Down

0 comments on commit c939130

Please sign in to comment.