From e9bf001d8c429222b94ed1d9bcaa1a07daeddf71 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 28 Jan 2025 21:18:07 +0000 Subject: [PATCH] [ML] Add _stream path for chat_completions task (#121006) As a replacement for the _unified route --- .../inference/InferenceBaseRestTest.java | 3 +- .../xpack/inference/rest/Paths.java | 1 + .../rest/RestStreamInferenceAction.java | 33 +++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 546eab471a077..bb3f3e9b46c4d 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -355,7 +355,8 @@ protected Deque unifiedCompletionInferOnMockService( List input, @Nullable Consumer responseConsumerCallback ) throws Exception { - var endpoint = Strings.format("_inference/%s/%s/_unified", taskType, modelId); + var route = randomBoolean() ? "_stream" : "_unified"; // TODO remove unified route + var endpoint = Strings.format("_inference/%s/%s/%s", taskType, modelId, route); return callAsyncUnified(endpoint, input, "user", responseConsumerCallback); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java index 57c06df8d8dfe..7f43676dfb5f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java @@ -31,6 +31,7 @@ public final class Paths { + INFERENCE_ID + "}/_stream"; + // TODO remove the _unified path public static final String UNIFIED_SUFFIX = "_unified"; static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/" + UNIFIED_SUFFIX; static final String UNIFIED_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java index 881af435b29b6..518056365d88b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java @@ -9,12 +9,17 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import java.io.IOException; import java.util.List; import java.util.Objects; @@ -50,4 +55,32 @@ protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Reques protected ActionListener listener(RestChannel channel) { return new ServerSentEventsRestActionListener(channel, threadPool); } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + var params = parseParams(restRequest); + var inferTimeout = parseTimeout(restRequest); + + if (params.taskType() == TaskType.CHAT_COMPLETION) { + UnifiedCompletionAction.Request request; + try (var parser = restRequest.contentParser()) { + request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser); + } + + return channel -> client.execute( + UnifiedCompletionAction.INSTANCE, + request, + new ServerSentEventsRestActionListener(channel, threadPool) + ); + } else { + InferenceAction.Request.Builder requestBuilder; + try (var parser = restRequest.contentParser()) { + requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser); + } + + requestBuilder.setInferenceTimeout(inferTimeout); + var request = prepareInferenceRequest(requestBuilder); + return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel)); + } + } }