From 0fb460a0285daba5b5c8348712ea9cfd659c30d9 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 6 Dec 2024 11:15:27 -0500 Subject: [PATCH 01/42] Starting completion model --- ...lasticInferenceServiceCompletionModel.java | 19 +++++++++++++++++++ ...renceServiceCompletionServiceSettings.java | 1 + 2 files changed, 20 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java index 84039cd7cc33c..389fed811b527 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java @@ -33,6 +33,25 @@ public static ElasticInferenceServiceCompletionModel of( ElasticInferenceServiceCompletionModel model, UnifiedCompletionRequest request ) { +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; + +public class ElasticInferenceServiceCompletionModel extends ElasticInferenceServiceModel { + + public static ElasticInferenceServiceCompletionModel of(ElasticInferenceServiceCompletionModel model, UnifiedCompletionRequest request) { var originalModelServiceSettings = model.getServiceSettings(); var overriddenServiceSettings = new ElasticInferenceServiceCompletionServiceSettings( Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java index 3c8182a7d41a4..ef0d7958ec184 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; From 8ef932a017043a198bd01e8fe0c43170fc1ac0a4 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 6 Dec 2024 15:40:33 -0500 Subject: [PATCH 02/42] Adding model --- .../openai/OpenAiUnifiedChatCompletionRequest.java | 10 +++++++++- .../OpenAiUnifiedChatCompletionRequestEntity.java | 6 ++++-- .../services/elastic/ElasticInferenceService.java | 1 + .../ElasticInferenceServiceCompletionModel.java | 12 +++++------- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java index e5b85633a499b..a547b4c1c52f5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java @@ -44,7 +44,15 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8) + Strings.toString( + new OpenAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + new OpenAiUnifiedChatCompletionRequestEntity.ModelFields( + model.getServiceSettings().modelId(), + model.getTaskSettings().user() + ) + ) + ).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index b80100c9e2f79..4f50d97107ded 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -30,6 +30,8 @@ public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInpu this.model = Objects.requireNonNull(model); } + public record ModelFields(String modelId, @Nullable String user) {} + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -37,8 +39,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); - if (Strings.isNullOrEmpty(model.getTaskSettings().user()) == false) { - builder.field(USER_FIELD, model.getTaskSettings().user()); + if (Strings.isNullOrEmpty(modelFields.user()) == false) { + builder.field(USER_FIELD, modelFields.user()); } builder.endObject(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 48416faac6a06..13a0efb3d3d23 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -37,6 +37,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.ElasticInferenceServiceUnifiedCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.OpenAiUnifiedCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java index 389fed811b527..04b815df988d0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java @@ -17,6 +17,7 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; @@ -37,21 +38,18 @@ public static ElasticInferenceServiceCompletionModel of( import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; -import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; -import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings; import java.net.URI; import java.net.URISyntaxException; -import java.util.Locale; import java.util.Map; import java.util.Objects; -import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; - public class ElasticInferenceServiceCompletionModel extends ElasticInferenceServiceModel { - public static ElasticInferenceServiceCompletionModel of(ElasticInferenceServiceCompletionModel model, UnifiedCompletionRequest request) { + public static ElasticInferenceServiceCompletionModel of( + ElasticInferenceServiceCompletionModel model, + UnifiedCompletionRequest request + ) { var originalModelServiceSettings = model.getServiceSettings(); var overriddenServiceSettings = new ElasticInferenceServiceCompletionServiceSettings( Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), From bb97600e2566881b278054c53e94479c6d22491c Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Mon, 9 Dec 2024 13:45:13 -0500 Subject: [PATCH 03/42] initial implementation of request and response handling, manager, and entity --- ...SUnifiedChatCompletionResponseHandler.java | 35 + .../EISUnifiedCompletionRequestManager.java | 62 ++ .../EISUnifiedChatCompletionRequest.java | 107 +++ ...EISUnifiedChatCompletionRequestEntity.java | 176 ++++ .../elastic/ElasticInferenceService.java | 1 - ...ifiedChatCompletionRequestEntityTests.java | 891 ++++++++++++++++++ ...ifiedChatCompletionRequestEntityTests.java | 5 +- 7 files changed, 1275 insertions(+), 2 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/EISUnifiedChatCompletionResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/EISUnifiedChatCompletionRequestEntityTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/EISUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/EISUnifiedChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..540e1484bc179 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/EISUnifiedChatCompletionResponseHandler.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.elastic; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; + +import java.util.concurrent.Flow; + +public class EISUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler { + public EISUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction); + } + + @Override + public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { + var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); + var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); // EIS uses the unified API spec + + flow.subscribe(serverSentEventProcessor); + serverSentEventProcessor.subscribe(openAiProcessor); + return new StreamingUnifiedChatCompletionResults(openAiProcessor); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java new file mode 100644 index 0000000000000..3445736e70f15 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.elastic.EISUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.elastic.EISUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; + +import java.util.Objects; +import java.util.function.Supplier; + +public class EISUnifiedCompletionRequestManager extends ElasticInferenceServiceRequestManager { + + private static final Logger logger = LogManager.getLogger(EISUnifiedCompletionRequestManager.class); + + private static final ResponseHandler HANDLER = createCompletionHandler(); + + public static EISUnifiedCompletionRequestManager of(ElasticInferenceServiceCompletionModel model, ThreadPool threadPool) { + return new EISUnifiedCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final ElasticInferenceServiceCompletionModel model; + + private EISUnifiedCompletionRequestManager(ElasticInferenceServiceCompletionModel model, ThreadPool threadPool) { + super(threadPool, model); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + EISUnifiedChatCompletionRequest request = new EISUnifiedChatCompletionRequest( + inferenceInputs.castTo(UnifiedChatInput.class), + model, + null // TODO + ); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } + + private static ResponseHandler createCompletionHandler() { + return new EISUnifiedChatCompletionResponseHandler("eis completion", OpenAiChatCompletionResponseEntity::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java new file mode 100644 index 0000000000000..73eacc09d83d1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.common.Strings; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiRequest; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +public class EISUnifiedChatCompletionRequest implements OpenAiRequest { + + private final ElasticInferenceServiceCompletionModel model; + private final UnifiedChatInput unifiedChatInput; + private final URI uri; + private final TraceContext traceContext; + + public EISUnifiedChatCompletionRequest( + UnifiedChatInput unifiedChatInput, + ElasticInferenceServiceCompletionModel model, + TraceContext traceContext + ) { + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + this.model = Objects.requireNonNull(model); + this.uri = model.uri(); + this.traceContext = traceContext; + + } + + @Override + public HttpRequest createHttpRequest() { + var httpPost = new HttpPost(uri); + var requestEntity = Strings.toString(new EISUnifiedChatCompletionRequestEntity(unifiedChatInput)); + + ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + if (traceContext != null) { + propagateTraceContext(httpPost); + } + + httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return uri; + } + + @Override + public Request truncate() { + // No truncation for OpenAI chat completions + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for OpenAI chat completions + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public boolean isStreaming() { + return true; + } + + public TraceContext getTraceContext() { + return traceContext; + } + + private void propagateTraceContext(HttpPost httpPost) { + var traceParent = traceContext.traceParent(); + var traceState = traceContext.traceState(); + + if (traceParent != null) { + httpPost.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent); + } + + if (traceState != null) { + httpPost.setHeader(Task.TRACE_STATE, traceState); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..727d5e82640c6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java @@ -0,0 +1,176 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; + +import java.io.IOException; +import java.util.Objects; + +public class EISUnifiedChatCompletionRequestEntity implements ToXContentObject { + + public static final String NAME_FIELD = "name"; + public static final String TOOL_CALL_ID_FIELD = "tool_call_id"; + public static final String TOOL_CALLS_FIELD = "tool_calls"; + public static final String ID_FIELD = "id"; + public static final String FUNCTION_FIELD = "function"; + public static final String ARGUMENTS_FIELD = "arguments"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String PARAMETERS_FIELD = "parameters"; + public static final String STRICT_FIELD = "strict"; + public static final String TOP_P_FIELD = "top_p"; + public static final String USER_FIELD = "user"; + public static final String STREAM_FIELD = "stream"; + private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; + private static final String MODEL_FIELD = "model"; + public static final String MESSAGES_FIELD = "messages"; + private static final String ROLE_FIELD = "role"; + private static final String CONTENT_FIELD = "content"; + private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; + private static final String STOP_FIELD = "stop"; + private static final String TEMPERATURE_FIELD = "temperature"; + private static final String TOOL_CHOICE_FIELD = "tool_choice"; + private static final String TOOL_FIELD = "tools"; + private static final String TEXT_FIELD = "text"; + private static final String TYPE_FIELD = "type"; + private static final String STREAM_OPTIONS_FIELD = "stream_options"; + private static final String INCLUDE_USAGE_FIELD = "include_usage"; + + private final UnifiedCompletionRequest unifiedRequest; + private final boolean stream; + + public EISUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) { + Objects.requireNonNull(unifiedChatInput); + + this.unifiedRequest = unifiedChatInput.getRequest(); + this.stream = unifiedChatInput.stream(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startArray(MESSAGES_FIELD); + { + for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) { + builder.startObject(); + { + switch (message.content()) { + case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content()); + case UnifiedCompletionRequest.ContentObjects contentObjects -> { + builder.startArray(CONTENT_FIELD); + for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) { + builder.startObject(); + builder.field(TEXT_FIELD, contentObject.text()); + builder.field(TYPE_FIELD, contentObject.type()); + builder.endObject(); + } + builder.endArray(); + } + } + + builder.field(ROLE_FIELD, message.role()); + if (message.name() != null) { + builder.field(NAME_FIELD, message.name()); + } + if (message.toolCallId() != null) { + builder.field(TOOL_CALL_ID_FIELD, message.toolCallId()); + } + if (message.toolCalls() != null) { + builder.startArray(TOOL_CALLS_FIELD); + for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) { + builder.startObject(); + { + builder.field(ID_FIELD, toolCall.id()); + builder.startObject(FUNCTION_FIELD); + { + builder.field(ARGUMENTS_FIELD, toolCall.function().arguments()); + builder.field(NAME_FIELD, toolCall.function().name()); + } + builder.endObject(); + builder.field(TYPE_FIELD, toolCall.type()); + } + builder.endObject(); + } + builder.endArray(); + } + } + builder.endObject(); + } + } + builder.endArray(); + + if (unifiedRequest.maxCompletionTokens() != null) { + builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens()); + } + + builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); + + if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) { + builder.field(STOP_FIELD, unifiedRequest.stop()); + } + if (unifiedRequest.temperature() != null) { + builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature()); + } + if (unifiedRequest.toolChoice() != null) { + if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) { + builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value()); + } else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) { + builder.startObject(TOOL_CHOICE_FIELD); + { + builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type()); + builder.startObject(FUNCTION_FIELD); + { + builder.field( + NAME_FIELD, + ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name() + ); + } + builder.endObject(); + } + builder.endObject(); + } + } + if (unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false) { + builder.startArray(TOOL_FIELD); + for (UnifiedCompletionRequest.Tool t : unifiedRequest.tools()) { + builder.startObject(); + { + builder.field(TYPE_FIELD, t.type()); + builder.startObject(FUNCTION_FIELD); + { + builder.field(DESCRIPTION_FIELD, t.function().description()); + builder.field(NAME_FIELD, t.function().name()); + builder.field(PARAMETERS_FIELD, t.function().parameters()); + if (t.function().strict() != null) { + builder.field(STRICT_FIELD, t.function().strict()); + } + } + builder.endObject(); + } + builder.endObject(); + } + builder.endArray(); + } + if (unifiedRequest.topP() != null) { + builder.field(TOP_P_FIELD, unifiedRequest.topP()); + } + + builder.field(STREAM_FIELD, stream); + if (stream) { + builder.startObject(STREAM_OPTIONS_FIELD); + builder.field(INCLUDE_USAGE_FIELD, true); + builder.endObject(); + } + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 13a0efb3d3d23..48416faac6a06 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -37,7 +37,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.ElasticInferenceServiceUnifiedCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; -import org.elasticsearch.xpack.inference.external.http.sender.OpenAiUnifiedCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/EISUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/EISUnifiedChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..64e875e3dd4ef --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/EISUnifiedChatCompletionRequestEntityTests.java @@ -0,0 +1,891 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.openai; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Random; + +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; +import static org.hamcrest.Matchers.equalTo; + +public class EISUnifiedChatCompletionRequestEntityTests extends ESTestCase { + + // TODO these tests were copied from the OpenAI tests and need to cleaned up and improved to be correct for EIS + + // 1. Basic Serialization + // Test with minimal required fields to ensure basic serialization works. + public void testBasicSerialization() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 2. Serialization with All Fields + // Test with all possible fields populated to ensure complete serialization. + public void testSerializationWithAllFields() throws IOException { + // Create a message with all fields populated + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "name", + "tool_call_id", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments", "function_name"), + "type" + ) + ) + ); + + // Create a tool with all fields populated + UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( + "type", + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with all fields populated + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + "model", + 100L, // maxCompletionTokens + Collections.singletonList("stop"), + 0.9f, // temperature + new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), + Collections.singletonList(tool), + 0.8f // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) + ); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user", + "name": "name", + "tool_call_id": "tool_call_id", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "arguments", + "name": "function_name" + }, + "type": "type" + } + ] + } + ], + "model": "model-name", + "max_completion_tokens": 100, + "n": 1, + "stop": ["stop"], + "temperature": 0.9, + "tool_choice": "tool_choice", + "tools": [ + { + "type": "type", + "function": { + "description": "Fetches the weather in the given location", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "description": "The location to get the weather for", + "type": "string" + }, + "unit": { + "description": "The unit to return the temperature in", + "type": "string", + "enum": ["F", "C"] + } + }, + "additionalProperties": false, + "required": ["location", "unit"] + }, + "strict": true + } + } + ], + "top_p": 0.8, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + + } + + // 3. Serialization with Null Optional Fields + // Test with optional fields set to null to ensure they are correctly omitted from the output. + public void testSerializationWithNullOptionalFields() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + // Create the unified request with optional fields set to null + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) + ); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 4. Serialization with Empty Lists + // Test with fields that are lists set to empty lists to ensure they are correctly serialized. + public void testSerializationWithEmptyLists() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + Collections.emptyList() // empty toolCalls list + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with empty lists + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + Collections.emptyList(), // empty stop list + null, // temperature + null, // toolChoice + Collections.emptyList(), // empty tools list + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) + ); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user", + "tool_calls": [] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 5. Serialization with Nested Objects + // Test with nested objects (e.g., toolCalls, toolChoice, tool) to ensure they are correctly serialized. + public void testSerializationWithNestedObjects() throws IOException { + Random random = Randomness.get(); + + // Generate random values + String randomContent = "Hello, world! " + random.nextInt(1000); + String randomName = "name" + random.nextInt(1000); + String randomToolCallId = "tool_call_id" + random.nextInt(1000); + String randomArguments = "arguments" + random.nextInt(1000); + String randomFunctionName = "function_name" + random.nextInt(1000); + String randomType = "type" + random.nextInt(1000); + String randomModel = "model" + random.nextInt(1000); + String randomStop = "stop" + random.nextInt(1000); + float randomTemperature = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); + float randomTopP = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); + + // Create a message with nested toolCalls + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(randomContent), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + randomName, + randomToolCallId, + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id", + new UnifiedCompletionRequest.ToolCall.FunctionField(randomArguments, randomFunctionName), + randomType + ) + ) + ); + + // Create a tool with nested function fields + UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( + randomType, + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with nested objects + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + randomModel, + 100L, // maxCompletionTokens + Collections.singletonList(randomStop), + randomTemperature, // temperature + new UnifiedCompletionRequest.ToolChoiceObject( + randomType, + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomFunctionName) + ), + Collections.singletonList(tool), + randomTopP // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", randomModel, null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) + ); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + // Expected JSON should be dynamically generated based on random values + String expectedJson = String.format( + Locale.US, + """ + { + "messages": [ + { + "content": "%s", + "role": "user", + "name": "%s", + "tool_call_id": "%s", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "%s", + "name": "%s" + }, + "type": "%s" + } + ] + } + ], + "model": "%s", + "max_completion_tokens": 100, + "n": 1, + "stop": ["%s"], + "temperature": %.5f, + "tool_choice": { + "type": "%s", + "function": { + "name": "%s" + } + }, + "tools": [ + { + "type": "%s", + "function": { + "description": "Fetches the weather in the given location", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "unit": { + "description": "The unit to return the temperature in", + "type": "string", + "enum": ["F", "C"] + }, + "location": { + "description": "The location to get the weather for", + "type": "string" + } + }, + "additionalProperties": false, + "required": ["location", "unit"] + }, + "strict": true + } + } + ], + "top_p": %.5f, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """, + randomContent, + randomName, + randomToolCallId, + randomArguments, + randomFunctionName, + randomType, + randomModel, + randomStop, + randomTemperature, + randomType, + randomFunctionName, + randomType, + randomTopP + ); + assertJsonEquals(jsonString, expectedJson); + } + + // 6. Serialization with Different Content Types + // Test with different content types in messages (e.g., ContentString, ContentObjects) to ensure they are correctly serialized. + public void testSerializationWithDifferentContentTypes() throws IOException { + Random random = Randomness.get(); + + // Generate random values for ContentString + String randomContentString = "Hello, world! " + random.nextInt(1000); + + // Generate random values for ContentObjects + String randomText = "Random text " + random.nextInt(1000); + String randomType = "type" + random.nextInt(1000); + UnifiedCompletionRequest.ContentObject contentObject = new UnifiedCompletionRequest.ContentObject(randomText, randomType); + + var contentObjectsList = new ArrayList(); + contentObjectsList.add(contentObject); + UnifiedCompletionRequest.ContentObjects contentObjects = new UnifiedCompletionRequest.ContentObjects(contentObjectsList); + + // Create messages with different content types + UnifiedCompletionRequest.Message messageWithString = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(randomContentString), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + + UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message( + contentObjects, + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(messageWithString); + messageList.add(messageWithObjects); + + // Create the unified request with both types of messages + UnifiedCompletionRequest unifiedRequest = UnifiedCompletionRequest.of(messageList); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) + ); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = String.format(Locale.US, """ + { + "messages": [ + { + "content": "%s", + "role": "user" + }, + { + "content": [ + { + "text": "%s", + "type": "%s" + } + ], + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """, randomContentString, randomText, randomType); + assertJsonEquals(jsonString, expectedJson); + } + + // 7. Serialization with Special Characters + // Test with special characters in string fields to ensure they are correctly escaped and serialized. + public void testSerializationWithSpecialCharacters() throws IOException { + // Create a message with special characters + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world! \n \"Special\" characters: \t \\ /"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "name\nwith\nnewlines", + "tool_call_id\twith\ttabs", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id\\with\\backslashes", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), + "type" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) + ); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world! \\n \\"Special\\" characters: \\t \\\\ /", + "role": "user", + "name": "name\\nwith\\nnewlines", + "tool_call_id": "tool_call_id\\twith\\ttabs", + "tool_calls": [ + { + "id": "id\\\\with\\\\backslashes", + "function": { + "arguments": "arguments\\"with\\"quotes", + "name": "function_name/with/slashes" + }, + "type": "type" + } + ] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 8. Serialization with Boolean Fields + // Test with boolean fields (stream) set to both true and false to ensure they are correctly serialized. + public void testSerializationWithBooleanFields() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Test with stream set to true + UnifiedChatInput unifiedChatInputTrue = new UnifiedChatInput(unifiedRequest, true); + OpenAiUnifiedChatCompletionRequestEntity entityTrue = new OpenAiUnifiedChatCompletionRequestEntity( + unifiedChatInputTrue, + new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) + ); + + XContentBuilder builderTrue = JsonXContent.contentBuilder(); + entityTrue.toXContent(builderTrue, ToXContent.EMPTY_PARAMS); + + String jsonStringTrue = Strings.toString(builderTrue); + String expectedJsonTrue = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(expectedJsonTrue, jsonStringTrue); + + // Test with stream set to false + UnifiedChatInput unifiedChatInputFalse = new UnifiedChatInput(unifiedRequest, false); + OpenAiUnifiedChatCompletionRequestEntity entityFalse = new OpenAiUnifiedChatCompletionRequestEntity( + unifiedChatInputFalse, + new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) + ); + + XContentBuilder builderFalse = JsonXContent.contentBuilder(); + entityFalse.toXContent(builderFalse, ToXContent.EMPTY_PARAMS); + + String jsonStringFalse = Strings.toString(builderFalse); + String expectedJsonFalse = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": false + } + """; + assertJsonEquals(expectedJsonFalse, jsonStringFalse); + } + + // 9. Serialization with Missing Required Fields + // Test with missing required fields to ensure appropriate exceptions are thrown. + public void testSerializationWithMissingRequiredFields() { + // Create a message with missing content (required field) + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + null, // missing content + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) + ); + + // Attempt to serialize to XContent and expect an exception + try { + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + fail("Expected an exception due to missing required fields"); + } catch (NullPointerException | IOException e) { + // Expected exception + } + } + + // 10. Serialization with Mixed Valid and Invalid Data + // Test with a mix of valid and invalid data to ensure the serializer handles it gracefully. + public void testSerializationWithMixedValidAndInvalidData() throws IOException { + // Create a valid message + UnifiedCompletionRequest.Message validMessage = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Valid content"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "validName", + "validToolCallId", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "validId", + new UnifiedCompletionRequest.ToolCall.FunctionField("validArguments", "validFunctionName"), + "validType" + ) + ) + ); + + // Create an invalid message with null content + UnifiedCompletionRequest.Message invalidMessage = new UnifiedCompletionRequest.Message( + null, // invalid content + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "invalidName", + "invalidToolCallId", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "invalidId", + new UnifiedCompletionRequest.ToolCall.FunctionField("invalidArguments", "invalidFunctionName"), + "invalidType" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(validMessage); + messageList.add(invalidMessage); + // Create the unified request with both valid and invalid messages + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + "model-name", + 100L, // maxCompletionTokens + Collections.singletonList("stop"), + 0.9f, // temperature + new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), + Collections.singletonList( + new UnifiedCompletionRequest.Tool( + "type", + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ) + ), + 0.8f // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) + ); + + // Serialize to XContent and verify + try { + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + fail("Expected an exception due to invalid data"); + } catch (NullPointerException | IOException e) { + // Expected exception + } + } + + public static Map createParameters() { + Map parameters = new LinkedHashMap<>(); + parameters.put("type", "object"); + + Map properties = new HashMap<>(); + + Map location = new HashMap<>(); + location.put("type", "string"); + location.put("description", "The location to get the weather for"); + properties.put("location", location); + + Map unit = new HashMap<>(); + unit.put("type", "string"); + unit.put("description", "The unit to return the temperature in"); + unit.put("enum", new String[] { "F", "C" }); + properties.put("unit", unit); + + parameters.put("properties", properties); + parameters.put("additionalProperties", false); + parameters.put("required", new String[] { "location", "unit" }); + + return parameters; + } + + private void assertJsonEquals(String actual, String expected) throws IOException { + try ( + var actualParser = createParser(JsonXContent.jsonXContent, actual); + var expectedParser = createParser(JsonXContent.jsonXContent, expected) + ) { + assertThat(actualParser.mapOrdered(), equalTo(expectedParser.mapOrdered())); + } + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java index f43b185391697..b4ecf367cc485 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java @@ -43,7 +43,10 @@ public void testModelUserFieldsSerialization() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", USER); - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("Open AI model", null) + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); From 31f3f2c671233a1c1e1d920ee9ac96eea0407853 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 9 Dec 2024 16:25:39 -0500 Subject: [PATCH 04/42] Working response from openai --- ...SUnifiedChatCompletionResponseHandler.java | 5 + .../EISUnifiedCompletionRequestManager.java | 25 +- .../EISUnifiedChatCompletionRequest.java | 9 +- ...EISUnifiedChatCompletionRequestEntity.java | 156 +--- .../OpenAiUnifiedChatCompletionRequest.java | 10 +- ...nAiUnifiedChatCompletionRequestEntity.java | 2 - .../ElasticInferenceServiceSettings.java | 5 +- ...lasticInferenceServiceCompletionModel.java | 6 +- ...renceServiceCompletionServiceSettings.java | 2 +- ...ifiedChatCompletionRequestEntityTests.java | 842 +----------------- ...ifiedChatCompletionRequestEntityTests.java | 5 +- 11 files changed, 64 insertions(+), 1003 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/EISUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/EISUnifiedChatCompletionResponseHandler.java index 540e1484bc179..f1007b1c30d74 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/EISUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/EISUnifiedChatCompletionResponseHandler.java @@ -23,6 +23,11 @@ public EISUnifiedChatCompletionResponseHandler(String requestType, ResponseParse super(requestType, parseFunction); } + @Override + public boolean canHandleStreamingResponses() { + return true; + } + @Override public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java index 3445736e70f15..9bfdc7fc7e3ef 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.inference.external.request.elastic.EISUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.Objects; import java.util.function.Supplier; @@ -28,15 +29,29 @@ public class EISUnifiedCompletionRequestManager extends ElasticInferenceServiceR private static final ResponseHandler HANDLER = createCompletionHandler(); - public static EISUnifiedCompletionRequestManager of(ElasticInferenceServiceCompletionModel model, ThreadPool threadPool) { - return new EISUnifiedCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + public static EISUnifiedCompletionRequestManager of( + ElasticInferenceServiceCompletionModel model, + ThreadPool threadPool, + TraceContext traceContext + ) { + return new EISUnifiedCompletionRequestManager( + Objects.requireNonNull(model), + Objects.requireNonNull(threadPool), + Objects.requireNonNull(traceContext) + ); } private final ElasticInferenceServiceCompletionModel model; + private final TraceContext traceContext; - private EISUnifiedCompletionRequestManager(ElasticInferenceServiceCompletionModel model, ThreadPool threadPool) { + private EISUnifiedCompletionRequestManager( + ElasticInferenceServiceCompletionModel model, + ThreadPool threadPool, + TraceContext traceContext + ) { super(threadPool, model); - this.model = Objects.requireNonNull(model); + this.model = model; + this.traceContext = traceContext; } @Override @@ -50,7 +65,7 @@ public void execute( EISUnifiedChatCompletionRequest request = new EISUnifiedChatCompletionRequest( inferenceInputs.castTo(UnifiedChatInput.class), model, - null // TODO + traceContext ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java index 73eacc09d83d1..0d778def96723 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java @@ -25,6 +25,8 @@ import java.nio.charset.StandardCharsets; import java.util.Objects; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + public class EISUnifiedChatCompletionRequest implements OpenAiRequest { private final ElasticInferenceServiceCompletionModel model; @@ -47,7 +49,10 @@ public EISUnifiedChatCompletionRequest( @Override public HttpRequest createHttpRequest() { var httpPost = new HttpPost(uri); - var requestEntity = Strings.toString(new EISUnifiedChatCompletionRequestEntity(unifiedChatInput)); + var requestEntity = Strings.toString( + // TODO remove the modelId() call if not used + new EISUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()) + ); ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); httpPost.setEntity(byteEntity); @@ -57,6 +62,8 @@ public HttpRequest createHttpRequest() { } httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); + // TODO remove EIS doesn't use an API key + httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey())); return new HttpRequest(httpPost, getInferenceEntityId()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java index 727d5e82640c6..64120693172cb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java @@ -7,168 +7,32 @@ package org.elasticsearch.xpack.inference.external.request.elastic; -import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; import java.io.IOException; import java.util.Objects; public class EISUnifiedChatCompletionRequestEntity implements ToXContentObject { - - public static final String NAME_FIELD = "name"; - public static final String TOOL_CALL_ID_FIELD = "tool_call_id"; - public static final String TOOL_CALLS_FIELD = "tool_calls"; - public static final String ID_FIELD = "id"; - public static final String FUNCTION_FIELD = "function"; - public static final String ARGUMENTS_FIELD = "arguments"; - public static final String DESCRIPTION_FIELD = "description"; - public static final String PARAMETERS_FIELD = "parameters"; - public static final String STRICT_FIELD = "strict"; - public static final String TOP_P_FIELD = "top_p"; - public static final String USER_FIELD = "user"; - public static final String STREAM_FIELD = "stream"; - private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; + // TODO remove this if EIS doesn't use it private static final String MODEL_FIELD = "model"; - public static final String MESSAGES_FIELD = "messages"; - private static final String ROLE_FIELD = "role"; - private static final String CONTENT_FIELD = "content"; - private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; - private static final String STOP_FIELD = "stop"; - private static final String TEMPERATURE_FIELD = "temperature"; - private static final String TOOL_CHOICE_FIELD = "tool_choice"; - private static final String TOOL_FIELD = "tools"; - private static final String TEXT_FIELD = "text"; - private static final String TYPE_FIELD = "type"; - private static final String STREAM_OPTIONS_FIELD = "stream_options"; - private static final String INCLUDE_USAGE_FIELD = "include_usage"; - - private final UnifiedCompletionRequest unifiedRequest; - private final boolean stream; - public EISUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) { - Objects.requireNonNull(unifiedChatInput); + private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; + private final String modelId; - this.unifiedRequest = unifiedChatInput.getRequest(); - this.stream = unifiedChatInput.stream(); + public EISUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) { + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput)); + this.modelId = Objects.requireNonNull(modelId); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.startArray(MESSAGES_FIELD); - { - for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) { - builder.startObject(); - { - switch (message.content()) { - case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content()); - case UnifiedCompletionRequest.ContentObjects contentObjects -> { - builder.startArray(CONTENT_FIELD); - for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) { - builder.startObject(); - builder.field(TEXT_FIELD, contentObject.text()); - builder.field(TYPE_FIELD, contentObject.type()); - builder.endObject(); - } - builder.endArray(); - } - } - - builder.field(ROLE_FIELD, message.role()); - if (message.name() != null) { - builder.field(NAME_FIELD, message.name()); - } - if (message.toolCallId() != null) { - builder.field(TOOL_CALL_ID_FIELD, message.toolCallId()); - } - if (message.toolCalls() != null) { - builder.startArray(TOOL_CALLS_FIELD); - for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) { - builder.startObject(); - { - builder.field(ID_FIELD, toolCall.id()); - builder.startObject(FUNCTION_FIELD); - { - builder.field(ARGUMENTS_FIELD, toolCall.function().arguments()); - builder.field(NAME_FIELD, toolCall.function().name()); - } - builder.endObject(); - builder.field(TYPE_FIELD, toolCall.type()); - } - builder.endObject(); - } - builder.endArray(); - } - } - builder.endObject(); - } - } - builder.endArray(); - - if (unifiedRequest.maxCompletionTokens() != null) { - builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens()); - } - - builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); - - if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) { - builder.field(STOP_FIELD, unifiedRequest.stop()); - } - if (unifiedRequest.temperature() != null) { - builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature()); - } - if (unifiedRequest.toolChoice() != null) { - if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) { - builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value()); - } else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) { - builder.startObject(TOOL_CHOICE_FIELD); - { - builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type()); - builder.startObject(FUNCTION_FIELD); - { - builder.field( - NAME_FIELD, - ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name() - ); - } - builder.endObject(); - } - builder.endObject(); - } - } - if (unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false) { - builder.startArray(TOOL_FIELD); - for (UnifiedCompletionRequest.Tool t : unifiedRequest.tools()) { - builder.startObject(); - { - builder.field(TYPE_FIELD, t.type()); - builder.startObject(FUNCTION_FIELD); - { - builder.field(DESCRIPTION_FIELD, t.function().description()); - builder.field(NAME_FIELD, t.function().name()); - builder.field(PARAMETERS_FIELD, t.function().parameters()); - if (t.function().strict() != null) { - builder.field(STRICT_FIELD, t.function().strict()); - } - } - builder.endObject(); - } - builder.endObject(); - } - builder.endArray(); - } - if (unifiedRequest.topP() != null) { - builder.field(TOP_P_FIELD, unifiedRequest.topP()); - } - - builder.field(STREAM_FIELD, stream); - if (stream) { - builder.startObject(STREAM_OPTIONS_FIELD); - builder.field(INCLUDE_USAGE_FIELD, true); - builder.endObject(); - } + unifiedRequestEntity.toXContent(builder, params); + // TODO remove this if EIS doesn't use it + builder.field(MODEL_FIELD, modelId); builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java index a547b4c1c52f5..e5b85633a499b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java @@ -44,15 +44,7 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString( - new OpenAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - new OpenAiUnifiedChatCompletionRequestEntity.ModelFields( - model.getServiceSettings().modelId(), - model.getTaskSettings().user() - ) - ) - ).getBytes(StandardCharsets.UTF_8) + Strings.toString(new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index 4f50d97107ded..fd05e1b98ac32 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -30,8 +30,6 @@ public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInpu this.model = Objects.requireNonNull(model); } - public record ModelFields(String modelId, @Nullable String user) {} - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java index 5146cec1552af..c4b846cda980c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -34,8 +34,9 @@ public class ElasticInferenceServiceSettings { public ElasticInferenceServiceSettings(Settings settings) { eisGatewayUrl = EIS_GATEWAY_URL.get(settings); - elasticInferenceServiceUrl = ELASTIC_INFERENCE_SERVICE_URL.get(settings); - + // TODO fix this + // elasticInferenceServiceUrl = ELASTIC_INFERENCE_SERVICE_URL.get(settings); + elasticInferenceServiceUrl = "abc"; } public static List> getSettingsDefinitions() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java index 04b815df988d0..6c0b4284fab23 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java @@ -9,7 +9,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -37,7 +36,7 @@ public static ElasticInferenceServiceCompletionModel of( import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.net.URI; import java.net.URISyntaxException; @@ -77,7 +76,8 @@ public ElasticInferenceServiceCompletionModel( service, ElasticInferenceServiceCompletionServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, + // TODO remove this as EIS doesn't use it + DefaultSecretSettings.fromMap(secrets), elasticInferenceServiceComponents ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java index ef0d7958ec184..05e5b2a12a7a9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java @@ -18,7 +18,6 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceRateLimitServiceSettings; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -42,6 +41,7 @@ public class ElasticInferenceServiceCompletionServiceSettings extends FilteredXC public static ElasticInferenceServiceCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); + // TODO does EIS have this? String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of( map, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/EISUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/EISUnifiedChatCompletionRequestEntityTests.java index 64e875e3dd4ef..ccea29bcc66dc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/EISUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/EISUnifiedChatCompletionRequestEntityTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.request.openai; -import org.elasticsearch.common.Randomness; import org.elasticsearch.common.Strings; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; @@ -19,229 +18,37 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.Locale; -import java.util.Map; -import java.util.Random; +import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; -import static org.hamcrest.Matchers.equalTo; public class EISUnifiedChatCompletionRequestEntityTests extends ESTestCase { - // TODO these tests were copied from the OpenAI tests and need to cleaned up and improved to be correct for EIS + private static final String ROLE = "user"; + private static final String USER = "a_user"; - // 1. Basic Serialization - // Test with minimal required fields to ensure basic serialization works. - public void testBasicSerialization() throws IOException { + // TODO remove if EIS doesn't use the model and user fields + public void testModelUserFieldsSerialization() throws IOException { UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( new UnifiedCompletionRequest.ContentString("Hello, world!"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + ROLE, null, null, null ); var messageList = new ArrayList(); messageList.add(message); - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); - - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) - ); - - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - - String jsonString = Strings.toString(builder); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user" - } - ], - "model": "test-endpoint", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - } - - // 2. Serialization with All Fields - // Test with all possible fields populated to ensure complete serialization. - public void testSerializationWithAllFields() throws IOException { - // Create a message with all fields populated - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - "name", - "tool_call_id", - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "id", - new UnifiedCompletionRequest.ToolCall.FunctionField("arguments", "function_name"), - "type" - ) - ) - ); - - // Create a tool with all fields populated - UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( - "type", - new UnifiedCompletionRequest.Tool.FunctionField( - "Fetches the weather in the given location", - "get_weather", - createParameters(), - true - ) - ); - var messageList = new ArrayList(); - messageList.add(message); - // Create the unified request with all fields populated - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - "model", - 100L, // maxCompletionTokens - Collections.singletonList("stop"), - 0.9f, // temperature - new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), - Collections.singletonList(tool), - 0.8f // topP - ); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) - ); - - // Serialize to XContent - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - - // Convert to string and verify - String jsonString = Strings.toString(builder); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user", - "name": "name", - "tool_call_id": "tool_call_id", - "tool_calls": [ - { - "id": "id", - "function": { - "arguments": "arguments", - "name": "function_name" - }, - "type": "type" - } - ] - } - ], - "model": "model-name", - "max_completion_tokens": 100, - "n": 1, - "stop": ["stop"], - "temperature": 0.9, - "tool_choice": "tool_choice", - "tools": [ - { - "type": "type", - "function": { - "description": "Fetches the weather in the given location", - "name": "get_weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "description": "The location to get the weather for", - "type": "string" - }, - "unit": { - "description": "The unit to return the temperature in", - "type": "string", - "enum": ["F", "C"] - } - }, - "additionalProperties": false, - "required": ["location", "unit"] - }, - "strict": true - } - } - ], - "top_p": 0.8, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - - } - - // 3. Serialization with Null Optional Fields - // Test with optional fields set to null to ensure they are correctly omitted from the output. - public void testSerializationWithNullOptionalFields() throws IOException { - // Create a message with minimal required fields - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - null, - null, - null - ); - var messageList = new ArrayList(); - messageList.add(message); + var unifiedRequest = UnifiedCompletionRequest.of(messageList); - // Create the unified request with optional fields set to null - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxCompletionTokens - null, // stop - null, // temperature - null, // toolChoice - null, // tools - null // topP - ); - - // Create the unified chat input UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", USER); - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) - ); - - // Serialize to XContent XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - // Convert to string and verify String jsonString = Strings.toString(builder); String expectedJson = """ { @@ -251,641 +58,16 @@ public void testSerializationWithNullOptionalFields() throws IOException { "role": "user" } ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - } - - // 4. Serialization with Empty Lists - // Test with fields that are lists set to empty lists to ensure they are correctly serialized. - public void testSerializationWithEmptyLists() throws IOException { - // Create a message with minimal required fields - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - null, - null, - Collections.emptyList() // empty toolCalls list - ); - var messageList = new ArrayList(); - messageList.add(message); - // Create the unified request with empty lists - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxCompletionTokens - Collections.emptyList(), // empty stop list - null, // temperature - null, // toolChoice - Collections.emptyList(), // empty tools list - null // topP - ); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) - ); - - // Serialize to XContent - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - - // Convert to string and verify - String jsonString = Strings.toString(builder); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user", - "tool_calls": [] - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - } - - // 5. Serialization with Nested Objects - // Test with nested objects (e.g., toolCalls, toolChoice, tool) to ensure they are correctly serialized. - public void testSerializationWithNestedObjects() throws IOException { - Random random = Randomness.get(); - - // Generate random values - String randomContent = "Hello, world! " + random.nextInt(1000); - String randomName = "name" + random.nextInt(1000); - String randomToolCallId = "tool_call_id" + random.nextInt(1000); - String randomArguments = "arguments" + random.nextInt(1000); - String randomFunctionName = "function_name" + random.nextInt(1000); - String randomType = "type" + random.nextInt(1000); - String randomModel = "model" + random.nextInt(1000); - String randomStop = "stop" + random.nextInt(1000); - float randomTemperature = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); - float randomTopP = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); - - // Create a message with nested toolCalls - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString(randomContent), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - randomName, - randomToolCallId, - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "id", - new UnifiedCompletionRequest.ToolCall.FunctionField(randomArguments, randomFunctionName), - randomType - ) - ) - ); - - // Create a tool with nested function fields - UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( - randomType, - new UnifiedCompletionRequest.Tool.FunctionField( - "Fetches the weather in the given location", - "get_weather", - createParameters(), - true - ) - ); - var messageList = new ArrayList(); - messageList.add(message); - // Create the unified request with nested objects - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - randomModel, - 100L, // maxCompletionTokens - Collections.singletonList(randomStop), - randomTemperature, // temperature - new UnifiedCompletionRequest.ToolChoiceObject( - randomType, - new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomFunctionName) - ), - Collections.singletonList(tool), - randomTopP // topP - ); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", randomModel, null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) - ); - - // Serialize to XContent - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - - // Convert to string and verify - String jsonString = Strings.toString(builder); - // Expected JSON should be dynamically generated based on random values - String expectedJson = String.format( - Locale.US, - """ - { - "messages": [ - { - "content": "%s", - "role": "user", - "name": "%s", - "tool_call_id": "%s", - "tool_calls": [ - { - "id": "id", - "function": { - "arguments": "%s", - "name": "%s" - }, - "type": "%s" - } - ] - } - ], - "model": "%s", - "max_completion_tokens": 100, - "n": 1, - "stop": ["%s"], - "temperature": %.5f, - "tool_choice": { - "type": "%s", - "function": { - "name": "%s" - } - }, - "tools": [ - { - "type": "%s", - "function": { - "description": "Fetches the weather in the given location", - "name": "get_weather", - "parameters": { - "type": "object", - "properties": { - "unit": { - "description": "The unit to return the temperature in", - "type": "string", - "enum": ["F", "C"] - }, - "location": { - "description": "The location to get the weather for", - "type": "string" - } - }, - "additionalProperties": false, - "required": ["location", "unit"] - }, - "strict": true - } - } - ], - "top_p": %.5f, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """, - randomContent, - randomName, - randomToolCallId, - randomArguments, - randomFunctionName, - randomType, - randomModel, - randomStop, - randomTemperature, - randomType, - randomFunctionName, - randomType, - randomTopP - ); - assertJsonEquals(jsonString, expectedJson); - } - - // 6. Serialization with Different Content Types - // Test with different content types in messages (e.g., ContentString, ContentObjects) to ensure they are correctly serialized. - public void testSerializationWithDifferentContentTypes() throws IOException { - Random random = Randomness.get(); - - // Generate random values for ContentString - String randomContentString = "Hello, world! " + random.nextInt(1000); - - // Generate random values for ContentObjects - String randomText = "Random text " + random.nextInt(1000); - String randomType = "type" + random.nextInt(1000); - UnifiedCompletionRequest.ContentObject contentObject = new UnifiedCompletionRequest.ContentObject(randomText, randomType); - - var contentObjectsList = new ArrayList(); - contentObjectsList.add(contentObject); - UnifiedCompletionRequest.ContentObjects contentObjects = new UnifiedCompletionRequest.ContentObjects(contentObjectsList); - - // Create messages with different content types - UnifiedCompletionRequest.Message messageWithString = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString(randomContentString), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - null, - null, - null - ); - - UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message( - contentObjects, - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - null, - null, - null - ); - var messageList = new ArrayList(); - messageList.add(messageWithString); - messageList.add(messageWithObjects); - - // Create the unified request with both types of messages - UnifiedCompletionRequest unifiedRequest = UnifiedCompletionRequest.of(messageList); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) - ); - - // Serialize to XContent - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - - // Convert to string and verify - String jsonString = Strings.toString(builder); - String expectedJson = String.format(Locale.US, """ - { - "messages": [ - { - "content": "%s", - "role": "user" - }, - { - "content": [ - { - "text": "%s", - "type": "%s" - } - ], - "role": "user" - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """, randomContentString, randomText, randomType); - assertJsonEquals(jsonString, expectedJson); - } - - // 7. Serialization with Special Characters - // Test with special characters in string fields to ensure they are correctly escaped and serialized. - public void testSerializationWithSpecialCharacters() throws IOException { - // Create a message with special characters - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world! \n \"Special\" characters: \t \\ /"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - "name\nwith\nnewlines", - "tool_call_id\twith\ttabs", - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "id\\with\\backslashes", - new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), - "type" - ) - ) - ); - var messageList = new ArrayList(); - messageList.add(message); - // Create the unified request - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxCompletionTokens - null, // stop - null, // temperature - null, // toolChoice - null, // tools - null // topP - ); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) - ); - - // Serialize to XContent - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - - // Convert to string and verify - String jsonString = Strings.toString(builder); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world! \\n \\"Special\\" characters: \\t \\\\ /", - "role": "user", - "name": "name\\nwith\\nnewlines", - "tool_call_id": "tool_call_id\\twith\\ttabs", - "tool_calls": [ - { - "id": "id\\\\with\\\\backslashes", - "function": { - "arguments": "arguments\\"with\\"quotes", - "name": "function_name/with/slashes" - }, - "type": "type" - } - ] - } - ], - "model": "model-name", + "model": "test-endpoint", "n": 1, "stream": true, "stream_options": { "include_usage": true - } + }, + "user": "a_user" } """; assertJsonEquals(jsonString, expectedJson); } - // 8. Serialization with Boolean Fields - // Test with boolean fields (stream) set to both true and false to ensure they are correctly serialized. - public void testSerializationWithBooleanFields() throws IOException { - // Create a message with minimal required fields - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - null, - null, - null - ); - var messageList = new ArrayList(); - messageList.add(message); - // Create the unified request - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxCompletionTokens - null, // stop - null, // temperature - null, // toolChoice - null, // tools - null // topP - ); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Test with stream set to true - UnifiedChatInput unifiedChatInputTrue = new UnifiedChatInput(unifiedRequest, true); - OpenAiUnifiedChatCompletionRequestEntity entityTrue = new OpenAiUnifiedChatCompletionRequestEntity( - unifiedChatInputTrue, - new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) - ); - - XContentBuilder builderTrue = JsonXContent.contentBuilder(); - entityTrue.toXContent(builderTrue, ToXContent.EMPTY_PARAMS); - - String jsonStringTrue = Strings.toString(builderTrue); - String expectedJsonTrue = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user" - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(expectedJsonTrue, jsonStringTrue); - - // Test with stream set to false - UnifiedChatInput unifiedChatInputFalse = new UnifiedChatInput(unifiedRequest, false); - OpenAiUnifiedChatCompletionRequestEntity entityFalse = new OpenAiUnifiedChatCompletionRequestEntity( - unifiedChatInputFalse, - new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) - ); - - XContentBuilder builderFalse = JsonXContent.contentBuilder(); - entityFalse.toXContent(builderFalse, ToXContent.EMPTY_PARAMS); - - String jsonStringFalse = Strings.toString(builderFalse); - String expectedJsonFalse = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user" - } - ], - "model": "model-name", - "n": 1, - "stream": false - } - """; - assertJsonEquals(expectedJsonFalse, jsonStringFalse); - } - - // 9. Serialization with Missing Required Fields - // Test with missing required fields to ensure appropriate exceptions are thrown. - public void testSerializationWithMissingRequiredFields() { - // Create a message with missing content (required field) - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - null, // missing content - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - null, - null, - null - ); - var messageList = new ArrayList(); - messageList.add(message); - // Create the unified request - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxCompletionTokens - null, // stop - null, // temperature - null, // toolChoice - null, // tools - null // topP - ); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) - ); - - // Attempt to serialize to XContent and expect an exception - try { - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - fail("Expected an exception due to missing required fields"); - } catch (NullPointerException | IOException e) { - // Expected exception - } - } - - // 10. Serialization with Mixed Valid and Invalid Data - // Test with a mix of valid and invalid data to ensure the serializer handles it gracefully. - public void testSerializationWithMixedValidAndInvalidData() throws IOException { - // Create a valid message - UnifiedCompletionRequest.Message validMessage = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Valid content"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - "validName", - "validToolCallId", - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "validId", - new UnifiedCompletionRequest.ToolCall.FunctionField("validArguments", "validFunctionName"), - "validType" - ) - ) - ); - - // Create an invalid message with null content - UnifiedCompletionRequest.Message invalidMessage = new UnifiedCompletionRequest.Message( - null, // invalid content - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - "invalidName", - "invalidToolCallId", - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "invalidId", - new UnifiedCompletionRequest.ToolCall.FunctionField("invalidArguments", "invalidFunctionName"), - "invalidType" - ) - ) - ); - var messageList = new ArrayList(); - messageList.add(validMessage); - messageList.add(invalidMessage); - // Create the unified request with both valid and invalid messages - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - "model-name", - 100L, // maxCompletionTokens - Collections.singletonList("stop"), - 0.9f, // temperature - new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), - Collections.singletonList( - new UnifiedCompletionRequest.Tool( - "type", - new UnifiedCompletionRequest.Tool.FunctionField( - "Fetches the weather in the given location", - "get_weather", - createParameters(), - true - ) - ) - ), - 0.8f // topP - ); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("EIS Model", null) - ); - - // Serialize to XContent and verify - try { - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - fail("Expected an exception due to invalid data"); - } catch (NullPointerException | IOException e) { - // Expected exception - } - } - - public static Map createParameters() { - Map parameters = new LinkedHashMap<>(); - parameters.put("type", "object"); - - Map properties = new HashMap<>(); - - Map location = new HashMap<>(); - location.put("type", "string"); - location.put("description", "The location to get the weather for"); - properties.put("location", location); - - Map unit = new HashMap<>(); - unit.put("type", "string"); - unit.put("description", "The unit to return the temperature in"); - unit.put("enum", new String[] { "F", "C" }); - properties.put("unit", unit); - - parameters.put("properties", properties); - parameters.put("additionalProperties", false); - parameters.put("required", new String[] { "location", "unit" }); - - return parameters; - } - - private void assertJsonEquals(String actual, String expected) throws IOException { - try ( - var actualParser = createParser(JsonXContent.jsonXContent, actual); - var expectedParser = createParser(JsonXContent.jsonXContent, expected) - ) { - assertThat(actualParser.mapOrdered(), equalTo(expectedParser.mapOrdered())); - } - } - } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java index b4ecf367cc485..f43b185391697 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java @@ -43,10 +43,7 @@ public void testModelUserFieldsSerialization() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", USER); - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - new OpenAiUnifiedChatCompletionRequestEntity.ModelFields("Open AI model", null) - ); + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); From 72139327dcbd00fe44ef3c70100b62cea13cfe4a Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Mon, 9 Dec 2024 16:32:56 -0500 Subject: [PATCH 05/42] Update docs/changelog/118301.yaml --- docs/changelog/118301.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/118301.yaml diff --git a/docs/changelog/118301.yaml b/docs/changelog/118301.yaml new file mode 100644 index 0000000000000..53b7deff92794 --- /dev/null +++ b/docs/changelog/118301.yaml @@ -0,0 +1,5 @@ +pr: 118301 +summary: EIS Unified chat completions integration +area: Machine Learning +type: enhancement +issues: [] From ae9dbf73bda4f3bfbf125c23d34b594a9d089aed Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 10 Dec 2024 08:34:48 -0500 Subject: [PATCH 06/42] Fixing comment --- .../request/elastic/EISUnifiedChatCompletionRequest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java index 0d778def96723..81a002d31ef8c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java @@ -75,13 +75,13 @@ public URI getURI() { @Override public Request truncate() { - // No truncation for OpenAI chat completions + // No truncation return this; } @Override public boolean[] getTruncationInfo() { - // No truncation for OpenAI chat completions + // No truncation return null; } From 147ba776aa108b1acf652dafe84d8ec33f734b97 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 11 Dec 2024 16:06:24 -0500 Subject: [PATCH 07/42] Adding some initial tests --- .../completion/EISCompletionModelTests.java | 52 ++++++++++++ .../EISCompletionServiceSettingsTests.java | 82 +++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java new file mode 100644 index 0000000000000..642f18e211013 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.completion; + +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.List; + +import static org.hamcrest.Matchers.is; + +// TODO determine if we need the model id +public class EISCompletionModelTests extends ESTestCase { + + public void testOverridingModelId() { + var originalModel = new ElasticInferenceServiceCompletionModel( + "id", + TaskType.COMPLETION, + "elastic", + new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents("url") + ); + + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("message"), "user", null, null, null)), + "new_model_id", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = ElasticInferenceServiceCompletionModel.of(originalModel, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("new_model_id")); + assertThat(overriddenModel.getTaskType(), is(TaskType.COMPLETION)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..780bba28657cd --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class EISCompletionServiceSettingsTests extends AbstractWireSerializingTestCase { + + @Override + protected Writeable.Reader instanceReader() { + return ElasticInferenceServiceCompletionServiceSettings::new; + } + + @Override + protected ElasticInferenceServiceCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected ElasticInferenceServiceCompletionServiceSettings mutateInstance(ElasticInferenceServiceCompletionServiceSettings instance) + throws IOException { + return randomValueOtherThan(instance, EISCompletionServiceSettingsTests::createRandom); + } + + public void testFromMap() { + var modelId = "model_id"; + + var serviceSettings = ElasticInferenceServiceCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + ConfigurationParseContext.REQUEST + ); + + assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(1000)))); + } + + public void testFromMap_MissingModelId_ThrowsException() { + ValidationException validationException = expectThrows( + ValidationException.class, + () -> ElasticInferenceServiceCompletionServiceSettings.fromMap(new HashMap<>(Map.of()), ConfigurationParseContext.REQUEST) + ); + + assertThat(validationException.getMessage(), containsString("does not contain the required setting [model_id]")); + } + + public void testToXContent_WritesAllFields() throws IOException { + var modelId = "model_id"; + var serviceSettings = new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(1000)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(Strings.format(""" + {"model_id":"%s","rate_limit":{"requests_per_minute":1000}}""", modelId))); + } + + public static ElasticInferenceServiceCompletionServiceSettings createRandom() { + return new ElasticInferenceServiceCompletionServiceSettings(randomAlphaOfLength(4), RateLimitSettingsTests.createRandom()); + } +} From 9d4e02e629534660653ff723ce384c54b3cddd12 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 11 Dec 2024 16:31:33 -0500 Subject: [PATCH 08/42] Moving tests around --- .../EISUnifiedChatCompletionRequestEntityTests.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/{openai => elastic}/EISUnifiedChatCompletionRequestEntityTests.java (94%) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/EISUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntityTests.java similarity index 94% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/EISUnifiedChatCompletionRequestEntityTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntityTests.java index ccea29bcc66dc..2a1b597725c4b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/EISUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntityTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.external.request.openai; +package org.elasticsearch.xpack.inference.external.request.elastic; import org.elasticsearch.common.Strings; import org.elasticsearch.inference.UnifiedCompletionRequest; @@ -14,6 +14,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; import java.io.IOException; From 9f78b40a51f8c72c5e9738fd74524c826ef405bf Mon Sep 17 00:00:00 2001 From: Jason Botzas-Coluni <44372106+jaybcee@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:22:31 -0500 Subject: [PATCH 09/42] Address some TODOs --- .../request/elastic/EISUnifiedChatCompletionRequest.java | 3 --- .../elastic/EISUnifiedChatCompletionRequestEntity.java | 2 -- .../completion/ElasticInferenceServiceCompletionModel.java | 5 ++--- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java index 81a002d31ef8c..749a65fa0509b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java @@ -50,7 +50,6 @@ public EISUnifiedChatCompletionRequest( public HttpRequest createHttpRequest() { var httpPost = new HttpPost(uri); var requestEntity = Strings.toString( - // TODO remove the modelId() call if not used new EISUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()) ); @@ -62,8 +61,6 @@ public HttpRequest createHttpRequest() { } httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); - // TODO remove EIS doesn't use an API key - httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey())); return new HttpRequest(httpPost, getInferenceEntityId()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java index 64120693172cb..ecc758c5d2f18 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java @@ -16,7 +16,6 @@ import java.util.Objects; public class EISUnifiedChatCompletionRequestEntity implements ToXContentObject { - // TODO remove this if EIS doesn't use it private static final String MODEL_FIELD = "model"; private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; @@ -31,7 +30,6 @@ public EISUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); unifiedRequestEntity.toXContent(builder, params); - // TODO remove this if EIS doesn't use it builder.field(MODEL_FIELD, modelId); builder.endObject(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java index 6c0b4284fab23..a091376ade7fa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -36,7 +37,6 @@ public static ElasticInferenceServiceCompletionModel of( import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; -import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.net.URI; import java.net.URISyntaxException; @@ -76,8 +76,7 @@ public ElasticInferenceServiceCompletionModel( service, ElasticInferenceServiceCompletionServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, - // TODO remove this as EIS doesn't use it - DefaultSecretSettings.fromMap(secrets), + EmptySecretSettings.INSTANCE, elasticInferenceServiceComponents ); } From b09b3f5eae7ecba35c2a4ef2f35355ce8750315a Mon Sep 17 00:00:00 2001 From: Jason Botzas-Coluni <44372106+jaybcee@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:30:50 -0500 Subject: [PATCH 10/42] Remove a TODO --- .../services/elastic/completion/EISCompletionModelTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java index 642f18e211013..6947f96d9b7c5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java @@ -19,7 +19,6 @@ import static org.hamcrest.Matchers.is; -// TODO determine if we need the model id public class EISCompletionModelTests extends ESTestCase { public void testOverridingModelId() { From b92724b711733532e6429d6be1811dcbdb02ccf9 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 18 Dec 2024 21:38:47 +0000 Subject: [PATCH 11/42] [CI] Auto commit changes from spotless --- .../request/elastic/EISUnifiedChatCompletionRequest.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java index 749a65fa0509b..66a3da88d7c4b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java @@ -25,8 +25,6 @@ import java.nio.charset.StandardCharsets; import java.util.Objects; -import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; - public class EISUnifiedChatCompletionRequest implements OpenAiRequest { private final ElasticInferenceServiceCompletionModel model; From 25ab348a4445b2a5c4ffc514bda7f46cfc2219ed Mon Sep 17 00:00:00 2001 From: Jason Botzas-Coluni <44372106+jaybcee@users.noreply.github.com> Date: Thu, 19 Dec 2024 09:45:27 -0500 Subject: [PATCH 12/42] Delete docs/changelog/118301.yaml --- docs/changelog/118301.yaml | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 docs/changelog/118301.yaml diff --git a/docs/changelog/118301.yaml b/docs/changelog/118301.yaml deleted file mode 100644 index 53b7deff92794..0000000000000 --- a/docs/changelog/118301.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 118301 -summary: EIS Unified chat completions integration -area: Machine Learning -type: enhancement -issues: [] From 7168dc69132322db6406c3e24c8db225db84c195 Mon Sep 17 00:00:00 2001 From: Jason Botzas-Coluni <44372106+jaybcee@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:20:36 -0500 Subject: [PATCH 13/42] Rename EISUnifiedChatCompletionResponseHandler --- ...SUnifiedChatCompletionResponseHandler.java | 40 ------------------- .../EISUnifiedCompletionRequestManager.java | 4 +- ...renceServiceCompletionServiceSettings.java | 1 - 3 files changed, 2 insertions(+), 43 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/EISUnifiedChatCompletionResponseHandler.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/EISUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/EISUnifiedChatCompletionResponseHandler.java deleted file mode 100644 index f1007b1c30d74..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/EISUnifiedChatCompletionResponseHandler.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.elastic; - -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; -import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor; -import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; -import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; - -import java.util.concurrent.Flow; - -public class EISUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler { - public EISUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction); - } - - @Override - public boolean canHandleStreamingResponses() { - return true; - } - - @Override - public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { - var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); - var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); // EIS uses the unified API spec - - flow.subscribe(serverSentEventProcessor); - serverSentEventProcessor.subscribe(openAiProcessor); - return new StreamingUnifiedChatCompletionResults(openAiProcessor); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java index 9bfdc7fc7e3ef..d3dcba050c8f6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java @@ -12,7 +12,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.elastic.EISUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.request.elastic.EISUnifiedChatCompletionRequest; @@ -72,6 +72,6 @@ public void execute( } private static ResponseHandler createCompletionHandler() { - return new EISUnifiedChatCompletionResponseHandler("eis completion", OpenAiChatCompletionResponseEntity::fromResponse); + return new ElasticInferenceServiceUnifiedChatCompletionResponseHandler("eis completion", OpenAiChatCompletionResponseEntity::fromResponse); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java index 05e5b2a12a7a9..3c8182a7d41a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java @@ -41,7 +41,6 @@ public class ElasticInferenceServiceCompletionServiceSettings extends FilteredXC public static ElasticInferenceServiceCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); - // TODO does EIS have this? String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of( map, From 91daa01949d0dfb4f4bfc309a096dddf682d739d Mon Sep 17 00:00:00 2001 From: Jason Botzas-Coluni <44372106+jaybcee@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:21:41 -0500 Subject: [PATCH 14/42] Renames to ElasticInferenceServiceUnifiedCompletionRequestManager --- .../EISUnifiedCompletionRequestManager.java | 77 ------------------- 1 file changed, 77 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java deleted file mode 100644 index d3dcba050c8f6..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EISUnifiedCompletionRequestManager.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.sender; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.request.elastic.EISUnifiedChatCompletionRequest; -import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; -import org.elasticsearch.xpack.inference.telemetry.TraceContext; - -import java.util.Objects; -import java.util.function.Supplier; - -public class EISUnifiedCompletionRequestManager extends ElasticInferenceServiceRequestManager { - - private static final Logger logger = LogManager.getLogger(EISUnifiedCompletionRequestManager.class); - - private static final ResponseHandler HANDLER = createCompletionHandler(); - - public static EISUnifiedCompletionRequestManager of( - ElasticInferenceServiceCompletionModel model, - ThreadPool threadPool, - TraceContext traceContext - ) { - return new EISUnifiedCompletionRequestManager( - Objects.requireNonNull(model), - Objects.requireNonNull(threadPool), - Objects.requireNonNull(traceContext) - ); - } - - private final ElasticInferenceServiceCompletionModel model; - private final TraceContext traceContext; - - private EISUnifiedCompletionRequestManager( - ElasticInferenceServiceCompletionModel model, - ThreadPool threadPool, - TraceContext traceContext - ) { - super(threadPool, model); - this.model = model; - this.traceContext = traceContext; - } - - @Override - public void execute( - InferenceInputs inferenceInputs, - RequestSender requestSender, - Supplier hasRequestCompletedFunction, - ActionListener listener - ) { - - EISUnifiedChatCompletionRequest request = new EISUnifiedChatCompletionRequest( - inferenceInputs.castTo(UnifiedChatInput.class), - model, - traceContext - ); - - execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); - } - - private static ResponseHandler createCompletionHandler() { - return new ElasticInferenceServiceUnifiedChatCompletionResponseHandler("eis completion", OpenAiChatCompletionResponseEntity::fromResponse); - } -} From 7842b6ce04367a98c5f15a1806d054dfc63842d9 Mon Sep 17 00:00:00 2001 From: Jason Botzas-Coluni <44372106+jaybcee@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:25:48 -0500 Subject: [PATCH 15/42] Renames EISUnifiedChatCompletionRequest --- .../EISUnifiedChatCompletionRequest.java | 109 ------------------ ...ifiedChatCompletionRequestEntityTests.java | 74 ------------ 2 files changed, 183 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntityTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java deleted file mode 100644 index 66a3da88d7c4b..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequest.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.elastic; - -import org.apache.http.HttpHeaders; -import org.apache.http.client.methods.HttpPost; -import org.apache.http.entity.ByteArrayEntity; -import org.apache.http.message.BasicHeader; -import org.elasticsearch.common.Strings; -import org.elasticsearch.tasks.Task; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; -import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.external.request.openai.OpenAiRequest; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; -import org.elasticsearch.xpack.inference.telemetry.TraceContext; - -import java.net.URI; -import java.nio.charset.StandardCharsets; -import java.util.Objects; - -public class EISUnifiedChatCompletionRequest implements OpenAiRequest { - - private final ElasticInferenceServiceCompletionModel model; - private final UnifiedChatInput unifiedChatInput; - private final URI uri; - private final TraceContext traceContext; - - public EISUnifiedChatCompletionRequest( - UnifiedChatInput unifiedChatInput, - ElasticInferenceServiceCompletionModel model, - TraceContext traceContext - ) { - this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); - this.model = Objects.requireNonNull(model); - this.uri = model.uri(); - this.traceContext = traceContext; - - } - - @Override - public HttpRequest createHttpRequest() { - var httpPost = new HttpPost(uri); - var requestEntity = Strings.toString( - new EISUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()) - ); - - ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); - httpPost.setEntity(byteEntity); - - if (traceContext != null) { - propagateTraceContext(httpPost); - } - - httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); - - return new HttpRequest(httpPost, getInferenceEntityId()); - } - - @Override - public URI getURI() { - return uri; - } - - @Override - public Request truncate() { - // No truncation - return this; - } - - @Override - public boolean[] getTruncationInfo() { - // No truncation - return null; - } - - @Override - public String getInferenceEntityId() { - return model.getInferenceEntityId(); - } - - @Override - public boolean isStreaming() { - return true; - } - - public TraceContext getTraceContext() { - return traceContext; - } - - private void propagateTraceContext(HttpPost httpPost) { - var traceParent = traceContext.traceParent(); - var traceState = traceContext.traceState(); - - if (traceParent != null) { - httpPost.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent); - } - - if (traceState != null) { - httpPost.setHeader(Task.TRACE_STATE, traceState); - } - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntityTests.java deleted file mode 100644 index 2a1b597725c4b..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntityTests.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.elastic; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.inference.UnifiedCompletionRequest; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.json.JsonXContent; -import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity; -import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; - -import java.io.IOException; -import java.util.ArrayList; - -import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; -import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; - -public class EISUnifiedChatCompletionRequestEntityTests extends ESTestCase { - - private static final String ROLE = "user"; - private static final String USER = "a_user"; - - // TODO remove if EIS doesn't use the model and user fields - public void testModelUserFieldsSerialization() throws IOException { - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - ROLE, - null, - null, - null - ); - var messageList = new ArrayList(); - messageList.add(message); - - var unifiedRequest = UnifiedCompletionRequest.of(messageList); - - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", USER); - - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); - - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - - String jsonString = Strings.toString(builder); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user" - } - ], - "model": "test-endpoint", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - }, - "user": "a_user" - } - """; - assertJsonEquals(jsonString, expectedJson); - } - -} From 7f44d0445d69f9fc1dd3a348004986905906e082 Mon Sep 17 00:00:00 2001 From: Jason Botzas-Coluni <44372106+jaybcee@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:31:28 -0500 Subject: [PATCH 16/42] Renames and comments --- ...EISUnifiedChatCompletionRequestEntity.java | 38 ------------------- ...ceServiceUnifiedChatCompletionRequest.java | 4 ++ 2 files changed, 4 insertions(+), 38 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java deleted file mode 100644 index ecc758c5d2f18..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/EISUnifiedChatCompletionRequestEntity.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.elastic; - -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; - -import java.io.IOException; -import java.util.Objects; - -public class EISUnifiedChatCompletionRequestEntity implements ToXContentObject { - private static final String MODEL_FIELD = "model"; - - private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; - private final String modelId; - - public EISUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) { - this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput)); - this.modelId = Objects.requireNonNull(modelId); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - unifiedRequestEntity.toXContent(builder, params); - builder.field(MODEL_FIELD, modelId); - builder.endObject(); - - return builder; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java index 112ead7057933..6304548a3d32b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java @@ -26,6 +26,10 @@ public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Request { + // Implementing OpenAiRequest to ensure compatibility with the OpenAI API interface + // This allows the ElasticInferenceService to handle requests in a standardized manner + // and leverage existing infrastructure for processing OpenAI-like requests. + private final ElasticInferenceServiceCompletionModel model; private final UnifiedChatInput unifiedChatInput; private final TraceContextHandler traceContextHandler; From afc8ebc7f0d9f9bd7928dec6681339edcfc3095e Mon Sep 17 00:00:00 2001 From: Jason Botzas-Coluni <44372106+jaybcee@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:47:51 -0500 Subject: [PATCH 17/42] propagateTraceContext extraction --- .../request/TraceContextPropagator.java | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/TraceContextPropagator.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/TraceContextPropagator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/TraceContextPropagator.java new file mode 100644 index 0000000000000..2a05016766ea5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/TraceContextPropagator.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request; + +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; + +public final class TraceContextPropagator { + private TraceContextPropagator() {} // Utility class + + public static void propagateTraceContext(HttpPost httpPost, TraceContext traceContext) { + if (traceContext == null) { + return; + } + + var traceParent = traceContext.traceParent(); + var traceState = traceContext.traceState(); + + if (traceParent != null) { + httpPost.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent); + } + + if (traceState != null) { + httpPost.setHeader(Task.TRACE_STATE, traceState); + } + } +} From 346ceba5adf04e3469f43ef33528d2698a4091df Mon Sep 17 00:00:00 2001 From: Jason Botzas-Coluni <44372106+jaybcee@users.noreply.github.com> Date: Thu, 19 Dec 2024 14:25:51 -0500 Subject: [PATCH 18/42] Clean up trace --- .../TraceContextAware.java} | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/{external/request/TraceContextPropagator.java => telemetry/TraceContextAware.java} (70%) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/TraceContextPropagator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/TraceContextAware.java similarity index 70% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/TraceContextPropagator.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/TraceContextAware.java index 2a05016766ea5..667d0992790fe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/TraceContextPropagator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/TraceContextAware.java @@ -4,17 +4,16 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ - -package org.elasticsearch.xpack.inference.external.request; +package org.elasticsearch.xpack.inference.telemetry; import org.apache.http.client.methods.HttpPost; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.inference.telemetry.TraceContext; -public final class TraceContextPropagator { - private TraceContextPropagator() {} // Utility class +public interface TraceContextAware { + TraceContext getTraceContext(); - public static void propagateTraceContext(HttpPost httpPost, TraceContext traceContext) { + default void propagateTraceContext(HttpPost httpPost) { + TraceContext traceContext = this.getTraceContext(); if (traceContext == null) { return; } From a206fabde622acb98601b614c3391a4b1318ef97 Mon Sep 17 00:00:00 2001 From: Jason Botzas-Coluni <44372106+jaybcee@users.noreply.github.com> Date: Mon, 23 Dec 2024 15:28:56 -0500 Subject: [PATCH 19/42] Address comments --- .../telemetry/TraceContextAware.java | 32 ------------------- .../EISCompletionServiceSettingsTests.java | 2 +- 2 files changed, 1 insertion(+), 33 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/TraceContextAware.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/TraceContextAware.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/TraceContextAware.java deleted file mode 100644 index 667d0992790fe..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/TraceContextAware.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.inference.telemetry; - -import org.apache.http.client.methods.HttpPost; -import org.elasticsearch.tasks.Task; - -public interface TraceContextAware { - TraceContext getTraceContext(); - - default void propagateTraceContext(HttpPost httpPost) { - TraceContext traceContext = this.getTraceContext(); - if (traceContext == null) { - return; - } - - var traceParent = traceContext.traceParent(); - var traceState = traceContext.traceState(); - - if (traceParent != null) { - httpPost.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent); - } - - if (traceState != null) { - httpPost.setHeader(Task.TRACE_STATE, traceState); - } - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java index 780bba28657cd..44225bb4074b7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java @@ -52,7 +52,7 @@ public void testFromMap() { ConfigurationParseContext.REQUEST ); - assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(1000)))); + assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(240L)))); } public void testFromMap_MissingModelId_ThrowsException() { From cef606d69c9f443b4e989465a420d69b2b70cfb6 Mon Sep 17 00:00:00 2001 From: Jason Botzas-Coluni <44372106+jaybcee@users.noreply.github.com> Date: Tue, 7 Jan 2025 09:31:08 -0500 Subject: [PATCH 20/42] Update x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java Co-authored-by: Tim Grein --- .../elastic/completion/EISCompletionServiceSettingsTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java index 44225bb4074b7..489c20bb3dd6b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java @@ -26,7 +26,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; -public class EISCompletionServiceSettingsTests extends AbstractWireSerializingTestCase { +public class ElasticInferenceServiceCompletionServiceSettingsTests extends AbstractWireSerializingTestCase { @Override protected Writeable.Reader instanceReader() { From a5e91d9d06f177624c6f4647f488173fd3ef4a06 Mon Sep 17 00:00:00 2001 From: Jason Botzas-Coluni <44372106+jaybcee@users.noreply.github.com> Date: Tue, 7 Jan 2025 09:31:37 -0500 Subject: [PATCH 21/42] Update x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java Co-authored-by: Tim Grein --- .../services/elastic/completion/EISCompletionModelTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java index 6947f96d9b7c5..cc1463232e7e5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java @@ -19,7 +19,7 @@ import static org.hamcrest.Matchers.is; -public class EISCompletionModelTests extends ESTestCase { +public class ElasticInferenceServiceCompletionModelTests extends ESTestCase { public void testOverridingModelId() { var originalModel = new ElasticInferenceServiceCompletionModel( From 6933c491677310edb1e65b484cfb63fde35b5191 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 7 Jan 2025 14:38:47 +0000 Subject: [PATCH 22/42] [CI] Auto commit changes from spotless --- .../elastic/completion/EISCompletionServiceSettingsTests.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java index 489c20bb3dd6b..aa0810ebd4abe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java @@ -26,7 +26,8 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; -public class ElasticInferenceServiceCompletionServiceSettingsTests extends AbstractWireSerializingTestCase { +public class ElasticInferenceServiceCompletionServiceSettingsTests extends AbstractWireSerializingTestCase< + ElasticInferenceServiceCompletionServiceSettings> { @Override protected Writeable.Reader instanceReader() { From 31fe29cd98eb4ea0087a39f5c97d4251f2c3bda5 Mon Sep 17 00:00:00 2001 From: Jason Botzas-Coluni <44372106+jaybcee@users.noreply.github.com> Date: Tue, 7 Jan 2025 10:03:25 -0500 Subject: [PATCH 23/42] Address comments --- .../completion/EISCompletionModelTests.java | 51 ------------ .../EISCompletionServiceSettingsTests.java | 83 ------------------- 2 files changed, 134 deletions(-) delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java deleted file mode 100644 index cc1463232e7e5..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionModelTests.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.completion; - -import org.elasticsearch.inference.EmptySecretSettings; -import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnifiedCompletionRequest; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; - -import java.util.List; - -import static org.hamcrest.Matchers.is; - -public class ElasticInferenceServiceCompletionModelTests extends ESTestCase { - - public void testOverridingModelId() { - var originalModel = new ElasticInferenceServiceCompletionModel( - "id", - TaskType.COMPLETION, - "elastic", - new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - new ElasticInferenceServiceComponents("url") - ); - - var request = new UnifiedCompletionRequest( - List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("message"), "user", null, null, null)), - "new_model_id", - null, - null, - null, - null, - null, - null - ); - - var overriddenModel = ElasticInferenceServiceCompletionModel.of(originalModel, request); - - assertThat(overriddenModel.getServiceSettings().modelId(), is("new_model_id")); - assertThat(overriddenModel.getTaskType(), is(TaskType.COMPLETION)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java deleted file mode 100644 index aa0810ebd4abe..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.completion; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.ServiceFields; -import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.is; - -public class ElasticInferenceServiceCompletionServiceSettingsTests extends AbstractWireSerializingTestCase< - ElasticInferenceServiceCompletionServiceSettings> { - - @Override - protected Writeable.Reader instanceReader() { - return ElasticInferenceServiceCompletionServiceSettings::new; - } - - @Override - protected ElasticInferenceServiceCompletionServiceSettings createTestInstance() { - return createRandom(); - } - - @Override - protected ElasticInferenceServiceCompletionServiceSettings mutateInstance(ElasticInferenceServiceCompletionServiceSettings instance) - throws IOException { - return randomValueOtherThan(instance, EISCompletionServiceSettingsTests::createRandom); - } - - public void testFromMap() { - var modelId = "model_id"; - - var serviceSettings = ElasticInferenceServiceCompletionServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), - ConfigurationParseContext.REQUEST - ); - - assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(240L)))); - } - - public void testFromMap_MissingModelId_ThrowsException() { - ValidationException validationException = expectThrows( - ValidationException.class, - () -> ElasticInferenceServiceCompletionServiceSettings.fromMap(new HashMap<>(Map.of()), ConfigurationParseContext.REQUEST) - ); - - assertThat(validationException.getMessage(), containsString("does not contain the required setting [model_id]")); - } - - public void testToXContent_WritesAllFields() throws IOException { - var modelId = "model_id"; - var serviceSettings = new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(1000)); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - serviceSettings.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(Strings.format(""" - {"model_id":"%s","rate_limit":{"requests_per_minute":1000}}""", modelId))); - } - - public static ElasticInferenceServiceCompletionServiceSettings createRandom() { - return new ElasticInferenceServiceCompletionServiceSettings(randomAlphaOfLength(4), RateLimitSettingsTests.createRandom()); - } -} From 8ebe833c47596c9fc477f516c9745ab2af7c8f0c Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Tue, 7 Jan 2025 12:48:23 -0500 Subject: [PATCH 24/42] add default endpoint for EIS completion --- .../elastic/ElasticInferenceService.java | 61 ++++++++++++++++--- ...lasticInferenceServiceCompletionModel.java | 2 +- 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 48416faac6a06..dc2a735a077f6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -42,6 +42,7 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -72,6 +73,10 @@ public class ElasticInferenceService extends SenderService { private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION); private static final String SERVICE_NAME = "Elastic"; + public static final String V1_EIS_COMPLETION_MODEL_ID = "temp1"; + + public static final String DEFAULT_EIS_COMPLETION_ENDPOINT_ID = "eis-alpha-1"; + public ElasticInferenceService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, @@ -210,6 +215,32 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + @Override + public List defaultConfigIds() { + return List.of(new DefaultConfigId(DEFAULT_EIS_COMPLETION_ENDPOINT_ID, TaskType.COMPLETION, this)); + } + + @Override + public void defaultConfigs(ActionListener> defaultsListener) { + var serviceSettings = new HashMap(1); + serviceSettings.put(MODEL_ID, "elastic-model"); // TODO + + defaultsListener.onResponse( + List.of( + new ElasticInferenceServiceCompletionModel( + DEFAULT_EIS_COMPLETION_ENDPOINT_ID, + TaskType.COMPLETION, + NAME, + serviceSettings, + null, + null, + new ElasticInferenceServiceComponents("http://localhost:8080"), // TODO + ConfigurationParseContext.PERSISTENT + ) + ) + ); + } + private static ElasticInferenceServiceModel createModel( String inferenceEntityId, TaskType taskType, @@ -271,16 +302,30 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) - ); + if (DEFAULT_EIS_COMPLETION_ENDPOINT_ID.equals(inferenceEntityId)) { + return V1_EIS_COMPLETION_MODEL; + } else { + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } } + private static final Model V1_EIS_COMPLETION_MODEL = new ElasticInferenceServiceCompletionModel( + V1_EIS_COMPLETION_MODEL_ID, + TaskType.COMPLETION, + NAME, + (ElasticInferenceServiceCompletionServiceSettings) Map.of(MODEL_ID, DEFAULT_EIS_COMPLETION_ENDPOINT_ID), + EmptyTaskSettings.INSTANCE, + null, + null + ); + @Override public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_16_0; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java index a091376ade7fa..2ca48b628ad1b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java @@ -90,7 +90,7 @@ public ElasticInferenceServiceCompletionModel( } - ElasticInferenceServiceCompletionModel( + public ElasticInferenceServiceCompletionModel( String inferenceEntityId, TaskType taskType, String service, From 1be3aca6a0e6fc8dd90889229ae114b37c7babcd Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Tue, 7 Jan 2025 14:25:47 -0500 Subject: [PATCH 25/42] Avoid using immutable map for constructing EISCompletionModel --- .../elastic/ElasticInferenceService.java | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index dc2a735a077f6..09e3c5a922725 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -303,7 +303,21 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); if (DEFAULT_EIS_COMPLETION_ENDPOINT_ID.equals(inferenceEntityId)) { - return V1_EIS_COMPLETION_MODEL; + var defaultServiceSettings = new HashMap(1); + var serviceSettings = ElasticInferenceServiceCompletionServiceSettings.fromMap( + defaultServiceSettings, + ConfigurationParseContext.PERSISTENT + ); + + return new ElasticInferenceServiceCompletionModel( + V1_EIS_COMPLETION_MODEL_ID, + TaskType.COMPLETION, + NAME, + serviceSettings, + EmptyTaskSettings.INSTANCE, + null, + null + ); } else { return createModelFromPersistent( inferenceEntityId, @@ -316,15 +330,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M } } - private static final Model V1_EIS_COMPLETION_MODEL = new ElasticInferenceServiceCompletionModel( - V1_EIS_COMPLETION_MODEL_ID, - TaskType.COMPLETION, - NAME, - (ElasticInferenceServiceCompletionServiceSettings) Map.of(MODEL_ID, DEFAULT_EIS_COMPLETION_ENDPOINT_ID), - EmptyTaskSettings.INSTANCE, - null, - null - ); + @Override public TransportVersion getMinimalSupportedVersion() { From 02787761158f6a77833cd7423ac70c15905ba5d5 Mon Sep 17 00:00:00 2001 From: Max Hniebergall <137079448+maxhniebergall@users.noreply.github.com> Date: Tue, 7 Jan 2025 13:38:00 -0500 Subject: [PATCH 26/42] Update docs/changelog/119694.yaml --- docs/changelog/119694.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/119694.yaml diff --git a/docs/changelog/119694.yaml b/docs/changelog/119694.yaml new file mode 100644 index 0000000000000..daa53fdd12b63 --- /dev/null +++ b/docs/changelog/119694.yaml @@ -0,0 +1,5 @@ +pr: 119694 +summary: "[Inference API] Default eis endpoint" +area: Machine Learning +type: enhancement +issues: [] From eeb12b3e6539b8e4a39f8ea31fc12d35aaf27fe4 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Tue, 7 Jan 2025 15:17:27 -0500 Subject: [PATCH 27/42] actually include service settings --- .../inference/services/elastic/ElasticInferenceService.java | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 09e3c5a922725..cdfc1b5bb85a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -304,6 +304,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M if (DEFAULT_EIS_COMPLETION_ENDPOINT_ID.equals(inferenceEntityId)) { var defaultServiceSettings = new HashMap(1); + defaultServiceSettings.put(MODEL_ID, "elastic-model"); // TODO var serviceSettings = ElasticInferenceServiceCompletionServiceSettings.fromMap( defaultServiceSettings, ConfigurationParseContext.PERSISTENT From a619561a3f09c5bb7b897f64f664e053b338d788 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 7 Jan 2025 19:54:40 +0000 Subject: [PATCH 28/42] [CI] Auto commit changes from spotless --- .../inference/services/elastic/ElasticInferenceService.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index cdfc1b5bb85a6..8cd1e4e6c5021 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -331,8 +331,6 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M } } - - @Override public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_16_0; From f9e6b7c7e4d7c6e90f4e83c16c0b31895db5da38 Mon Sep 17 00:00:00 2001 From: Max Hniebergall <137079448+maxhniebergall@users.noreply.github.com> Date: Tue, 7 Jan 2025 15:41:53 -0500 Subject: [PATCH 29/42] Update changemessage --- docs/changelog/119694.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog/119694.yaml b/docs/changelog/119694.yaml index daa53fdd12b63..0205004980ba2 100644 --- a/docs/changelog/119694.yaml +++ b/docs/changelog/119694.yaml @@ -1,5 +1,5 @@ pr: 119694 -summary: "[Inference API] Default eis endpoint" +summary: "[Inference API] Add default endpoint for completion in elastic inference service" area: Machine Learning type: enhancement issues: [] From 59c8791b07b79a0911b8f05ae3dea0396c334b35 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Wed, 8 Jan 2025 15:42:09 -0500 Subject: [PATCH 30/42] match ElasticsearchInternalService implementation of defaults --- .../elastic/ElasticInferenceService.java | 80 +++++++++---------- 1 file changed, 39 insertions(+), 41 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 8cd1e4e6c5021..fd4777f87ae57 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -77,6 +77,8 @@ public class ElasticInferenceService extends SenderService { public static final String DEFAULT_EIS_COMPLETION_ENDPOINT_ID = "eis-alpha-1"; + public static final List DEFAULT_EIS_ENDPOINT_IDS = List.of(DEFAULT_EIS_COMPLETION_ENDPOINT_ID); + public ElasticInferenceService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, @@ -180,8 +182,18 @@ public void parseRequestConfig( Map config, ActionListener parsedModelListener ) { - try { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + if (DEFAULT_EIS_ENDPOINT_IDS.contains(inferenceEntityId)) { + parsedModelListener.onFailure( + new ElasticsearchStatusException( + "[{}] is a reserved inference Id. Cannot create a new inference endpoint with a reserved Id", + RestStatus.BAD_REQUEST, + inferenceEntityId + ) + ); + return; + } + + try {Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ElasticInferenceServiceModel model = createModel( @@ -222,25 +234,30 @@ public List defaultConfigIds() { @Override public void defaultConfigs(ActionListener> defaultsListener) { - var serviceSettings = new HashMap(1); - serviceSettings.put(MODEL_ID, "elastic-model"); // TODO defaultsListener.onResponse( List.of( - new ElasticInferenceServiceCompletionModel( - DEFAULT_EIS_COMPLETION_ENDPOINT_ID, - TaskType.COMPLETION, - NAME, - serviceSettings, - null, - null, - new ElasticInferenceServiceComponents("http://localhost:8080"), // TODO - ConfigurationParseContext.PERSISTENT - ) + firstDefaultCompletionModel() ) ); } + private static ElasticInferenceServiceCompletionModel firstDefaultCompletionModel() { + var serviceSettings = new HashMap(1); + serviceSettings.put(MODEL_ID, "elastic-model"); // TODO + + return new ElasticInferenceServiceCompletionModel( + DEFAULT_EIS_COMPLETION_ENDPOINT_ID, + TaskType.COMPLETION, + NAME, + serviceSettings, + null, + null, + new ElasticInferenceServiceComponents("http://localhost:8080"), // TODO + ConfigurationParseContext.PERSISTENT + ); + } + private static ElasticInferenceServiceModel createModel( String inferenceEntityId, TaskType taskType, @@ -302,33 +319,14 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - if (DEFAULT_EIS_COMPLETION_ENDPOINT_ID.equals(inferenceEntityId)) { - var defaultServiceSettings = new HashMap(1); - defaultServiceSettings.put(MODEL_ID, "elastic-model"); // TODO - var serviceSettings = ElasticInferenceServiceCompletionServiceSettings.fromMap( - defaultServiceSettings, - ConfigurationParseContext.PERSISTENT - ); - - return new ElasticInferenceServiceCompletionModel( - V1_EIS_COMPLETION_MODEL_ID, - TaskType.COMPLETION, - NAME, - serviceSettings, - EmptyTaskSettings.INSTANCE, - null, - null - ); - } else { - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) - ); - } + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); } @Override From 5d430eb0e6767d66e46fc2b76fdc4a8b2304d864 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Wed, 8 Jan 2025 16:07:15 -0500 Subject: [PATCH 31/42] Update tests --- .../org/elasticsearch/xpack/inference/InferenceCrudIT.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 6634eecc2c959..5db87a0efe513 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -58,7 +58,7 @@ public void testCRUD() throws IOException { } var getAllModels = getAllModels(); - int numModels = 12; + int numModels = 13; assertThat(getAllModels, hasSize(numModels)); var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING); @@ -543,8 +543,8 @@ private static String expectedResult(String input) { } } - public void testGetZeroModels() throws IOException { + public void testGetCompletionModels() throws IOException { var models = getModels("_all", TaskType.COMPLETION); - assertThat(models, empty()); + assertEquals(models.size(), 1); } } From ad8c7ab042e8aa3f938590b6e9dfae56617ed59d Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 8 Jan 2025 20:48:37 +0000 Subject: [PATCH 32/42] [CI] Auto commit changes from spotless --- .../services/elastic/ElasticInferenceService.java | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index fd4777f87ae57..89754b3ac8832 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -42,7 +42,6 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -193,7 +192,8 @@ public void parseRequestConfig( return; } - try {Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ElasticInferenceServiceModel model = createModel( @@ -235,11 +235,7 @@ public List defaultConfigIds() { @Override public void defaultConfigs(ActionListener> defaultsListener) { - defaultsListener.onResponse( - List.of( - firstDefaultCompletionModel() - ) - ); + defaultsListener.onResponse(List.of(firstDefaultCompletionModel())); } private static ElasticInferenceServiceCompletionModel firstDefaultCompletionModel() { From 84b654ce102d0e8524f0a74e9cff443d746f9531 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 8 Jan 2025 21:12:23 +0000 Subject: [PATCH 33/42] [CI] Auto commit changes from spotless --- .../java/org/elasticsearch/xpack/inference/InferenceCrudIT.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 5db87a0efe513..9fe8c7ac2a028 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -35,7 +35,6 @@ import java.util.stream.Stream; import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalToIgnoringCase; import static org.hamcrest.Matchers.hasSize; From 1687bbaabb30aec51286f567919e7459ccfb56c1 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 9 Jan 2025 08:28:16 -0500 Subject: [PATCH 34/42] update model name constant --- .../services/elastic/ElasticInferenceService.java | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 89754b3ac8832..6a79068b51e50 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -72,11 +72,10 @@ public class ElasticInferenceService extends SenderService { private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION); private static final String SERVICE_NAME = "Elastic"; - public static final String V1_EIS_COMPLETION_MODEL_ID = "temp1"; - public static final String DEFAULT_EIS_COMPLETION_ENDPOINT_ID = "eis-alpha-1"; + public static final String DEFAULT_EIS_COMPLETION_ENDPOINT_ID_V1 = ".eis-alpha-1"; - public static final List DEFAULT_EIS_ENDPOINT_IDS = List.of(DEFAULT_EIS_COMPLETION_ENDPOINT_ID); + public static final List DEFAULT_EIS_ENDPOINT_IDS = List.of(DEFAULT_EIS_COMPLETION_ENDPOINT_ID_V1); public ElasticInferenceService( HttpRequestSender.Factory factory, @@ -229,7 +228,7 @@ public EnumSet supportedTaskTypes() { @Override public List defaultConfigIds() { - return List.of(new DefaultConfigId(DEFAULT_EIS_COMPLETION_ENDPOINT_ID, TaskType.COMPLETION, this)); + return List.of(new DefaultConfigId(DEFAULT_EIS_COMPLETION_ENDPOINT_ID_V1, TaskType.COMPLETION, this)); } @Override @@ -243,7 +242,7 @@ private static ElasticInferenceServiceCompletionModel firstDefaultCompletionMode serviceSettings.put(MODEL_ID, "elastic-model"); // TODO return new ElasticInferenceServiceCompletionModel( - DEFAULT_EIS_COMPLETION_ENDPOINT_ID, + DEFAULT_EIS_COMPLETION_ENDPOINT_ID_V1, TaskType.COMPLETION, NAME, serviceSettings, From ef802b08f15ea4535b685b1dda1717f8e323dd90 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 9 Jan 2025 13:35:10 +0000 Subject: [PATCH 35/42] [CI] Auto commit changes from spotless --- .../inference/services/elastic/ElasticInferenceService.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 6a79068b51e50..1099845bae543 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -72,7 +72,6 @@ public class ElasticInferenceService extends SenderService { private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION); private static final String SERVICE_NAME = "Elastic"; - public static final String DEFAULT_EIS_COMPLETION_ENDPOINT_ID_V1 = ".eis-alpha-1"; public static final List DEFAULT_EIS_ENDPOINT_IDS = List.of(DEFAULT_EIS_COMPLETION_ENDPOINT_ID_V1); From 986654b5108cdc27ab5df33391a567c36dcb9880 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 9 Jan 2025 09:58:54 -0500 Subject: [PATCH 36/42] fix merge conflicts --- ...OpenAiUnifiedChatCompletionRequestEntity.java | 4 ++-- .../ElasticInferenceServiceCompletionModel.java | 16 ---------------- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index fd05e1b98ac32..b80100c9e2f79 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -37,8 +37,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); - if (Strings.isNullOrEmpty(modelFields.user()) == false) { - builder.field(USER_FIELD, modelFields.user()); + if (Strings.isNullOrEmpty(model.getTaskSettings().user()) == false) { + builder.field(USER_FIELD, model.getTaskSettings().user()); } builder.endObject(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java index 2ca48b628ad1b..d129dcf9fea6b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java @@ -17,7 +17,6 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; -import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; @@ -28,21 +27,6 @@ import java.util.Map; import java.util.Objects; -public class ElasticInferenceServiceCompletionModel extends ElasticInferenceServiceModel { - - public static ElasticInferenceServiceCompletionModel of( - ElasticInferenceServiceCompletionModel model, - UnifiedCompletionRequest request - ) { -import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; - -import java.net.URI; -import java.net.URISyntaxException; -import java.util.Map; -import java.util.Objects; - public class ElasticInferenceServiceCompletionModel extends ElasticInferenceServiceModel { public static ElasticInferenceServiceCompletionModel of( From 1d79df735acd594248e8a06a3b9068c212d99954 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 9 Jan 2025 10:00:10 -0500 Subject: [PATCH 37/42] remove uncessary comment --- .../ElasticInferenceServiceUnifiedChatCompletionRequest.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java index 6304548a3d32b..112ead7057933 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java @@ -26,10 +26,6 @@ public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Request { - // Implementing OpenAiRequest to ensure compatibility with the OpenAI API interface - // This allows the ElasticInferenceService to handle requests in a standardized manner - // and leverage existing infrastructure for processing OpenAI-like requests. - private final ElasticInferenceServiceCompletionModel model; private final UnifiedChatInput unifiedChatInput; private final TraceContextHandler traceContextHandler; From fa7448907a7e8d01713cb0c91fd72c8def01753b Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 9 Jan 2025 10:01:05 -0500 Subject: [PATCH 38/42] remove todo --- .../services/elastic/ElasticInferenceServiceSettings.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java index c4b846cda980c..3801e0acc8727 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -34,9 +34,7 @@ public class ElasticInferenceServiceSettings { public ElasticInferenceServiceSettings(Settings settings) { eisGatewayUrl = EIS_GATEWAY_URL.get(settings); - // TODO fix this - // elasticInferenceServiceUrl = ELASTIC_INFERENCE_SERVICE_URL.get(settings); - elasticInferenceServiceUrl = "abc"; + elasticInferenceServiceUrl = ELASTIC_INFERENCE_SERVICE_URL.get(settings); } public static List> getSettingsDefinitions() { From 5b1a509121644bcb065d1bd62e93dfeb0d2a07ca Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 9 Jan 2025 15:03:50 -0500 Subject: [PATCH 39/42] Replace local constant with class variable --- .../inference/services/elastic/ElasticInferenceService.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 1099845bae543..ed4fffc299040 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -232,11 +232,10 @@ public List defaultConfigIds() { @Override public void defaultConfigs(ActionListener> defaultsListener) { - defaultsListener.onResponse(List.of(firstDefaultCompletionModel())); } - private static ElasticInferenceServiceCompletionModel firstDefaultCompletionModel() { + private ElasticInferenceServiceCompletionModel firstDefaultCompletionModel() { var serviceSettings = new HashMap(1); serviceSettings.put(MODEL_ID, "elastic-model"); // TODO @@ -247,7 +246,7 @@ private static ElasticInferenceServiceCompletionModel firstDefaultCompletionMode serviceSettings, null, null, - new ElasticInferenceServiceComponents("http://localhost:8080"), // TODO + elasticInferenceServiceComponents, ConfigurationParseContext.PERSISTENT ); } From 1cd60f6d94113656b02c0d78755134b5eff9ecaf Mon Sep 17 00:00:00 2001 From: Max Hniebergall <137079448+maxhniebergall@users.noreply.github.com> Date: Fri, 10 Jan 2025 13:25:17 -0500 Subject: [PATCH 40/42] Delete docs/changelog/119694.yaml --- docs/changelog/119694.yaml | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 docs/changelog/119694.yaml diff --git a/docs/changelog/119694.yaml b/docs/changelog/119694.yaml deleted file mode 100644 index 0205004980ba2..0000000000000 --- a/docs/changelog/119694.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 119694 -summary: "[Inference API] Add default endpoint for completion in elastic inference service" -area: Machine Learning -type: enhancement -issues: [] From 8657c09fb6c79d2ce63991b275779cb05c5ef1ee Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 13 Jan 2025 10:58:52 -0500 Subject: [PATCH 41/42] Adding the model id name and refactoring --- .../elastic/ElasticInferenceService.java | 26 +++++++++---------- ...renceServiceCompletionServiceSettings.java | 3 ++- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index ed4fffc299040..bee0e2e59ec09 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -16,6 +16,8 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -42,6 +44,7 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -71,10 +74,9 @@ public class ElasticInferenceService extends SenderService { private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION); private static final String SERVICE_NAME = "Elastic"; - - public static final String DEFAULT_EIS_COMPLETION_ENDPOINT_ID_V1 = ".eis-alpha-1"; - - public static final List DEFAULT_EIS_ENDPOINT_IDS = List.of(DEFAULT_EIS_COMPLETION_ENDPOINT_ID_V1); + private static final String DEFAULT_EIS_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; + private static final String DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1 = ".eis-alpha-1"; + private static final Set DEFAULT_EIS_ENDPOINT_IDS = Set.of(DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1); public ElasticInferenceService( HttpRequestSender.Factory factory, @@ -227,7 +229,7 @@ public EnumSet supportedTaskTypes() { @Override public List defaultConfigIds() { - return List.of(new DefaultConfigId(DEFAULT_EIS_COMPLETION_ENDPOINT_ID_V1, TaskType.COMPLETION, this)); + return List.of(new DefaultConfigId(DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1, TaskType.COMPLETION, this)); } @Override @@ -236,18 +238,14 @@ public void defaultConfigs(ActionListener> defaultsListener) { } private ElasticInferenceServiceCompletionModel firstDefaultCompletionModel() { - var serviceSettings = new HashMap(1); - serviceSettings.put(MODEL_ID, "elastic-model"); // TODO - return new ElasticInferenceServiceCompletionModel( - DEFAULT_EIS_COMPLETION_ENDPOINT_ID_V1, + DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1, TaskType.COMPLETION, NAME, - serviceSettings, - null, - null, - elasticInferenceServiceComponents, - ConfigurationParseContext.PERSISTENT + new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_EIS_CHAT_COMPLETION_MODEL_ID_V1, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java index 3c8182a7d41a4..931ce8109462e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; @@ -60,7 +61,7 @@ public static ElasticInferenceServiceCompletionServiceSettings fromMap(Map Date: Mon, 13 Jan 2025 14:15:57 -0500 Subject: [PATCH 42/42] Refactoring so we only create the model once --- .../elastic/ElasticInferenceService.java | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index bee0e2e59ec09..e19034644862a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -70,14 +70,15 @@ public class ElasticInferenceService extends SenderService { public static final String NAME = "elastic"; public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service"; - private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION); private static final String SERVICE_NAME = "Elastic"; private static final String DEFAULT_EIS_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; private static final String DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1 = ".eis-alpha-1"; private static final Set DEFAULT_EIS_ENDPOINT_IDS = Set.of(DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1); + private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; + private final List defaultEndpoints; + public ElasticInferenceService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, @@ -85,6 +86,23 @@ public ElasticInferenceService( ) { super(factory, serviceComponents); this.elasticInferenceServiceComponents = elasticInferenceServiceComponents; + this.defaultEndpoints = initDefaultEndpoints(); + } + + private List initDefaultEndpoints() { + return List.of(v1DefaultCompletionModel()); + } + + private ElasticInferenceServiceCompletionModel v1DefaultCompletionModel() { + return new ElasticInferenceServiceCompletionModel( + DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.COMPLETION, + NAME, + new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_EIS_CHAT_COMPLETION_MODEL_ID_V1, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ); } @Override @@ -234,19 +252,7 @@ public List defaultConfigIds() { @Override public void defaultConfigs(ActionListener> defaultsListener) { - defaultsListener.onResponse(List.of(firstDefaultCompletionModel())); - } - - private ElasticInferenceServiceCompletionModel firstDefaultCompletionModel() { - return new ElasticInferenceServiceCompletionModel( - DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1, - TaskType.COMPLETION, - NAME, - new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_EIS_CHAT_COMPLETION_MODEL_ID_V1, null), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - elasticInferenceServiceComponents - ); + defaultsListener.onResponse(defaultEndpoints); } private static ElasticInferenceServiceModel createModel(