Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,23 @@ public static void init() {
public void testGetDefaultEndpoints() throws IOException {
var allModels = getAllModels();
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
var completionModels = getModels("_all", TaskType.COMPLETION);

assertThat(allModels, hasSize(8));
assertThat(allModels, hasSize(9));
assertThat(chatCompletionModels, hasSize(2));
assertThat(completionModels, hasSize(1));

for (var model : chatCompletionModels) {
assertEquals("chat_completion", model.get("task_type"));
}

for (var model : completionModels) {
assertEquals("completion", model.get("task_type"));
}

assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
assertInferenceIdTaskType(allModels, ".gp-llm-v2-chat_completion", TaskType.CHAT_COMPLETION);
assertInferenceIdTaskType(allModels, ".gp-llm-v2-completion", TaskType.COMPLETION);
assertInferenceIdTaskType(allModels, ".elser-2-elastic", TaskType.SPARSE_EMBEDDING);
assertInferenceIdTaskType(allModels, ".jina-embeddings-v3", TaskType.TEXT_EMBEDDING);
assertInferenceIdTaskType(allModels, ".elastic-rerank-v1", TaskType.RERANK);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class InternalPreconfiguredEndpoints {
// gp-llm-v2
public static final String GP_LLM_V2_MODEL_ID = "gp-llm-v2";
public static final String GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID = ".gp-llm-v2-chat_completion";
public static final String GP_LLM_V2_COMPLETION_ENDPOINT_ID = ".gp-llm-v2-completion";

// elser-2
public static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2";
Expand Down Expand Up @@ -80,8 +81,7 @@ public record MinimalModel(
DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1,
TaskType.CHAT_COMPLETION,
ElasticInferenceService.NAME,
COMPLETION_SERVICE_SETTINGS,
ChunkingSettingsBuilder.DEFAULT_SETTINGS
COMPLETION_SERVICE_SETTINGS
),
COMPLETION_SERVICE_SETTINGS
)
Expand All @@ -93,8 +93,16 @@ public record MinimalModel(
GP_LLM_V2_CHAT_COMPLETION_ENDPOINT_ID,
TaskType.CHAT_COMPLETION,
ElasticInferenceService.NAME,
GP_LLM_V2_COMPLETION_SERVICE_SETTINGS,
ChunkingSettingsBuilder.DEFAULT_SETTINGS
GP_LLM_V2_COMPLETION_SERVICE_SETTINGS
),
GP_LLM_V2_COMPLETION_SERVICE_SETTINGS
),
new MinimalModel(
new ModelConfigurations(
GP_LLM_V2_COMPLETION_ENDPOINT_ID,
TaskType.COMPLETION,
ElasticInferenceService.NAME,
GP_LLM_V2_COMPLETION_SERVICE_SETTINGS
),
GP_LLM_V2_COMPLETION_SERVICE_SETTINGS
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

package org.elasticsearch.xpack.inference.services.elastic;

import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;

import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;

public class InternalPreconfiguredEndpointsTests extends ESTestCase {
public void testGetWithModelName_ReturnsAnEmptyList_IfNameDoesNotExist() {
Expand All @@ -20,4 +22,18 @@ public void testGetWithModelName_ReturnsChatCompletionModels() {
var models = InternalPreconfiguredEndpoints.getWithModelName(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
assertThat(models, hasSize(1));
}

public void testGetWithModelName_ReturnsGpLlmV2Models() {
var models = InternalPreconfiguredEndpoints.getWithModelName(InternalPreconfiguredEndpoints.GP_LLM_V2_MODEL_ID);
assertThat(models, hasSize(2));
var taskTypes = models.stream().map(m -> m.configurations().getTaskType()).toList();
assertTrue("Should contain CHAT_COMPLETION", taskTypes.contains(TaskType.CHAT_COMPLETION));
assertTrue("Should contain COMPLETION", taskTypes.contains(TaskType.COMPLETION));
}

public void testGetWithInferenceId_ReturnsGpLlmV2CompletionEndpoint() {
var model = InternalPreconfiguredEndpoints.getWithInferenceId(InternalPreconfiguredEndpoints.GP_LLM_V2_COMPLETION_ENDPOINT_ID);
assertThat(model.configurations().getInferenceEntityId(), is(InternalPreconfiguredEndpoints.GP_LLM_V2_COMPLETION_ENDPOINT_ID));
assertThat(model.configurations().getTaskType(), is(TaskType.COMPLETION));
}
}