Skip to content

Commit

Permalink
#51 protect from undefined behavior caused by arithmetic overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-bukhtoyarov committed Sep 23, 2017
1 parent 895d436 commit 2b2ebe7
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ public static IllegalArgumentException restrictionsNotSpecified() {
return new IllegalArgumentException(msg);
}

public static IllegalArgumentException tooHighRefillRate(long periodNanos, long tokens) {
double actualRate = (double) tokens / (double) periodNanos;
String pattern = "{0} token/nanosecond is not permitted refill rate" +
", because highest supported rate is 1 token/nanosecond";
String msg = MessageFormat.format(pattern, actualRate);
return new IllegalArgumentException(msg);
}

// ------------------- end of construction time exceptions --------------------------------

// ------------------- usage time exceptions ---------------------------------------------
Expand Down
105 changes: 82 additions & 23 deletions bucket4j-core/src/main/java/io/github/bucket4j/BucketState.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,11 @@ private void addTokens(int bandwidthIndex, Bandwidth bandwidth, long tokensToAdd
long currentSize = getCurrentSize(bandwidthIndex);
long newSize = currentSize + tokensToAdd;
if (newSize >= bandwidth.capacity) {
setCurrentSize(bandwidthIndex, bandwidth.capacity);
setRoundingError(bandwidthIndex, 0L);
resetBandwidth(bandwidthIndex, bandwidth.capacity);
} else if (newSize < currentSize) {
// arithmetic overflow happens. This mean that bucket reached Long.MAX_VALUE tokens.
// just reset bandwidth state
resetBandwidth(bandwidthIndex, bandwidth.capacity);
} else {
setCurrentSize(bandwidthIndex, newSize);
}
Expand All @@ -121,46 +124,85 @@ private void consume(int bandwidth, long tokens) {

private void refill(int bandwidthIndex, Bandwidth bandwidth, long previousRefillNanos, long currentTimeNanos) {
final long capacity = bandwidth.capacity;
long currentSize = getCurrentSize(bandwidthIndex);

if (currentSize >= capacity) {
setCurrentSize(bandwidthIndex, capacity);
setRoundingError(bandwidthIndex, 0L);
return;
}
final long refillPeriodNanos = bandwidth.refill.getPeriodNanos();
final long refillTokens = bandwidth.refill.getTokens();
final long currentSize = getCurrentSize(bandwidthIndex);

long durationSinceLastRefillNanos = currentTimeNanos - previousRefillNanos;
long newSize = currentSize;

if (durationSinceLastRefillNanos > refillPeriodNanos) {
long elapsedPeriods = durationSinceLastRefillNanos / refillPeriodNanos;
long calculatedRefill = elapsedPeriods * refillTokens;
newSize += calculatedRefill;
if (newSize > capacity) {
resetBandwidth(bandwidthIndex, capacity);
return;
}
if (newSize < currentSize) {
// arithmetic overflow happens. This mean that tokens reached Long.MAX_VALUE tokens.
// just reset bandwidth state
resetBandwidth(bandwidthIndex, capacity);
return;
}
durationSinceLastRefillNanos %= refillPeriodNanos;
}

long refillPeriod = bandwidth.refill.getPeriodNanos();
long refillTokens = bandwidth.refill.getTokens();
long roundingError = getRoundingError(bandwidthIndex);
long divided = refillTokens * durationSinceLastRefillNanos + roundingError;
long calculatedRefill = divided / refillPeriod;
if (calculatedRefill == 0) {
setRoundingError(bandwidthIndex, divided);
return;
long dividedWithoutError = multiplyExactOrReturnMaxValue(refillTokens, durationSinceLastRefillNanos);
long divided = dividedWithoutError + roundingError;
if (divided < 0 || dividedWithoutError == Long.MAX_VALUE) {
// arithmetic overflow happens.
// there is no sense to stay in integer arithmetic when having deal with so big numbers
long calculatedRefill = (long) ((double) durationSinceLastRefillNanos / (double) refillPeriodNanos * (double) refillTokens);
newSize += calculatedRefill;
roundingError = 0;
} else {
long calculatedRefill = divided / refillPeriodNanos;
if (calculatedRefill == 0) {
roundingError = divided;
} else {
newSize += calculatedRefill;
roundingError = divided % refillPeriodNanos;
}
}

long newSize = currentSize + calculatedRefill;
if (newSize >= capacity) {
setCurrentSize(bandwidthIndex, capacity);
setRoundingError(bandwidthIndex, 0);
resetBandwidth(bandwidthIndex, capacity);
return;
}
if (newSize < currentSize) {
// arithmetic overflow happens. This mean that bucket reached Long.MAX_VALUE tokens.
// just reset bandwidth state
resetBandwidth(bandwidthIndex, capacity);
return;
}

roundingError = divided % refillPeriod;
setCurrentSize(bandwidthIndex, newSize);
setRoundingError(bandwidthIndex, roundingError);
}

private void resetBandwidth(int bandwidthIndex, long capacity) {
setCurrentSize(bandwidthIndex, capacity);
setRoundingError(bandwidthIndex, 0);
}

private long delayNanosAfterWillBePossibleToConsume(int bandwidthIndex, Bandwidth bandwidth, long tokens) {
long currentSize = getCurrentSize(bandwidthIndex);
if (tokens <= currentSize) {
return 0;
}
long deficit = tokens - currentSize;
long periodNanos = bandwidth.refill.getPeriodNanos();
return periodNanos * deficit / bandwidth.refill.getTokens();
long refillPeriodNanos = bandwidth.refill.getPeriodNanos();
long refillPeriodTokens = bandwidth.refill.getTokens();

long divided = multiplyExactOrReturnMaxValue(refillPeriodNanos, deficit);
if (divided == Long.MAX_VALUE) {
// arithmetic overflow happens.
// there is no sense to stay in integer arithmetic when having deal with so big numbers
return (long)((double) deficit / (double)refillPeriodTokens * (double)refillPeriodNanos);
} else {
return divided / refillPeriodTokens;
}
}

long getCurrentSize(int bandwidth) {
Expand Down Expand Up @@ -194,4 +236,21 @@ public String toString() {
'}';
}

// just a copy of JDK method Math#multiplyExact,
// but instead of throwing exception it returns Long.MAX_VALUE in case of overflow
private static long multiplyExactOrReturnMaxValue(long x, long y) {
long r = x * y;
long ax = Math.abs(x);
long ay = Math.abs(y);
if (((ax | ay) >>> 31 != 0)) {
// Some bits greater than 2^31 that might cause overflow
// Check the result using the divide operator
// and check for the special case of Long.MIN_VALUE * -1
if (((y != 0) && (r / y != x)) || (x == Long.MIN_VALUE && y == -1)) {
return Long.MAX_VALUE;
}
}
return r;
}

}
3 changes: 3 additions & 0 deletions bucket4j-core/src/main/java/io/github/bucket4j/Refill.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ private Refill(long tokens, Duration period) {
if (periodNanos <= 0) {
throw BucketExceptions.nonPositivePeriod(periodNanos);
}
if (tokens > periodNanos) {
throw BucketExceptions.tooHighRefillRate(periodNanos, tokens);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ class DetectionOfIllegalApiUsageSpecification extends Specification {
ex.message == restrictionsNotSpecified().message
}

def "Should detect the high rate of refill"() {
when:
Bucket4j.builder().addLimit(Bandwidth.simple(2, Duration.ofNanos(1)))
then:
IllegalArgumentException ex = thrown()
ex.message == tooHighRefillRate(1, 2).message
}

def "Should check that tokens to consume should be positive"() {
setup:
def bucket = Bucket4j.builder().addLimit(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
*
* Copyright 2015-2017 Vladimir Bukhtoyarov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.bucket4j

import io.github.bucket4j.mock.TimeMeterMock
import spock.lang.Specification

import java.time.Duration

class HandlingArithmeticOverflowSpecification extends Specification {

def "regression test for https://github.com/vladimir-bukhtoyarov/bucket4j/issues/51"() {
setup:
Bandwidth limit1 = Bandwidth.simple(700000, Duration.ofHours(1))
Bandwidth limit2 = Bandwidth.simple(14500, Duration.ofMinutes(1))
Bandwidth limit3 = Bandwidth.simple(300, Duration.ofSeconds(1))
TimeMeterMock customTimeMeter = new TimeMeterMock(0)
long twelveHourNanos = 12 * 60 * 60 * 1_000_000_000L;
Bucket bucket = Bucket4j.builder()
.addLimit(limit1)
.addLimit(limit2)
.addLimit(limit3)
.withCustomTimePrecision(customTimeMeter)
.build()
when:
// shift time to 12 hours forward
customTimeMeter.addTime(twelveHourNanos)
then:
bucket.tryConsume(1)
bucket.tryConsume(300 - 1)
!bucket.tryConsume(1)
}

def "Should check ArithmeticOverflow when add tokens to bucket"() {
setup:
Bandwidth limit = Bandwidth.simple(10, Duration.ofSeconds(1))
TimeMeterMock customTimeMeter = new TimeMeterMock(0)
Bucket bucket = Bucket4j.builder()
.addLimit(9, limit)
.withCustomTimePrecision(customTimeMeter)
.build()
when:
bucket.addTokens(Long.MAX_VALUE - 1)
then:
bucket.tryConsume(10)
!bucket.tryConsume(1)
}

def "Should firstly do refill by completed periods"() {
setup:
Bandwidth limit = Bandwidth.simple((long) Long.MAX_VALUE / 16, Duration.ofNanos((long) Long.MAX_VALUE / 8))
TimeMeterMock meter = new TimeMeterMock(0)
Bucket bucket = Bucket4j.builder()
.addLimit(7, limit)
.withCustomTimePrecision(meter)
.build()
when:
// emulate time shift which equal of 3 refill periods
meter.addTime((long) Long.MAX_VALUE / 8 * 3)
then:
bucket.tryConsume((long) Long.MAX_VALUE / 16)
!bucket.tryConsume(1)
}

def "Should check ArithmeticOverflow when refilling by completed periods"() {
setup:
Bandwidth limit = Bandwidth.classic((long) Long.MAX_VALUE - 10, Refill.smooth(1, Duration.ofNanos(1)))
TimeMeterMock meter = new TimeMeterMock(0)
Bucket bucket = Bucket4j.builder()
.addLimit((long) Long.MAX_VALUE - 13, limit)
.withCustomTimePrecision(meter)
.build()
when:
// add time shift enough to overflow
meter.addTime(20)
then:
bucket.tryConsume(Long.MAX_VALUE - 10)
!bucket.tryConsume(1)
}

def "Should down to floating point arithmetic if necessary during refill"() {
setup:
Bandwidth limit = Bandwidth.simple((long) Long.MAX_VALUE / 16, Duration.ofNanos((long) Long.MAX_VALUE / 8))
TimeMeterMock meter = new TimeMeterMock(0)
Bucket bucket = Bucket4j.builder()
.addLimit(0, limit)
.withCustomTimePrecision(meter)
.build()
when:
// emulate time shift which little bit less then one refill period
meter.addTime((long) Long.MAX_VALUE / 16 - 1)
then:
// should down into floating point arithmetic and successfully refill
bucket.tryConsume((long) Long.MAX_VALUE / 32)
bucket.tryConsumeAsMuchAsPossible() == 1
}

def "Should check ArithmeticOverflow when refilling by uncompleted periods"() {
setup:
Bandwidth limit = Bandwidth.classic((long) Long.MAX_VALUE - 10, Refill.smooth(100, Duration.ofNanos(100)))
TimeMeterMock meter = new TimeMeterMock(0)
Bucket bucket = Bucket4j.builder()
.addLimit((long) Long.MAX_VALUE - 13, limit)
.withCustomTimePrecision(meter)
.build()
when:
// add time shift enough to overflow
meter.addTime(50)
then:
bucket.tryConsume(Long.MAX_VALUE - 10)
!bucket.tryConsume(1)
}

def "Should down to floating point arithmetic when having deal with big number during deficit calculation"() {
setup:
Bandwidth limit = Bandwidth.simple ((long) Long.MAX_VALUE / 2, Duration.ofNanos((long) Long.MAX_VALUE / 2))
TimeMeterMock meter = new TimeMeterMock(0)
Bucket bucket = Bucket4j.builder()
.addLimit(0, limit)
.withCustomTimePrecision(meter)
.build()
BucketState state = bucket.createSnapshot()
Bandwidth[] limits = bucket.configuration.bandwidths

expect:
state.delayNanosAfterWillBePossibleToConsume(limits, 10) == 10
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
import java.util.function.Function;
import java.util.function.Supplier;

import static org.junit.Assert.assertTrue;

public class LocalTest {

private LocalBucketBuilder builder = Bucket4j.builder()
Expand Down

0 comments on commit 2b2ebe7

Please sign in to comment.