Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/137677.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 137677
summary: "[Inference] Implementing the completion task type on EIS"
area: "Inference"
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,12 @@ private void executeTaskImmediately(RejectableTask task) {
e
);

task.onRejection(
new EsRejectedExecutionException(
format("Failed to execute request for inference id [%s]", task.getRequestManager().inferenceEntityId()),
false
)
var rejectionException = new EsRejectedExecutionException(
format("Failed to execute request for inference id [%s]", task.getRequestManager().inferenceEntityId()),
false
);
rejectionException.initCause(e);
task.onRejection(rejectionException);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
Expand Down Expand Up @@ -72,6 +73,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage;
import static org.elasticsearch.xpack.inference.services.openai.action.OpenAiActionCreator.USER_ROLE;

public class ElasticInferenceService extends SenderService {

Expand All @@ -86,6 +88,7 @@ public class ElasticInferenceService extends SenderService {
public static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
TaskType.SPARSE_EMBEDDING,
TaskType.CHAT_COMPLETION,
TaskType.COMPLETION,
TaskType.RERANK,
TaskType.TEXT_EMBEDDING
);
Expand All @@ -101,6 +104,7 @@ public class ElasticInferenceService extends SenderService {
*/
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(
TaskType.SPARSE_EMBEDDING,
TaskType.COMPLETION,
TaskType.RERANK,
TaskType.TEXT_EMBEDDING
);
Expand Down Expand Up @@ -162,7 +166,8 @@ protected void doUnifiedCompletionInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof ElasticInferenceServiceCompletionModel == false) {
if (model instanceof ElasticInferenceServiceCompletionModel == false
|| (model.getTaskType() != TaskType.CHAT_COMPLETION && model.getTaskType() != TaskType.COMPLETION)) {
listener.onFailure(createInvalidModelException(model));
return;
}
Expand Down Expand Up @@ -212,10 +217,15 @@ protected void doInfer(

var elasticInferenceServiceModel = (ElasticInferenceServiceModel) model;

// For ElasticInferenceServiceCompletionModel, convert ChatCompletionInput to UnifiedChatInput
// since the request manager expects UnifiedChatInput
final InferenceInputs finalInputs = (elasticInferenceServiceModel instanceof ElasticInferenceServiceCompletionModel
&& inputs instanceof ChatCompletionInput) ? new UnifiedChatInput((ChatCompletionInput) inputs, USER_ROLE) : inputs;

actionCreator.create(
elasticInferenceServiceModel,
currentTraceInfo,
listener.delegateFailureAndWrap((delegate, action) -> action.execute(inputs, timeout, delegate))
listener.delegateFailureAndWrap((delegate, action) -> action.execute(finalInputs, timeout, delegate))
);
}

Expand Down Expand Up @@ -379,7 +389,7 @@ private static ElasticInferenceServiceModel createModel(
context,
chunkingSettings
);
case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
case CHAT_COMPLETION, COMPLETION -> new ElasticInferenceServiceCompletionModel(
inferenceEntityId,
taskType,
NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,6 @@ public String getInferenceEntityId() {

@Override
public boolean isStreaming() {
return true;
return unifiedChatInput.stream();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws
thrownException.getMessage(),
is(
"Inference entity [model_id] does not support task type [chat_completion] "
+ "for inference, the task type must be one of [text_embedding, sparse_embedding, rerank]. "
+ "for inference, the task type must be one of [text_embedding, sparse_embedding, rerank, completion]. "
+ "The task type for the inference entity is chat_completion, "
+ "please use the _inference/chat_completion/model_id/_stream URL."
)
Expand Down Expand Up @@ -1133,7 +1133,7 @@ private InferenceEventsAssertion testUnifiedStream(int responseCode, String resp
webServer.enqueue(new MockResponse().setResponseCode(responseCode).setBody(responseJson));
var model = new ElasticInferenceServiceCompletionModel(
"id",
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION,
"elastic",
new ElasticInferenceServiceCompletionServiceSettings("model_id"),
EmptyTaskSettings.INSTANCE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,7 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra
List.of(
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID,
// EIS does not yet support completions so this model will be ignored
EnumSet.of(TaskType.COMPLETION)
EnumSet.noneOf(TaskType.class)
)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,63 @@ public void testOverridingModelId() {
assertThat(overriddenModel.getServiceSettings().modelId(), is("new_model_id"));
assertThat(overriddenModel.getTaskType(), is(TaskType.COMPLETION));
}

public void testUriCreation() {
var url = "http://eis-gateway.com";
var model = createModel(url, "my-model-id");

var uri = model.uri();
assertThat(uri.toString(), is(url + "/api/v1/chat"));
}

public void testGetServiceSettings() {
var modelId = "test-model";
var model = createModel("http://eis-gateway.com", modelId);

var serviceSettings = model.getServiceSettings();
assertThat(serviceSettings.modelId(), is(modelId));
}

public void testGetTaskType() {
var model = createModel("http://eis-gateway.com", "my-model-id");
assertThat(model.getTaskType(), is(TaskType.COMPLETION));
}

public void testGetInferenceEntityId() {
var inferenceEntityId = "test-id";
var model = new ElasticInferenceServiceCompletionModel(
inferenceEntityId,
TaskType.COMPLETION,
"elastic",
new ElasticInferenceServiceCompletionServiceSettings("my-model-id"),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
ElasticInferenceServiceComponents.of("http://eis-gateway.com")
);

assertThat(model.getInferenceEntityId(), is(inferenceEntityId));
}

public void testModelWithOverriddenServiceSettings() {
var originalModel = createModel("http://eis-gateway.com", "original-model");
var newServiceSettings = new ElasticInferenceServiceCompletionServiceSettings("new-model");

var overriddenModel = new ElasticInferenceServiceCompletionModel(originalModel, newServiceSettings);

assertThat(overriddenModel.getServiceSettings().modelId(), is("new-model"));
assertThat(overriddenModel.getTaskType(), is(TaskType.COMPLETION));
assertThat(overriddenModel.uri().toString(), is(originalModel.uri().toString()));
}

public static ElasticInferenceServiceCompletionModel createModel(String url, String modelId) {
return new ElasticInferenceServiceCompletionModel(
"id",
TaskType.COMPLETION,
"elastic",
new ElasticInferenceServiceCompletionServiceSettings(modelId),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
ElasticInferenceServiceComponents.of(url)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModelTests;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.openai.request.OpenAiUnifiedChatCompletionRequestEntity;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals;
import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel;
Expand Down Expand Up @@ -67,4 +69,152 @@ public void testModelUserFieldsSerialization() throws IOException {
assertJsonEquals(jsonString, expectedJson);
}

public void testSerialization_NonStreaming_ForCompletion() throws IOException {
// Test non-streaming case (used for COMPLETION task type)
var unifiedChatInput = new UnifiedChatInput(List.of("What is 2+2?"), ROLE, false);
var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id");
var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId());

XContentBuilder builder = JsonXContent.contentBuilder();
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);

String jsonString = Strings.toString(builder);
String expectedJson = """
{
"messages": [
{
"content": "What is 2+2?",
"role": "user"
}
],
"model": "my-model-id",
"n": 1,
"stream": false
}
""";
assertJsonEquals(jsonString, expectedJson);
}

public void testSerialization_MultipleInputs_NonStreaming() throws IOException {
// Test multiple inputs converted to messages (used for COMPLETION task type)
var unifiedChatInput = new UnifiedChatInput(List.of("What is 2+2?", "What is the capital of France?"), ROLE, false);
var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id");
var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId());

XContentBuilder builder = JsonXContent.contentBuilder();
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);

String jsonString = Strings.toString(builder);
String expectedJson = """
{
"messages": [
{
"content": "What is 2+2?",
"role": "user"
},
{
"content": "What is the capital of France?",
"role": "user"
}
],
"model": "my-model-id",
"n": 1,
"stream": false
}
""";
assertJsonEquals(jsonString, expectedJson);
}

public void testSerialization_EmptyInput_NonStreaming() throws IOException {
var unifiedChatInput = new UnifiedChatInput(List.of(""), ROLE, false);
var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id");
var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId());

XContentBuilder builder = JsonXContent.contentBuilder();
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);

String jsonString = Strings.toString(builder);
String expectedJson = """
{
"messages": [
{
"content": "",
"role": "user"
}
],
"model": "my-model-id",
"n": 1,
"stream": false
}
""";
assertJsonEquals(jsonString, expectedJson);
}

public void testSerialization_AlwaysSetsNToOne_NonStreaming() throws IOException {
// Verify n is always 1 regardless of number of inputs
var unifiedChatInput = new UnifiedChatInput(List.of("input1", "input2", "input3"), ROLE, false);
var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "my-model-id");
var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId());

XContentBuilder builder = JsonXContent.contentBuilder();
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);

String jsonString = Strings.toString(builder);
String expectedJson = """
{
"messages": [
{
"content": "input1",
"role": "user"
},
{
"content": "input2",
"role": "user"
},
{
"content": "input3",
"role": "user"
}
],
"model": "my-model-id",
"n": 1,
"stream": false
}
""";
assertJsonEquals(jsonString, expectedJson);
}

public void testSerialization_AllMessagesHaveUserRole_NonStreaming() throws IOException {
// Verify all messages have "user" role when converting from simple inputs
var unifiedChatInput = new UnifiedChatInput(List.of("first", "second", "third"), ROLE, false);
var model = ElasticInferenceServiceCompletionModelTests.createModel("http://eis-gateway.com", "test-model");
var entity = new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId());

XContentBuilder builder = JsonXContent.contentBuilder();
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);

String jsonString = Strings.toString(builder);
String expectedJson = """
{
"messages": [
{
"content": "first",
"role": "user"
},
{
"content": "second",
"role": "user"
},
{
"content": "third",
"role": "user"
}
],
"model": "test-model",
"n": 1,
"stream": false
}
""";
assertJsonEquals(jsonString, expectedJson);
}
}
Loading