diff --git a/docs/changelog/137677.yaml b/docs/changelog/137677.yaml new file mode 100644 index 0000000000000..56b41374dc73c --- /dev/null +++ b/docs/changelog/137677.yaml @@ -0,0 +1,5 @@ +pr: 137677 +summary: "[Inference] Implementing the completion task type on EIS" +area: "Inference" +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index 4ce3b65866de1..589b59eaf3102 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -305,12 +305,12 @@ private void executeTaskImmediately(RejectableTask task) { e ); - task.onRejection( - new EsRejectedExecutionException( - format("Failed to execute request for inference id [%s]", task.getRequestManager().inferenceEntityId()), - false - ) + var rejectionException = new EsRejectedExecutionException( + format("Failed to execute request for inference id [%s]", task.getRequestManager().inferenceEntityId()), + false ); + rejectionException.initCause(e); + task.onRejection(rejectionException); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 0ba5dddb5fcae..3ebf971d48d53 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -39,6 +39,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; @@ -72,6 +73,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage; +import static org.elasticsearch.xpack.inference.services.openai.action.OpenAiActionCreator.USER_ROLE; public class ElasticInferenceService extends SenderService { @@ -86,6 +88,7 @@ public class ElasticInferenceService extends SenderService { public static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of( TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, + TaskType.COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING ); @@ -101,6 +104,7 @@ public class ElasticInferenceService extends SenderService { */ private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of( TaskType.SPARSE_EMBEDDING, + TaskType.COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING ); @@ -162,7 +166,8 @@ protected void doUnifiedCompletionInfer( TimeValue timeout, ActionListener listener ) { - if (model instanceof ElasticInferenceServiceCompletionModel == false) { + if (model instanceof ElasticInferenceServiceCompletionModel == false + || (model.getTaskType() != TaskType.CHAT_COMPLETION && model.getTaskType() != TaskType.COMPLETION)) { listener.onFailure(createInvalidModelException(model)); return; } @@ -212,10 +217,15 @@ protected void doInfer( var elasticInferenceServiceModel = (ElasticInferenceServiceModel) model; + // For ElasticInferenceServiceCompletionModel, convert ChatCompletionInput to UnifiedChatInput + // since the request manager expects UnifiedChatInput + final InferenceInputs finalInputs = (elasticInferenceServiceModel instanceof ElasticInferenceServiceCompletionModel + && inputs instanceof ChatCompletionInput) ? new UnifiedChatInput((ChatCompletionInput) inputs, USER_ROLE) : inputs; + actionCreator.create( elasticInferenceServiceModel, currentTraceInfo, - listener.delegateFailureAndWrap((delegate, action) -> action.execute(inputs, timeout, delegate)) + listener.delegateFailureAndWrap((delegate, action) -> action.execute(finalInputs, timeout, delegate)) ); } @@ -379,7 +389,7 @@ private static ElasticInferenceServiceModel createModel( context, chunkingSettings ); - case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel( + case CHAT_COMPLETION, COMPLETION -> new ElasticInferenceServiceCompletionModel( inferenceEntityId, taskType, NAME, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequest.java index 20388bee77957..10d49a2086376 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequest.java @@ -84,6 +84,6 @@ public String getInferenceEntityId() { @Override public boolean isStreaming() { - return true; + return unifiedChatInput.stream(); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 5e70fe01eee97..cff7c0ef5d97f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -492,7 +492,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws thrownException.getMessage(), is( "Inference entity [model_id] does not support task type [chat_completion] " - + "for inference, the task type must be one of [text_embedding, sparse_embedding, rerank]. " + + "for inference, the task type must be one of [text_embedding, sparse_embedding, rerank, completion]. " + "The task type for the inference entity is chat_completion, " + "please use the _inference/chat_completion/model_id/_stream URL." ) @@ -1133,7 +1133,7 @@ private InferenceEventsAssertion testUnifiedStream(int responseCode, String resp webServer.enqueue(new MockResponse().setResponseCode(responseCode).setBody(responseJson)); var model = new ElasticInferenceServiceCompletionModel( "id", - TaskType.COMPLETION, + TaskType.CHAT_COMPLETION, "elastic", new ElasticInferenceServiceCompletionServiceSettings("model_id"), EmptyTaskSettings.INSTANCE, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java index 75ca0b00a8e12..5d9c236b20a13 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -419,8 +419,7 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID, - // EIS does not yet support completions so this model will be ignored - EnumSet.of(TaskType.COMPLETION) + EnumSet.noneOf(TaskType.class) ) ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java index 58750e7d8c456..89ea08edc46fc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java @@ -47,4 +47,63 @@ public void testOverridingModelId() { assertThat(overriddenModel.getServiceSettings().modelId(), is("new_model_id")); assertThat(overriddenModel.getTaskType(), is(TaskType.COMPLETION)); } + + public void testUriCreation() { + var url = "http://eis-gateway.com"; + var model = createModel(url, "my-model-id"); + + var uri = model.uri(); + assertThat(uri.toString(), is(url + "/api/v1/chat")); + } + + public void testGetServiceSettings() { + var modelId = "test-model"; + var model = createModel("http://eis-gateway.com", modelId); + + var serviceSettings = model.getServiceSettings(); + assertThat(serviceSettings.modelId(), is(modelId)); + } + + public void testGetTaskType() { + var model = createModel("http://eis-gateway.com", "my-model-id"); + assertThat(model.getTaskType(), is(TaskType.COMPLETION)); + } + + public void testGetInferenceEntityId() { + var inferenceEntityId = "test-id"; + var model = new ElasticInferenceServiceCompletionModel( + inferenceEntityId, + TaskType.COMPLETION, + "elastic", + new ElasticInferenceServiceCompletionServiceSettings("my-model-id"), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.of("http://eis-gateway.com") + ); + + assertThat(model.getInferenceEntityId(), is(inferenceEntityId)); + } + + public void testModelWithOverriddenServiceSettings() { + var originalModel = createModel("http://eis-gateway.com", "original-model"); + var newServiceSettings = new ElasticInferenceServiceCompletionServiceSettings("new-model"); + + var overriddenModel = new ElasticInferenceServiceCompletionModel(originalModel, newServiceSettings); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("new-model")); + assertThat(overriddenModel.getTaskType(), is(TaskType.COMPLETION)); + assertThat(overriddenModel.uri().toString(), is(originalModel.uri().toString())); + } + + public static ElasticInferenceServiceCompletionModel createModel(String url, String modelId) { + return new ElasticInferenceServiceCompletionModel( + "id", + TaskType.COMPLETION, + "elastic", + new ElasticInferenceServiceCompletionServiceSettings(modelId), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.of(url) + ); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java index b067350e26aa5..aaf57345a8512 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java @@ -14,11 +14,13 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModelTests; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.openai.request.OpenAiUnifiedChatCompletionRequestEntity; import java.io.IOException; import java.util.ArrayList; +import java.util.List; import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; @@ -67,4 +69,152 @@ public void testModelUserFieldsSerialization() throws IOException { assertJsonEquals(jsonString, expectedJson); } + public void testSerialization_NonStreaming_ForCompletion() throws IOException { + // Test non-streaming case (used for COMPLETION task type) + var unifiedChatInput = new UnifiedChatInput(List.of("What is 2+2?"), ROLE, false); + var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id"); + var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "What is 2+2?", + "role": "user" + } + ], + "model": "my-model-id", + "n": 1, + "stream": false + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_MultipleInputs_NonStreaming() throws IOException { + // Test multiple inputs converted to messages (used for COMPLETION task type) + var unifiedChatInput = new UnifiedChatInput(List.of("What is 2+2?", "What is the capital of France?"), ROLE, false); + var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id"); + var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "What is 2+2?", + "role": "user" + }, + { + "content": "What is the capital of France?", + "role": "user" + } + ], + "model": "my-model-id", + "n": 1, + "stream": false + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_EmptyInput_NonStreaming() throws IOException { + var unifiedChatInput = new UnifiedChatInput(List.of(""), ROLE, false); + var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id"); + var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "", + "role": "user" + } + ], + "model": "my-model-id", + "n": 1, + "stream": false + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_AlwaysSetsNToOne_NonStreaming() throws IOException { + // Verify n is always 1 regardless of number of inputs + var unifiedChatInput = new UnifiedChatInput(List.of("input1", "input2", "input3"), ROLE, false); + var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id"); + var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "input1", + "role": "user" + }, + { + "content": "input2", + "role": "user" + }, + { + "content": "input3", + "role": "user" + } + ], + "model": "my-model-id", + "n": 1, + "stream": false + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_AllMessagesHaveUserRole_NonStreaming() throws IOException { + // Verify all messages have "user" role when converting from simple inputs + var unifiedChatInput = new UnifiedChatInput(List.of("first", "second", "third"), ROLE, false); + var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "test-model"); + var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "first", + "role": "user" + }, + { + "content": "second", + "role": "user" + }, + { + "content": "third", + "role": "user" + } + ], + "model": "test-model", + "n": 1, + "stream": false + } + """; + assertJsonEquals(jsonString, expectedJson); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestTests.java new file mode 100644 index 0000000000000..37d2a9cff9154 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestTests.java @@ -0,0 +1,247 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModelTests; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; + +public class ElasticInferenceServiceUnifiedChatCompletionRequestTests extends ESTestCase { + + public void testCreateHttpRequest_SingleInput() throws IOException { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + var input = "What is 2+2?"; + + var request = createRequest(url, modelId, List.of(input), false); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is(url + "/api/v1/chat")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(4)); + assertThat(requestMap.get("model"), is(modelId)); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + @SuppressWarnings("unchecked") + var messages = (List>) requestMap.get("messages"); + assertThat(messages.size(), is(1)); + assertThat(messages.get(0).get("content"), is(input)); + assertThat(messages.get(0).get("role"), is("user")); + } + + public void testCreateHttpRequest_MultipleInputs() throws IOException { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + var inputs = List.of("What is 2+2?", "What is the capital of France?"); + + var request = createRequest(url, modelId, inputs, false); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + @SuppressWarnings("unchecked") + var messages = (List>) requestMap.get("messages"); + assertThat(messages.size(), is(2)); + assertThat(messages.get(0).get("content"), is(inputs.get(0))); + assertThat(messages.get(0).get("role"), is("user")); + assertThat(messages.get(1).get("content"), is(inputs.get(1))); + assertThat(messages.get(1).get("role"), is("user")); + } + + public void testCreateHttpRequest_NonStreaming() throws IOException { + // Test non-streaming case (used for COMPLETION task type) + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + var input = "What is 2+2?"; + + var request = createRequest(url, modelId, List.of(input), false); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.get("stream"), is(false)); + assertFalse(request.isStreaming()); + } + + public void testCreateHttpRequest_Streaming() throws IOException { + // Test streaming case (used for CHAT_COMPLETION task type) + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + var input = "What is 2+2?"; + + var request = createRequest(url, modelId, List.of(input), true); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.get("stream"), is(true)); + assertTrue(request.isStreaming()); + } + + public void testGetURI() { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + + var request = createRequest(url, modelId, List.of("input"), false); + + assertThat(request.getURI().toString(), is(url + "/api/v1/chat")); + } + + public void testGetInferenceEntityId() { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + var inferenceEntityId = "test-endpoint-id"; + + var model = new ElasticInferenceServiceCompletionModel( + inferenceEntityId, + TaskType.COMPLETION, + "elastic", + new ElasticInferenceServiceCompletionServiceSettings(modelId), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.of(url) + ); + + var unifiedChatInput = new UnifiedChatInput(List.of("input"), "user", false); + var request = new ElasticInferenceServiceUnifiedChatCompletionRequest( + unifiedChatInput, + model, + new TraceContext("trace-parent", "trace-state"), + randomElasticInferenceServiceRequestMetadata(), + CCMAuthenticationApplierFactory.NOOP_APPLIER + ); + + assertThat(request.getInferenceEntityId(), is(inferenceEntityId)); + } + + public void testTruncate_ReturnsSameInstance() throws IOException { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + var input = "What is 2+2?"; + + var request = createRequest(url, modelId, List.of(input), false); + var truncatedRequest = request.truncate(); + + // Should return the same instance (no truncation) + assertThat(truncatedRequest, is(request)); + + // Verify content is unchanged + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + @SuppressWarnings("unchecked") + var messages = (List>) requestMap.get("messages"); + assertThat(messages.size(), is(1)); + assertThat(messages.get(0).get("content"), is(input)); + } + + public void testGetTruncationInfo_ReturnsNull() { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + + var request = createRequest(url, modelId, List.of("input"), false); + + assertThat(request.getTruncationInfo(), nullValue()); + } + + public void testIsStreaming_NonStreamingReturnsFalse() { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + + var request = createRequest(url, modelId, List.of("input"), false); + + assertFalse(request.isStreaming()); + } + + public void testIsStreaming_StreamingReturnsTrue() { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + + var request = createRequest(url, modelId, List.of("input"), true); + + assertTrue(request.isStreaming()); + } + + public void testTraceContextPropagatedThroughHTTPHeaders() { + var url = "http://eis-gateway.com"; + var modelId = "my-model-id"; + var traceParent = randomAlphaOfLength(10); + var traceState = randomAlphaOfLength(10); + + var model = ElasticInferenceServiceCompletionModelTests.createModel(url, modelId); + var unifiedChatInput = new UnifiedChatInput(List.of("input"), "user", false); + var request = new ElasticInferenceServiceUnifiedChatCompletionRequest( + unifiedChatInput, + model, + new TraceContext(traceParent, traceState), + randomElasticInferenceServiceRequestMetadata(), + CCMAuthenticationApplierFactory.NOOP_APPLIER + ); + + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(Task.TRACE_PARENT_HTTP_HEADER).getValue(), is(traceParent)); + assertThat(httpPost.getLastHeader(Task.TRACE_STATE).getValue(), is(traceState)); + } + + private ElasticInferenceServiceUnifiedChatCompletionRequest createRequest( + String url, + String modelId, + List inputs, + boolean stream + ) { + var model = ElasticInferenceServiceCompletionModelTests.createModel(url, modelId); + var unifiedChatInput = new UnifiedChatInput(inputs, "user", stream); + + return new ElasticInferenceServiceUnifiedChatCompletionRequest( + unifiedChatInput, + model, + new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), + randomElasticInferenceServiceRequestMetadata(), + CCMAuthenticationApplierFactory.NOOP_APPLIER + ); + } +}