Skip to content

Commit

Permalink
[ML] Inference API rate limit queuing logic refactor (#107706)
Browse files Browse the repository at this point in the history
* Adding new executor

* Adding in queuing logic

* working tests

* Added cleanup task

* Update docs/changelog/107706.yaml

* Updating yml

* deregistering callbacks for settings changes

* Cleaning up code

* Update docs/changelog/107706.yaml

* Fixing rate limit settings bug and only sleeping least amount

* Removing debug logging

* Removing commented code

* Renaming feedback

* fixing tests

* Updating docs and validation

* Fixing source blocks

* Adjusting cancel logic

* Reformatting ascii

* Addressing feedback

* adding rate limiting for google embeddings and mistral

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
  • Loading branch information
jonathan-buttner and elasticmachine committed Jun 5, 2024
1 parent cd84749 commit fdb5058
Show file tree
Hide file tree
Showing 102 changed files with 1,487 additions and 925 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/107706.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 107706
summary: Add rate limiting support for the Inference API
area: Machine Learning
type: enhancement
issues: []
289 changes: 172 additions & 117 deletions docs/reference/inference/put-inference.asciidoc

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions libs/core/src/main/java/org/elasticsearch/core/TimeValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ public static TimeValue timeValueDays(long days) {
return new TimeValue(days, TimeUnit.DAYS);
}

/**
* @return the {@link TimeValue} object that has the least duration.
*/
public static TimeValue min(TimeValue time1, TimeValue time2) {
return time1.compareTo(time2) < 0 ? time1 : time2;
}

/**
* @return the unit used for the this time value, see {@link #duration()}
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.object.HasToString.hasToString;

Expand Down Expand Up @@ -231,6 +232,12 @@ public void testRejectsNegativeValuesAtCreation() {
assertThat(ex.getMessage(), containsString("duration cannot be negative"));
}

public void testMin() {
assertThat(TimeValue.min(TimeValue.ZERO, TimeValue.timeValueNanos(1)), is(TimeValue.timeValueNanos(0)));
assertThat(TimeValue.min(TimeValue.MAX_VALUE, TimeValue.timeValueNanos(1)), is(TimeValue.timeValueNanos(1)));
assertThat(TimeValue.min(TimeValue.MINUS_ONE, TimeValue.timeValueHours(1)), is(TimeValue.MINUS_ONE));
}

private TimeUnit randomTimeUnitObject() {
return randomFrom(
TimeUnit.NANOSECONDS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class CohereActionCreator implements CohereActionVisitor {
private final ServiceComponents serviceComponents;

public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
// TODO Batching - accept a class that can handle batching
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, Thread
model.getServiceSettings().getCommonSettings().uri(),
"Cohere embeddings"
);
// TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager
requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -37,17 +36,16 @@ public AzureAiStudioChatCompletionRequestManager(AzureAiStudioChatCompletionMode
}

@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, input);

return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}

private static ResponseHandler createCompletionHandler() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -41,17 +40,16 @@ public AzureAiStudioEmbeddingsRequestManager(AzureAiStudioEmbeddingsModel model,
}

@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
AzureAiStudioEmbeddingsRequest request = new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}

private static ResponseHandler createEmbeddingsHandler() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -43,16 +42,15 @@ public AzureOpenAiCompletionRequestManager(AzureOpenAiCompletionModel model, Thr
}

@Override
public Runnable create(
public void execute(
@Nullable String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -55,16 +54,15 @@ public AzureOpenAiEmbeddingsRequestManager(AzureOpenAiEmbeddingsModel model, Tru
}

@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@ public String inferenceEntityId() {

@Override
public Object rateLimitGrouping() {
return rateLimitGroup;
// It's possible that two inference endpoints have the same information defining the group but have different
// rate limits then they should be in different groups otherwise whoever initially created the group will set
// the rate and the other inference endpoint's rate will be ignored
return new EndpointGrouping(rateLimitGroup, rateLimitSettings);
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
}

private record EndpointGrouping(Object group, RateLimitSettings settings) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -46,16 +45,15 @@ private CohereCompletionRequestManager(CohereCompletionModel model, ThreadPool t
}

@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
CohereCompletionRequest request = new CohereCompletionRequest(input, model);

return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -44,16 +43,15 @@ private CohereEmbeddingsRequestManager(CohereEmbeddingsModel model, ThreadPool t
}

@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(input, model);

return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -44,16 +43,15 @@ private CohereRerankRequestManager(CohereRerankModel model, ThreadPool threadPoo
}

@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
CohereRerankRequest request = new CohereRerankRequest(query, input, model);

return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ record ExecutableInferenceRequest(
RequestSender requestSender,
Logger logger,
Request request,
HttpClientContext context,
ResponseHandler responseHandler,
Supplier<Boolean> hasFinished,
ActionListener<InferenceServiceResults> listener
Expand All @@ -34,7 +33,7 @@ public void run() {
var inferenceEntityId = request.createHttpRequest().inferenceEntityId();

try {
requestSender.send(logger, request, context, hasFinished, responseHandler, listener);
requestSender.send(logger, request, HttpClientContext.create(), hasFinished, responseHandler, listener);
} catch (Exception e) {
var errorMessage = Strings.format("Failed to send request from inference entity id [%s]", inferenceEntityId);
logger.warn(errorMessage, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -42,15 +41,14 @@ public GoogleAiStudioCompletionRequestManager(GoogleAiStudioCompletionModel mode
}

@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -48,17 +47,16 @@ public GoogleAiStudioEmbeddingsRequestManager(GoogleAiStudioEmbeddingsModel mode
}

@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
GoogleAiStudioEmbeddingsRequest request = new GoogleAiStudioEmbeddingsRequest(truncator, truncatedInput, model);

return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Loading

0 comments on commit fdb5058

Please sign in to comment.