Skip to content

Commit

Permalink
#151 "forceAddTokens" implementations + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-bukhtoyarov committed Mar 19, 2021
1 parent a823a24 commit f94453f
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 11 deletions.
39 changes: 30 additions & 9 deletions bucket4j-core/src/main/java/io/github/bucket4j/BucketState.java
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,12 @@ private void replaceBandwidthAsIs(BucketState newState, int newBandwidthIndex, B
long lastRefillTimeNanos = getLastRefillTimeNanos(previousBandwidthIndex);
newState.setLastRefillTimeNanos(newBandwidthIndex, lastRefillTimeNanos);

long currentSize = getCurrentSize(previousBandwidthIndex);
if (currentSize >= newBandwidth.capacity) {
newState.setCurrentSize(newBandwidthIndex, newBandwidth.capacity);
return;
}
if (newBandwidth.isGready() && previousBandwidth.isGready()) {
long currentSize = getCurrentSize(previousBandwidthIndex);
long newSize = Math.min(newBandwidth.capacity, currentSize);
newState.setCurrentSize(newBandwidthIndex, newSize);

Expand All @@ -122,17 +126,21 @@ private void replaceBandwidthAsIs(BucketState newState, int newBandwidthIndex, B
newRoundingError = newBandwidth.refillPeriodNanos - 1;
}
newState.setRoundingError(newBandwidthIndex, newRoundingError);
return;
} else {
long newSize = Math.min(newBandwidth.capacity, currentSize);
newState.setCurrentSize(newBandwidthIndex, newSize);
}

long currentSize = getCurrentSize(previousBandwidthIndex);
long newSize = Math.min(newBandwidth.capacity, currentSize);
newState.setCurrentSize(newBandwidthIndex, newSize);
}

private void replaceBandwidthProportional(BucketState newState, int newBandwidthIndex, Bandwidth newBandwidth, int previousBandwidthIndex, Bandwidth previousBandwidth, long currentTimeNanos) {
newState.setLastRefillTimeNanos(newBandwidthIndex, getLastRefillTimeNanos(previousBandwidthIndex));
long currentSize = getCurrentSize(previousBandwidthIndex);
if (currentSize >= previousBandwidth.capacity) {
// can come here if forceAddTokens has been used
newState.setCurrentSize(newBandwidthIndex, newBandwidth.capacity);
return;
}

long roundingError = getRoundingError(previousBandwidthIndex);
double realRoundedError = (double) roundingError / (double) previousBandwidth.refillPeriodNanos;
double scale = (double) newBandwidth.capacity / (double) previousBandwidth.capacity;
Expand Down Expand Up @@ -267,9 +275,17 @@ private void addTokens(int bandwidthIndex, Bandwidth bandwidth, long tokensToAdd
}
}

private void forceAddTokens(int i, Bandwidth limit, long tokensToAdd) {
// TODO
throw new UnsupportedOperationException();
private void forceAddTokens(int bandwidthIndex, Bandwidth bandwidth, long tokensToAdd) {
long currentSize = getCurrentSize(bandwidthIndex);
long newSize = currentSize + tokensToAdd;
if (newSize < currentSize) {
// arithmetic overflow happens. This mean that bucket reached Long.MAX_VALUE tokens.
// just set MAX_VALUE tokens
setCurrentSize(bandwidthIndex, Long.MAX_VALUE);
setRoundingError(bandwidthIndex, 0);
} else {
setCurrentSize(bandwidthIndex, newSize);
}
}

private void refill(int bandwidthIndex, Bandwidth bandwidth, long currentTimeNanos) {
Expand All @@ -293,6 +309,11 @@ private void refill(int bandwidthIndex, Bandwidth bandwidth, long currentTimeNan
final long refillTokens = bandwidth.refillTokens;
final long currentSize = getCurrentSize(bandwidthIndex);

if (currentSize >= capacity) {
// can come here if forceAddTokens has been used
return;
}

long durationSinceLastRefillNanos = currentTimeNanos - previousRefillNanos;
long newSize = currentSize;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public <O> void serialize(SerializationAdapter<O> adapter, O output, ForceAddTok

@Override
public int getTypeId() {
return 25;
return 26;
}

@Override
Expand All @@ -65,7 +65,7 @@ public ForceAddTokensCommand(long tokensToAdd) {
@Override
public Nothing execute(GridBucketState state, long currentTimeNanos) {
state.refillAllBandwidth(currentTimeNanos);
state.addTokens(tokensToAdd);
state.forceAddTokens(tokensToAdd);
return Nothing.INSTANCE;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,43 @@ class ConfigurationReplacementSpecification extends Specification {
bucketType << BucketType.values()
}

@Unroll
def "#bucketType test replace configuration proportionally when capacity overflown"(BucketType bucketType) {
expect:
for (boolean sync : [true, false]) {
for (boolean verbose: [true, false]) {
// System.err.println("sync: $sync verbose: $verbose")
TimeMeterMock clock = new TimeMeterMock(0)
Bucket bucket = bucketType.createBucket(Bucket4j.builder()
.addLimit(Bandwidth.simple(3, Duration.ofNanos(5)).withInitialTokens(0)),
clock
)
bucket.forceAddTokens(10000000)
assert bucket.getAvailableTokens() == 10000000

BucketConfiguration newConfiguration = Bucket4j.configurationBuilder()
.addLimit(Bandwidth.simple(60, Duration.ofNanos(1000)))
.build()
if (sync) {
if (!verbose) {
bucket.replaceConfiguration(newConfiguration, TokensInheritanceStrategy.PROPORTIONALLY)
} else {
bucket.asVerbose().replaceConfiguration(newConfiguration, TokensInheritanceStrategy.PROPORTIONALLY)
}
} else {
if (!verbose) {
bucket.asAsync().replaceConfiguration(newConfiguration, TokensInheritanceStrategy.PROPORTIONALLY).get()
} else {
bucket.asAsync().asVerbose().replaceConfiguration(newConfiguration, TokensInheritanceStrategy.PROPORTIONALLY).get()
}
}
assert bucket.getAvailableTokens() == 60 // because should be just reduced to maximum
}
}
where:
bucketType << BucketType.values()
}

@Unroll
def "#bucketType test replace configuration proportionally from gready refill to gready refill. Case for roundingError propogation"(BucketType bucketType) {
expect:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,20 @@ class DetectionOfIllegalApiUsageSpecification extends Specification {
tokens << [0, -1, -10]
}

@Unroll
def "Should check that #tokens tokens is not positive to force add"(long tokens) {
setup:
def bucket = Bucket4j.builder().addLimit(
Bandwidth.simple(VALID_CAPACITY, VALID_PERIOD)
).build()
when:
bucket.forceAddTokens(tokens)
then:
thrown(IllegalArgumentException)
where:
tokens << [0, -1, -10]
}

def "Should that scheduler passed to tryConsume is not null"() {
setup:
def bucket = Bucket4j.builder().addLimit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ public class EqualityUtils {
registerComparator(VerboseCommand.class, (cmd1, cmd2) -> {
return equals(cmd1.getTargetCommand(), cmd2.getTargetCommand());
});

registerComparator(ForceAddTokensCommand.class, (cmd1, cmd2) -> {
return equals(cmd1.getTokensToAdd(), cmd2.getTokensToAdd());
});
}

public static <T> void registerComparator(Class<T> clazz, BiFunction<T, T, Boolean> comparator) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package io.github.bucket4j

import io.github.bucket4j.mock.BucketType
import io.github.bucket4j.mock.TimeMeterMock
import spock.lang.Specification
import spock.lang.Unroll

import java.time.Duration

import static io.github.bucket4j.PackageAcessor.getState
import static io.github.bucket4j.PackageAcessor.getState
import static org.junit.Assert.assertNotSame
import static org.junit.Assert.assertNotSame

class ForceAddTokensSpecification extends Specification {

@Unroll
def "#n Force Add tokens spec"(
int n, long tokensToAdd, long nanosIncrement, long requiredResult, AbstractBucketBuilder builder) {
expect:
for (BucketType type : BucketType.values()) {
for (boolean sync : [true, false]) {
for (boolean verbose : [true, false]) {
// println("type=$type sync=$sync verbose=$verbose")
TimeMeterMock timeMeter = new TimeMeterMock(0)
Bucket bucket = type.createBucket(builder, timeMeter)
timeMeter.addTime(nanosIncrement)
if (sync) {
if (!verbose) {
bucket.forceAddTokens(tokensToAdd)
} else {
def verboseResult = bucket.asVerbose().forceAddTokens(tokensToAdd)
assertNotSame(verboseResult.state, getState(bucket))
}
} else {
if (!verbose) {
bucket.asAsync().forceAddTokens(tokensToAdd).get()
} else {
def verboseResult = bucket.asAsync().asVerbose().forceAddTokens(tokensToAdd).get()
assertNotSame(verboseResult.state, getState(bucket))
}
}
assert bucket.createSnapshot().getAvailableTokens(bucket.configuration.bandwidths) == requiredResult
}
}
}
where:
n | tokensToAdd | nanosIncrement | requiredResult | builder
1 | 49 | 50 | 99 | Bucket4j.builder().addLimit(Bandwidth.simple(100, Duration.ofNanos(100)).withInitialTokens(0))
2 | 50 | 50 | 100 | Bucket4j.builder().addLimit(Bandwidth.simple(100, Duration.ofNanos(100)).withInitialTokens(0))
3 | 50 | 0 | 50 | Bucket4j.builder().addLimit(Bandwidth.simple(100, Duration.ofNanos(100)).withInitialTokens(0))
4 | 120 | 0 | 120 | Bucket4j.builder().addLimit(Bandwidth.simple(100, Duration.ofNanos(100)).withInitialTokens(0))
5 | 120 | 110 | 220 | Bucket4j.builder().addLimit(Bandwidth.simple(100, Duration.ofNanos(100)).withInitialTokens(0))
}

@Unroll
def "#n Tokens that was added over capacity should not be lost"() {
setup:
TimeMeterMock timeMeter = new TimeMeterMock(0)
Bucket bucket = Bucket4j.builder()
.addLimit(Bandwidth.simple(100, Duration.ofNanos(100)))
.withCustomTimePrecision(timeMeter)
.build()
when:
bucket.forceAddTokens(10)
then:
bucket.getAvailableTokens() == 110

when:
timeMeter.addTime(10)
bucket.consumeIgnoringRateLimits(2)
then:
bucket.getAvailableTokens() == 108

when:
timeMeter.addTime(10)
bucket.tryConsume(3)
then:
bucket.getAvailableTokens() == 105
}

}

0 comments on commit f94453f

Please sign in to comment.