Skip to content

Commit

Permalink
[ML] Inference API time to reserve tokens for rate limiter (#107571)
Browse files Browse the repository at this point in the history
* Refactoring tests

* Adding time to reserve tests
  • Loading branch information
jonathan-buttner committed Apr 17, 2024
1 parent cc75338 commit 1aa182a
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.common;

import org.elasticsearch.common.Strings;
import org.elasticsearch.core.TimeValue;

import java.time.Clock;
import java.time.Instant;
Expand Down Expand Up @@ -92,24 +93,59 @@ public final synchronized void setRate(double newAccumulatedTokensLimit, double
* @throws InterruptedException _
*/
public void acquire(int tokens) throws InterruptedException {
sleeper.sleep(reserveInternal(tokens));
}

/**
* Returns the amount of time to wait for the tokens to become available but does not reserve them in advance.
* A caller will need to call {@link #reserve(int)} or {@link #acquire(int)} after this call.
* @param tokens the number of items of work that should be throttled, typically you'd pass a value of 1 here. Must be greater than 0.
* @return the amount of time to wait
*/
public TimeValue timeToReserve(int tokens) {
var timeToReserveRes = timeToReserveInternal(tokens);

return new TimeValue((long) timeToReserveRes.microsToWait, TimeUnit.MICROSECONDS);
}

private TimeToReserve timeToReserveInternal(int tokens) {
validateTokenRequest(tokens);

double microsToWait;
accumulateTokens();
var accumulatedTokensToUse = Math.min(tokens, accumulatedTokens);
var additionalTokensRequired = tokens - accumulatedTokensToUse;
microsToWait = additionalTokensRequired / tokensPerMicros;

return new TimeToReserve(microsToWait, accumulatedTokensToUse);
}

private record TimeToReserve(double microsToWait, double accumulatedTokensToUse) {}

private static void validateTokenRequest(int tokens) {
if (tokens <= 0) {
throw new IllegalArgumentException("Requested tokens must be positive");
}
}

double microsToWait;
synchronized (this) {
accumulateTokens();
var accumulatedTokensToUse = Math.min(tokens, accumulatedTokens);
var additionalTokensRequired = tokens - accumulatedTokensToUse;
microsToWait = additionalTokensRequired / tokensPerMicros;
accumulatedTokens -= accumulatedTokensToUse;
nextTokenAvailability = nextTokenAvailability.plus((long) microsToWait, ChronoUnit.MICROS);
}
/**
* Returns the amount of time to wait for the tokens to become available.
* @param tokens the number of items of work that should be throttled, typically you'd pass a value of 1 here. Must be greater than 0.
* @return the amount of time to wait
*/
public TimeValue reserve(int tokens) {
return new TimeValue(reserveInternal(tokens), TimeUnit.MICROSECONDS);
}

private synchronized long reserveInternal(int tokens) {
var timeToReserveRes = timeToReserveInternal(tokens);
accumulatedTokens -= timeToReserveRes.accumulatedTokensToUse;
nextTokenAvailability = nextTokenAvailability.plus((long) timeToReserveRes.microsToWait, ChronoUnit.MICROS);

sleeper.sleep((long) microsToWait);
return (long) timeToReserveRes.microsToWait;
}

private void accumulateTokens() {
private synchronized void accumulateTokens() {
var now = Instant.now(clock);
if (now.isAfter(nextTokenAvailability)) {
var elapsedTimeMicros = microsBetweenExact(nextTokenAvailability, now);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.common;

import org.elasticsearch.common.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.ESTestCase;

import java.time.Clock;
Expand All @@ -17,11 +18,19 @@

import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class RateLimiterTests extends ESTestCase {
public abstract class BaseRateLimiterTests extends ESTestCase {

protected abstract TimeValue tokenMethod(RateLimiter limiter, int tokens) throws InterruptedException;

protected abstract void sleepValidationMethod(
TimeValue result,
RateLimiter.Sleeper mockSleeper,
int numberOfClassToExpect,
long expectedMicrosecondsToSleep
) throws InterruptedException;

public void testThrows_WhenAccumulatedTokensLimit_IsNegative() {
var exception = expectThrows(
IllegalArgumentException.class,
Expand Down Expand Up @@ -65,55 +74,55 @@ public void testThrows_WhenTokensPerTimeUnit_IsNegative() {
assertThat(exception.getMessage(), is("Tokens per time unit must be greater than 0"));
}

public void testAcquire_Throws_WhenTokens_IsZero() {
public void testMethod_Throws_WhenTokens_IsZero() {
var limiter = new RateLimiter(0, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC());
var exception = expectThrows(IllegalArgumentException.class, () -> limiter.acquire(0));
assertThat(exception.getMessage(), is("Requested tokens must be positive"));
}

public void testAcquire_Throws_WhenTokens_IsNegative() {
public void testMethod_Throws_WhenTokens_IsNegative() {
var limiter = new RateLimiter(0, 1, TimeUnit.SECONDS, new RateLimiter.TimeUnitSleeper(), Clock.systemUTC());
var exception = expectThrows(IllegalArgumentException.class, () -> limiter.acquire(-1));
assertThat(exception.getMessage(), is("Requested tokens must be positive"));
}

public void testAcquire_First_CallDoesNotSleep() throws InterruptedException {
public void testMethod_First_CallDoesNotSleep() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
var res = tokenMethod(limiter, 1);
sleepValidationMethod(res, sleeper, 1, 0);
}

public void testAcquire_DoesNotSleep_WhenTokenRateIsHigh() throws InterruptedException {
public void testMethod_DoesNotSleep_WhenTokenRateIsHigh() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.MICROSECONDS, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
var res = tokenMethod(limiter, 1);
sleepValidationMethod(res, sleeper, 1, 0);
}

public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsHigh() throws InterruptedException {
public void testMethod_AcceptsMaxIntValue_WhenTokenRateIsHigh() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, Double.MAX_VALUE, TimeUnit.MICROSECONDS, sleeper, clock);
limiter.acquire(Integer.MAX_VALUE);
verify(sleeper, times(1)).sleep(0);
var res = tokenMethod(limiter, Integer.MAX_VALUE);
sleepValidationMethod(res, sleeper, 1, 0);
}

public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsLow() throws InterruptedException {
public void testMethod_AcceptsMaxIntValue_WhenTokenRateIsLow() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);
Expand All @@ -122,76 +131,77 @@ public void testAcquire_AcceptsMaxIntValue_WhenTokenRateIsLow() throws Interrupt

double tokensPerDay = 1;
var limiter = new RateLimiter(0, tokensPerDay, TimeUnit.DAYS, sleeper, clock);
limiter.acquire(Integer.MAX_VALUE);

var res = tokenMethod(limiter, Integer.MAX_VALUE);
double tokensPerMicro = tokensPerDay / TimeUnit.DAYS.toMicros(1);
verify(sleeper, times(1)).sleep((long) ((double) Integer.MAX_VALUE / tokensPerMicro));
sleepValidationMethod(res, sleeper, 1, (long) ((double) Integer.MAX_VALUE / tokensPerMicro));
}

public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken() throws InterruptedException {
public void testMethod_SleepsForOneMinute_WhenRequestingOneUnavailableToken() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(2);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1));
var res = tokenMethod(limiter, 2);
sleepValidationMethod(res, sleeper, 1, TimeUnit.MINUTES.toMicros(1));
}

public void testAcquire_SleepsForOneMinute_WhenRequestingOneUnavailableToken_NoAccumulated() throws InterruptedException {
public void testMethod_SleepsForOneMinute_WhenRequestingOneUnavailableToken_NoAccumulated() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1));
var res = tokenMethod(limiter, 1);
sleepValidationMethod(res, sleeper, 1, TimeUnit.MINUTES.toMicros(1));
}

public void testAcquire_SleepsFor10Minute_WhenRequesting10UnavailableToken_NoAccumulated() throws InterruptedException {
public void testMethod_SleepsFor10Minute_WhenRequesting10UnavailableToken_NoAccumulated() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(10);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(10));
var res = tokenMethod(limiter, 10);
sleepValidationMethod(res, sleeper, 1, TimeUnit.MINUTES.toMicros(10));
}

public void testAcquire_IncrementsNextTokenAvailabilityInstant_ByOneMinute() throws InterruptedException {
public void testMethod_IncrementsNextTokenAvailabilityInstant_ByOneMinute() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(0, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1));
var res = tokenMethod(limiter, 1);
sleepValidationMethod(res, sleeper, 1, TimeUnit.MINUTES.toMicros(1));
assertThat(limiter.getNextTokenAvailability(), is(now.plus(1, ChronoUnit.MINUTES)));
}

public void testAcquire_SecondCallToAcquire_ShouldWait_WhenAccumulatedTokensAreDepleted() throws InterruptedException {
public void testMethod_SecondCallToAcquire_ShouldWait_WhenAccumulatedTokensAreDepleted() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.MINUTES.toMicros(1));

var res = tokenMethod(limiter, 1);
sleepValidationMethod(res, sleeper, 1, 0);
res = tokenMethod(limiter, 1);
sleepValidationMethod(res, sleeper, 1, TimeUnit.MINUTES.toMicros(1));
}

public void testAcquire_SecondCallToAcquire_ShouldWaitForHalfDuration_WhenElapsedTimeIsHalfRequiredDuration()
public void testMethod_SecondCallToAcquire_ShouldWaitForHalfDuration_WhenElapsedTimeIsHalfRequiredDuration()
throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
Expand All @@ -200,26 +210,28 @@ public void testAcquire_SecondCallToAcquire_ShouldWaitForHalfDuration_WhenElapse
var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(1, 1, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(1);
verify(sleeper, times(1)).sleep(0);

var res = tokenMethod(limiter, 1);
sleepValidationMethod(res, sleeper, 1, 0);
when(clock.instant()).thenReturn(now.plus(Duration.ofSeconds(30)));
limiter.acquire(1);
verify(sleeper, times(1)).sleep(TimeUnit.SECONDS.toMicros(30));
res = tokenMethod(limiter, 1);
sleepValidationMethod(res, sleeper, 1, TimeUnit.SECONDS.toMicros(30));
}

public void testAcquire_ShouldAccumulateTokens() throws InterruptedException {
public void testMethod_ShouldAccumulateTokens() throws InterruptedException {
var now = Clock.systemUTC().instant();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);

var sleeper = mock(RateLimiter.Sleeper.class);

var limiter = new RateLimiter(10, 10, TimeUnit.MINUTES, sleeper, clock);
limiter.acquire(5);
verify(sleeper, times(1)).sleep(0);

var res = tokenMethod(limiter, 5);
sleepValidationMethod(res, sleeper, 1, 0);
// it should accumulate 5 tokens
when(clock.instant()).thenReturn(now.plus(Duration.ofSeconds(30)));
limiter.acquire(10);
verify(sleeper, times(2)).sleep(0);
res = tokenMethod(limiter, 10);
sleepValidationMethod(res, sleeper, 2, 0);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;

import org.elasticsearch.core.TimeValue;

import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

public class RateLimiterAcquireTests extends BaseRateLimiterTests {

@Override
protected TimeValue tokenMethod(RateLimiter limiter, int tokens) throws InterruptedException {
limiter.acquire(tokens);
return null;
}

@Override
protected void sleepValidationMethod(
TimeValue result,
RateLimiter.Sleeper mockSleeper,
int numberOfClassToExpect,
long expectedMicrosecondsToSleep
) throws InterruptedException {
verify(mockSleeper, times(numberOfClassToExpect)).sleep(expectedMicrosecondsToSleep);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;

import org.elasticsearch.core.TimeValue;

import static org.hamcrest.Matchers.is;

public class RateLimiterReserveTests extends BaseRateLimiterTests {

@Override
protected TimeValue tokenMethod(RateLimiter limiter, int tokens) {
return limiter.reserve(tokens);
}

@Override
protected void sleepValidationMethod(
TimeValue result,
RateLimiter.Sleeper mockSleeper,
int numberOfClassToExpect,
long expectedMicrosecondsToSleep
) {
assertThat(result.getMicros(), is(expectedMicrosecondsToSleep));
}
}

0 comments on commit 1aa182a

Please sign in to comment.