diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index b8ff560ced006..2607e68acadbb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -1079,11 +1079,27 @@ public interface EnumConstructor> { E apply(String name) throws IllegalArgumentException; } - public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName) { + /** + * Create an exception for when the task type is not valid for the service. + */ + public static ElasticsearchStatusException createInvalidTaskTypeException( + String inferenceEntityId, + String serviceName, + TaskType taskType, + ConfigurationParseContext parseContext + ) { + var message = parseContext == ConfigurationParseContext.PERSISTENT + ? parsePersistedConfigErrorMsg(inferenceEntityId, serviceName, taskType) + : TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName); + return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST); + } + + private static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) { return format( - "Failed to parse stored model [%s] for [%s] service, please delete and add the service again", + "Failed to parse stored model [%s] for [%s] service, error: [%s]. Please delete and add the service again", inferenceEntityId, - serviceName + serviceName, + TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java index b677ec642075e..69eddd7ecade2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.ai21; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; @@ -27,7 +26,6 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; @@ -55,7 +53,7 @@ import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -178,14 +176,7 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Ai21Model model = createModel( - modelId, - taskType, - serviceSettingsMap, - serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), - ConfigurationParseContext.REQUEST - ); + Ai21Model model = createModel(modelId, taskType, serviceSettingsMap, serviceSettingsMap, ConfigurationParseContext.REQUEST); throwIfNotEmptyMap(config, NAME); throwIfNotEmptyMap(serviceSettingsMap, NAME); @@ -208,13 +199,7 @@ public Ai21Model parsePersistedConfigWithSecrets( removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - return createModelFromPersistent( - modelId, - taskType, - serviceSettingsMap, - secretSettingsMap, - parsePersistedConfigErrorMsg(modelId, NAME) - ); + return createModelFromPersistent(modelId, taskType, serviceSettingsMap, secretSettingsMap); } @Override @@ -222,7 +207,7 @@ public Ai21Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - return createModelFromPersistent(modelId, taskType, serviceSettingsMap, null, parsePersistedConfigErrorMsg(modelId, NAME)); + return createModelFromPersistent(modelId, taskType, serviceSettingsMap, null); } @Override @@ -240,14 +225,13 @@ private static Ai21Model createModel( TaskType taskType, Map serviceSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { switch (taskType) { case CHAT_COMPLETION, COMPLETION: return new Ai21ChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context); default: - throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + throw createInvalidTaskTypeException(modelId, NAME, taskType, context); } } @@ -255,17 +239,9 @@ private Ai21Model createModelFromPersistent( String inferenceEntityId, TaskType taskType, Map serviceSettings, - Map secretSettings, - String failureMessage + Map secretSettings ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - secretSettings, - failureMessage, - ConfigurationParseContext.PERSISTENT - ); + return createModel(inferenceEntityId, taskType, serviceSettings, secretSettings, ConfigurationParseContext.PERSISTENT); } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index f474850b9f190..5f2378d40116d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.alibabacloudsearch; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -32,7 +31,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -59,7 +57,7 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -135,7 +133,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -165,8 +162,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -175,7 +171,6 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -187,7 +182,6 @@ private static AlibabaCloudSearchModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -229,7 +223,7 @@ private static AlibabaCloudSearchModel createModel( secretSettings, context ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -255,8 +249,7 @@ public AlibabaCloudSearchModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + secretSettingsMap ); } @@ -276,8 +269,7 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta serviceSettingsMap, taskSettingsMap, chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + null ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 11204018a5523..5e3ba81b7e1bc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -60,7 +60,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -204,7 +204,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -241,7 +240,6 @@ public Model parsePersistedConfigWithSecrets( taskSettingsMap, chunkingSettings, secretSettingsMap, - parsePersistedConfigErrorMsg(modelId, NAME), ConfigurationParseContext.PERSISTENT ); } @@ -263,7 +261,6 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { switch (taskType) { @@ -318,7 +314,7 @@ private static AmazonBedrockModel createModel( checkChatCompletionProviderForTopKParameter(model); return model; } - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 224d62f83f28b..29a1582ce6236 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.anthropic; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -28,7 +27,6 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; @@ -48,7 +46,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -94,7 +92,6 @@ public void parseRequestConfig( serviceSettingsMap, taskSettingsMap, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -113,8 +110,7 @@ private static AnthropicModel createModelFromPersistent( TaskType taskType, Map serviceSettings, Map taskSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -122,7 +118,6 @@ private static AnthropicModel createModelFromPersistent( serviceSettings, taskSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -133,7 +128,6 @@ private static AnthropicModel createModel( Map serviceSettings, Map taskSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -146,7 +140,7 @@ private static AnthropicModel createModel( secretSettings, context ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -161,14 +155,7 @@ public AnthropicModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap); } @Override @@ -176,14 +163,7 @@ public AnthropicModel parsePersistedConfig(String inferenceEntityId, TaskType ta Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 7578aa702ad7c..60b8bbdf86fa5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -60,7 +60,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -185,7 +185,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -221,8 +220,7 @@ public AzureAiStudioModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + secretSettingsMap ); } @@ -236,15 +234,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override @@ -279,7 +269,6 @@ private static AzureAiStudioModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { @@ -305,7 +294,7 @@ private static AzureAiStudioModel createModel( context ); case RERANK -> model = new AzureAiStudioRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); } final var azureAiStudioServiceSettings = (AzureAiStudioServiceSettings) model.getServiceSettings(); checkProviderAndEndpointTypeForTask(taskType, azureAiStudioServiceSettings.provider(), azureAiStudioServiceSettings.endpointType()); @@ -318,8 +307,7 @@ private AzureAiStudioModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - Map secretSettings, - String failureMessage + Map secretSettings ) { return createModel( inferenceEntityId, @@ -328,7 +316,6 @@ private AzureAiStudioModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 077e5361dd46f..de067ae0096b5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.azureopenai; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -30,7 +29,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -55,7 +53,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -114,7 +112,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -134,8 +131,7 @@ private static AzureOpenAiModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -144,7 +140,6 @@ private static AzureOpenAiModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -156,7 +151,6 @@ private static AzureOpenAiModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { switch (taskType) { @@ -183,7 +177,7 @@ private static AzureOpenAiModel createModel( context ); } - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); } } @@ -209,8 +203,7 @@ public AzureOpenAiModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + secretSettingsMap ); } @@ -224,15 +217,7 @@ public AzureOpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 2561f198075e2..f1e34138aad5c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.cohere; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -32,7 +31,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -60,7 +58,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -130,7 +128,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -150,8 +147,7 @@ private static CohereModel createModelWithoutLoggingDeprecations( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -160,7 +156,6 @@ private static CohereModel createModelWithoutLoggingDeprecations( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -172,7 +167,6 @@ private static CohereModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -186,7 +180,7 @@ private static CohereModel createModel( ); case RERANK -> new CohereRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context); case COMPLETION -> new CohereCompletionModel(inferenceEntityId, serviceSettings, secretSettings, context); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -212,8 +206,7 @@ public CohereModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + secretSettingsMap ); } @@ -233,8 +226,7 @@ public CohereModel parsePersistedConfig(String inferenceEntityId, TaskType taskT serviceSettingsMap, taskSettingsMap, chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + null ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java index c67fc328acdd2..9088de1dacd2d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java @@ -46,6 +46,8 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; + /** * Contextual AI inference service for reranking tasks. * This service uses the Contextual AI REST API to perform document reranking. @@ -97,7 +99,6 @@ public void parseRequestConfig( serviceSettingsMap, taskSettingsMap, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -117,11 +118,10 @@ private static ContextualAiRerankModel createModel( Map serviceSettings, Map taskSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { if (taskType != TaskType.RERANK) { - throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); } return new ContextualAiRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context); @@ -144,7 +144,6 @@ public ContextualAiRerankModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, secretSettingsMap, - ServiceUtils.parsePersistedConfigErrorMsg(inferenceEntityId, NAME), ConfigurationParseContext.PERSISTENT ); } @@ -154,15 +153,7 @@ public ContextualAiRerankModel parsePersistedConfig(String inferenceEntityId, Ta Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - return createModel( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - null, - ServiceUtils.parsePersistedConfigErrorMsg(inferenceEntityId, NAME), - ConfigurationParseContext.PERSISTENT - ); + return createModel(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null, ConfigurationParseContext.PERSISTENT); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index fd29b02012185..aa7fd1337428f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -57,9 +57,9 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.inference.TaskType.unsupportedTaskTypeErrorMsg; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -106,7 +106,12 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - var chunkingSettings = extractChunkingSettings(config, taskType); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } CustomModel model = createModel( inferenceEntityId, @@ -157,14 +162,6 @@ private static RequestParameters createParameters(CustomModel model) { }; } - private static ChunkingSettings extractChunkingSettings(Map config, TaskType taskType) { - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return null; - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); @@ -214,7 +211,7 @@ private static CustomModel createModel( ConfigurationParseContext context ) { if (supportedTaskTypes.contains(taskType) == false) { - throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); + throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); } return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, chunkingSettings, context); } @@ -230,7 +227,7 @@ public CustomModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - var chunkingSettings = extractChunkingSettings(config, taskType); + var chunkingSettings = extractPersistentChunkingSettings(config, taskType); return createModelWithoutLoggingDeprecations( inferenceEntityId, @@ -242,12 +239,27 @@ public CustomModel parsePersistedConfigWithSecrets( ); } + private static ChunkingSettings extractPersistentChunkingSettings(Map config, TaskType taskType) { + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + /* + * There's a sutle difference between how the chunking settings are parsed for the request context vs the persistent context. + * For persistent context, to support backwards compatibility, if the chunking settings are not present, removeFromMap will + * return null which results in the older word boundary chunking settings being used as the default. + * For request context, removeFromMapOrDefaultEmpty returns an empty map which results in the newer sentence boundary chunking + * settings being used as the default. + */ + return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return null; + } + @Override public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - var chunkingSettings = extractChunkingSettings(config, taskType); + var chunkingSettings = extractPersistentChunkingSettings(config, taskType); return createModelWithoutLoggingDeprecations( inferenceEntityId, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index cc871da8eb860..c4156c0bfd6b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -79,7 +79,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -446,7 +446,6 @@ public void parseRequestConfig( chunkingSettings, serviceSettingsMap, elasticInferenceServiceComponents, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -494,7 +493,6 @@ private static ElasticInferenceServiceModel createModel( ChunkingSettings chunkingSettings, @Nullable Map secretSettings, ElasticInferenceServiceComponents elasticInferenceServiceComponents, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -540,7 +538,7 @@ private static ElasticInferenceServiceModel createModel( context, chunkingSettings ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -566,8 +564,7 @@ public Model parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + secretSettingsMap ); } @@ -581,15 +578,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override @@ -603,8 +592,7 @@ private ElasticInferenceServiceModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -614,7 +602,6 @@ private ElasticInferenceServiceModel createModelFromPersistent( chunkingSettings, secretSettings, elasticInferenceServiceComponents, - failureMessage, ConfigurationParseContext.PERSISTENT ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 97bd2502d25b6..dc08ec8544e3c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.googleaistudio; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -31,7 +30,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; @@ -60,7 +58,7 @@ import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -127,7 +125,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -149,7 +146,6 @@ private static GoogleAiStudioModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -172,7 +168,7 @@ private static GoogleAiStudioModel createModel( secretSettings, context ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -198,8 +194,7 @@ public GoogleAiStudioModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + secretSettingsMap ); } @@ -209,8 +204,7 @@ private static GoogleAiStudioModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - Map secretSettings, - String failureMessage + Map secretSettings ) { return createModel( inferenceEntityId, @@ -219,7 +213,6 @@ private static GoogleAiStudioModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -234,15 +227,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 19ab67b920f04..66a4dd0649730 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.services.googlevertexai; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -31,7 +30,6 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; @@ -63,7 +61,7 @@ import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -149,7 +147,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -185,8 +182,7 @@ public Model parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + secretSettingsMap ); } @@ -200,15 +196,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override @@ -356,8 +344,7 @@ private static GoogleVertexAiModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - Map secretSettings, - String failureMessage + Map secretSettings ) { return createModel( inferenceEntityId, @@ -366,7 +353,6 @@ private static GoogleVertexAiModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -378,7 +364,6 @@ private static GoogleVertexAiModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -412,7 +397,7 @@ private static GoogleVertexAiModel createModel( context ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index 325f88c8904a3..403a1758983ac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -31,7 +31,6 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -84,7 +83,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, name()), ConfigurationParseContext.REQUEST ) ); @@ -123,7 +121,6 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( taskSettingsMap, chunkingSettings, secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, name()), ConfigurationParseContext.PERSISTENT ) ); @@ -147,7 +144,6 @@ public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskSettingsMap, chunkingSettings, null, - parsePersistedConfigErrorMsg(inferenceEntityId, name()), ConfigurationParseContext.PERSISTENT ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java index 6dabaa66ffb2b..7600207eebad0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java @@ -20,6 +20,5 @@ public record HuggingFaceModelParameters( Map taskSettings, ChunkingSettings chunkingSettings, Map secretSettings, - String failureMessage, ConfigurationParseContext context ) {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index d0a98d8252923..e727a9e20cd8c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.huggingface; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -26,7 +25,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; @@ -54,6 +52,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; /** * This class is responsible for managing the Hugging Face inference service. @@ -124,7 +123,7 @@ protected HuggingFaceModel createModel(HuggingFaceModelParameters params) { params.secretSettings(), params.context() ); - default -> throw new ElasticsearchStatusException(params.failureMessage(), RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(params.inferenceEntityId(), NAME, params.taskType(), params.context()); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 775a4e90ae034..081d5c63b84ff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.huggingface.elser; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -25,7 +24,6 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; @@ -50,6 +48,7 @@ import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL; @@ -87,7 +86,7 @@ protected HuggingFaceModel createModel(HuggingFaceModelParameters input) { input.secretSettings(), input.context() ); - default -> throw new ElasticsearchStatusException(input.failureMessage(), RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(input.inferenceEntityId(), NAME, input.taskType(), input.context()); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 8cdc8cd182425..ee836d03747de 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.ibmwatsonx; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -30,7 +29,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; @@ -62,7 +60,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -128,7 +126,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -150,7 +147,6 @@ private static IbmWatsonxModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -181,7 +177,7 @@ private static IbmWatsonxModel createModel( secretSettings, context ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -207,8 +203,7 @@ public IbmWatsonxModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + secretSettingsMap ); } @@ -228,8 +223,7 @@ private static IbmWatsonxModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - Map secretSettings, - String failureMessage + Map secretSettings ) { return createModel( inferenceEntityId, @@ -238,7 +232,6 @@ private static IbmWatsonxModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -253,15 +246,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index f6bd954617b76..5e95f85e78ecd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.jinaai; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -31,7 +30,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -57,7 +55,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -121,7 +119,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -141,8 +138,7 @@ private static JinaAIModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -151,7 +147,6 @@ private static JinaAIModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -163,7 +158,6 @@ private static JinaAIModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -177,7 +171,7 @@ private static JinaAIModel createModel( context ); case RERANK -> new JinaAIRerankModel(inferenceEntityId, NAME, serviceSettings, taskSettings, secretSettings, context); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -203,8 +197,7 @@ public JinaAIModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + secretSettingsMap ); } @@ -218,15 +211,7 @@ public JinaAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskT chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java index a74f3202e5fb4..b3c0e12927fff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.llama; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; @@ -28,7 +27,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; @@ -64,7 +62,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -138,7 +136,6 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc * @param serviceSettings the settings for the inference service * @param chunkingSettings the settings for chunking, if applicable * @param secretSettings the secret settings for the model, such as API keys or tokens - * @param failureMessage the message to use in case of failure * @param context the context for parsing configuration settings * @return a new instance of LlamaModel based on the provided parameters */ @@ -148,7 +145,6 @@ protected LlamaModel createModel( Map serviceSettings, ChunkingSettings chunkingSettings, Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { switch (taskType) { @@ -157,7 +153,7 @@ protected LlamaModel createModel( case CHAT_COMPLETION, COMPLETION: return new LlamaChatCompletionModel(inferenceId, taskType, NAME, serviceSettings, secretSettings, context); default: - throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + throw createInvalidTaskTypeException(inferenceId, NAME, taskType, context); } } @@ -283,7 +279,6 @@ public void parseRequestConfig( serviceSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -313,14 +308,7 @@ public Model parsePersistedConfigWithSecrets( chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - modelId, - taskType, - serviceSettingsMap, - chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(modelId, NAME) - ); + return createModelFromPersistent(modelId, taskType, serviceSettingsMap, chunkingSettings, secretSettingsMap); } private LlamaModel createModelFromPersistent( @@ -328,8 +316,7 @@ private LlamaModel createModelFromPersistent( TaskType taskType, Map serviceSettings, ChunkingSettings chunkingSettings, - Map secretSettings, - String failureMessage + Map secretSettings ) { return createModel( inferenceEntityId, @@ -337,7 +324,6 @@ private LlamaModel createModelFromPersistent( serviceSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -352,14 +338,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { switch (taskType) { @@ -299,7 +281,7 @@ private static MistralModel createModel( case CHAT_COMPLETION, COMPLETION: return new MistralChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context); default: - throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + throw createInvalidTaskTypeException(modelId, NAME, taskType, context); } } @@ -308,8 +290,7 @@ private MistralModel createModelFromPersistent( TaskType taskType, Map serviceSettings, ChunkingSettings chunkingSettings, - Map secretSettings, - String failureMessage + Map secretSettings ) { return createModel( inferenceEntityId, @@ -317,7 +298,6 @@ private MistralModel createModelFromPersistent( serviceSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index ae49f5dcef13b..9e77c0d63e336 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -65,7 +65,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -139,7 +139,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -159,8 +158,7 @@ private static OpenAiModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -169,7 +167,6 @@ private static OpenAiModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -181,7 +178,6 @@ private static OpenAiModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -204,7 +200,7 @@ private static OpenAiModel createModel( secretSettings, context ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -232,8 +228,7 @@ public OpenAiModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + secretSettingsMap ); } @@ -249,15 +244,7 @@ public OpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskT moveModelFromTaskToServiceSettings(taskSettingsMap, serviceSettingsMap); - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java index 88840cc1202ac..6d340320c655c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java @@ -47,7 +47,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject // The rate limit for usage tier 1 is 500 request per minute for most of the completion models // To find this information you need to access your account's limits https://platform.openai.com/account/limits // 500 requests per minute - private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500); + public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500); public static OpenAiChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java index 20b9bf931290c..e9b6a7c77a5fa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java @@ -50,11 +50,11 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl public static final String NAME = "openai_service_settings"; - static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; + public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; // The rate limit for usage tier 1 is 3000 request per minute for the text embedding models // To find this information you need to access your account's limits https://platform.openai.com/account/limits // 3000 requests per minute - private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); public static OpenAiEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { return switch (context) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index c69aeec203e4c..f2db85f263f9f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.voyageai; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -31,7 +30,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -56,7 +54,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -152,7 +150,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -172,8 +169,7 @@ private static VoyageAIModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -182,7 +178,6 @@ private static VoyageAIModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -194,7 +189,6 @@ private static VoyageAIModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -208,7 +202,7 @@ private static VoyageAIModel createModel( context ); case RERANK -> new VoyageAIRerankModel(inferenceEntityId, NAME, serviceSettings, taskSettings, secretSettings, context); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -234,8 +228,7 @@ public VoyageAIModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + secretSettingsMap ); } @@ -249,15 +242,7 @@ public VoyageAIModel parsePersistedConfig(String inferenceEntityId, TaskType tas chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java new file mode 100644 index 0000000000000..7657f7e79cf25 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java @@ -0,0 +1,173 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.RerankingInferenceService; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.junit.After; +import org.junit.Before; + +import java.util.EnumSet; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.mockito.Mockito.mock; + +public abstract class AbstractInferenceServiceBaseTests extends InferenceServiceTestCase { + protected final TestConfiguration testConfiguration; + + protected final MockWebServer webServer = new MockWebServer(); + protected ThreadPool threadPool; + protected HttpClientManager clientManager; + protected AbstractInferenceServiceParameterizedTests.TestCase testCase; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + webServer.start(); + threadPool = createThreadPool(inferenceUtilityExecutors()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public AbstractInferenceServiceBaseTests(TestConfiguration testConfiguration) { + this.testConfiguration = Objects.requireNonNull(testConfiguration); + } + + /** + * Main configurations for the tests + */ + public record TestConfiguration(CommonConfig commonConfig, UpdateModelConfiguration updateModelConfiguration) { + public static class Builder { + private final CommonConfig commonConfig; + private UpdateModelConfiguration updateModelConfiguration = DISABLED_UPDATE_MODEL_TESTS; + + public Builder(CommonConfig commonConfig) { + this.commonConfig = commonConfig; + } + + public TestConfiguration.Builder enableUpdateModelTests(UpdateModelConfiguration updateModelConfiguration) { + this.updateModelConfiguration = updateModelConfiguration; + return this; + } + + public TestConfiguration build() { + return new TestConfiguration(commonConfig, updateModelConfiguration); + } + } + } + + /** + * Configurations that are useful for most tests + */ + public abstract static class CommonConfig { + + private final TaskType targetTaskType; + private final TaskType unsupportedTaskType; + private final EnumSet supportedTaskTypes; + + public CommonConfig(TaskType targetTaskType, @Nullable TaskType unsupportedTaskType, EnumSet supportedTaskTypes) { + this.targetTaskType = Objects.requireNonNull(targetTaskType); + this.unsupportedTaskType = unsupportedTaskType; + this.supportedTaskTypes = Objects.requireNonNull(supportedTaskTypes); + } + + public TaskType targetTaskType() { + return targetTaskType; + } + + public TaskType unsupportedTaskType() { + return unsupportedTaskType; + } + + public EnumSet supportedTaskTypes() { + return supportedTaskTypes; + } + + protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager); + + protected abstract Map createServiceSettingsMap(TaskType taskType); + + protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) { + return createServiceSettingsMap(taskType); + } + + protected abstract Map createTaskSettingsMap(); + + protected abstract Map createSecretSettingsMap(); + + protected abstract void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets); + + protected void assertModel(Model model, TaskType taskType) { + assertModel(model, taskType, true); + } + + protected abstract EnumSet supportedStreamingTasks(); + + /** + * Override this method if the service support reranking. This method won't be called if the service doesn't support reranking. + */ + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + fail("Reranking services should override this test method to verify window size"); + } + } + + /** + * Configurations specific to the {@link SenderService#updateModelWithEmbeddingDetails(Model, int)} tests + */ + public abstract static class UpdateModelConfiguration { + + public boolean isEnabled() { + return true; + } + + protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); + } + + private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() { + @Override + public boolean isEnabled() { + return false; + } + + @Override + protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) { + throw new UnsupportedOperationException("Update model tests are disabled"); + } + }; + + @Override + public InferenceService createInferenceService() { + return testConfiguration.commonConfig.createService(threadPool, clientManager); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + testConfiguration.commonConfig.assertRerankerWindowSize(rerankingInferenceService); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java new file mode 100644 index 0000000000000..37ba073407df8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java @@ -0,0 +1,387 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.core.Strings; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.Utils; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.junit.Assume; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.function.Function; + +import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +/** + * Base class for testing inference services using parameterized tests. + */ +public abstract class AbstractInferenceServiceParameterizedTests extends AbstractInferenceServiceBaseTests { + + public AbstractInferenceServiceParameterizedTests( + AbstractInferenceServiceBaseTests.TestConfiguration testConfiguration, + TestCase testCase + ) { + super(testConfiguration); + this.testCase = testCase; + } + + @Override + public InferenceService createInferenceService() { + return testConfiguration.commonConfig().createService(threadPool, clientManager); + } + + public record TestCase( + String description, + Function createPersistedConfig, + ServiceParser serviceParser, + TaskType expectedTaskType, + boolean modelIncludesSecrets, + boolean expectFailure + ) {} + + private record ServiceParserParams( + SenderService service, + Utils.PersistedConfig persistedConfig, + AbstractInferenceServiceBaseTests.TestConfiguration testConfiguration + ) {} + + @FunctionalInterface + private interface ServiceParser { + Model parseConfigs(ServiceParserParams params); + } + + private static class TestCaseBuilder { + private final String description; + private final Function createPersistedConfig; + private final ServiceParser serviceParser; + private final TaskType expectedTaskType; + private boolean modelIncludesSecrets; + private boolean expectFailure; + + TestCaseBuilder( + String description, + Function createPersistedConfig, + ServiceParser serviceParser, + TaskType expectedTaskType + ) { + this.description = description; + this.createPersistedConfig = createPersistedConfig; + this.serviceParser = serviceParser; + this.expectedTaskType = expectedTaskType; + } + + public TestCaseBuilder withSecrets() { + this.modelIncludesSecrets = true; + return this; + } + + public TestCaseBuilder expectFailure() { + this.expectFailure = true; + return this; + } + + public TestCase build() { + return new TestCase(description, createPersistedConfig, serviceParser, expectedTaskType, modelIncludesSecrets, expectFailure); + } + } + + @ParametersFactory + public static Iterable parameters() throws IOException { + return Arrays.asList( + new TestCase[][] { + // Test cases for parsePersistedConfig method + { + new TestCaseBuilder( + "Test parsing persisted config without chunking settings", + testConfiguration -> getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + null + ), + (params) -> params.service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, params.persistedConfig.config()), + TaskType.TEXT_EMBEDDING + ).build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with chunking settings", + testConfiguration -> getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + createRandomChunkingSettingsMap(), + null + ), + (params) -> params.service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, params.persistedConfig.config()), + TaskType.TEXT_EMBEDDING + ).build() }, + { + new TestCaseBuilder( + "Test parsing persisted config does not throw when an extra key exists in config", + testConfiguration -> { + var persistedConfigMap = getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + null + ); + persistedConfigMap.config().put("extra_key", "value"); + return persistedConfigMap; + }, + (params) -> params.service.parsePersistedConfig("id", TaskType.COMPLETION, params.persistedConfig.config()), + TaskType.COMPLETION + ).build() }, + { + new TestCaseBuilder( + "Test parsing persisted config does not throw when an extra key exists in service settings", + testConfiguration -> { + var serviceSettings = testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT); + serviceSettings.put("extra_key", "value"); + + return getPersistedConfigMap(serviceSettings, testConfiguration.commonConfig().createTaskSettingsMap(), null); + }, + (params) -> params.service.parsePersistedConfig("id", TaskType.COMPLETION, params.persistedConfig.config()), + TaskType.COMPLETION + ).build() }, + { + new TestCaseBuilder( + "Test parsing persisted config does not throw when an extra key exists in task settings", + testConfiguration -> { + var taskSettingsMap = testConfiguration.commonConfig().createTaskSettingsMap(); + taskSettingsMap.put("extra_key", "value"); + + return getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), + taskSettingsMap, + null + ); + }, + (params) -> params.service.parsePersistedConfig("id", TaskType.COMPLETION, params.persistedConfig.config()), + TaskType.COMPLETION + ).build() }, + // Test cases for parsePersistedConfigWithSecrets method + { + new TestCaseBuilder( + "Test parsing persisted config with secrets creates an embeddings model", + testConfiguration -> getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + testConfiguration.commonConfig().createSecretSettingsMap() + ), + (params) -> params.service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + params.persistedConfig.config(), + params.persistedConfig.secrets() + ), + TaskType.TEXT_EMBEDDING + ).withSecrets().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with with secrets creates an embeddings " + + "model when chunking settings are provided", + testConfiguration -> getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + createRandomChunkingSettingsMap(), + testConfiguration.commonConfig().createSecretSettingsMap() + ), + (params) -> params.service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + params.persistedConfig.config(), + params.persistedConfig.secrets() + ), + TaskType.TEXT_EMBEDDING + ).withSecrets().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with with secrets creates a completion " + + "model when chunking settings are not provided", + testConfiguration -> getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + testConfiguration.commonConfig().createSecretSettingsMap() + ), + (params) -> params.service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + params.persistedConfig.config(), + params.persistedConfig.secrets() + ), + TaskType.COMPLETION + ).withSecrets().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with with secrets throws exception for unsupported task type", + testConfiguration -> getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + testConfiguration.commonConfig().createSecretSettingsMap() + ), + (params) -> params.service.parsePersistedConfigWithSecrets( + "id", + params.testConfiguration.commonConfig().unsupportedTaskType(), + params.persistedConfig.config(), + params.persistedConfig.secrets() + ), + TaskType.COMPLETION + ).withSecrets().expectFailure().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with with secrets does not throw when an extra key exists in config", + testConfiguration -> { + var persistedConfigMap = getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + testConfiguration.commonConfig().createSecretSettingsMap() + ); + persistedConfigMap.config().put("extra_key", "value"); + return persistedConfigMap; + }, + (params) -> params.service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + params.persistedConfig.config(), + params.persistedConfig.secrets() + ), + TaskType.COMPLETION + ).withSecrets().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with with secrets does not throw when an extra key exists in service settings", + testConfiguration -> { + var serviceSettings = testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT); + serviceSettings.put("extra_key", "value"); + + return getPersistedConfigMap( + serviceSettings, + testConfiguration.commonConfig().createTaskSettingsMap(), + testConfiguration.commonConfig().createSecretSettingsMap() + ); + }, + (params) -> params.service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + params.persistedConfig.config(), + params.persistedConfig.secrets() + ), + TaskType.COMPLETION + ).withSecrets().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with with secrets does not throw when an extra key exists in task settings", + testConfiguration -> { + var taskSettingsMap = testConfiguration.commonConfig().createTaskSettingsMap(); + taskSettingsMap.put("extra_key", "value"); + + return getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), + taskSettingsMap, + testConfiguration.commonConfig().createSecretSettingsMap() + ); + }, + (params) -> params.service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + params.persistedConfig.config(), + params.persistedConfig.secrets() + ), + TaskType.COMPLETION + ).withSecrets().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with with secrets does not throw when an extra key exists in secret settings", + testConfiguration -> { + var secretSettingsMap = testConfiguration.commonConfig().createSecretSettingsMap(); + secretSettingsMap.put("extra_key", "value"); + + return getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + secretSettingsMap + ); + }, + (params) -> params.service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + params.persistedConfig.config(), + params.persistedConfig.secrets() + ), + TaskType.COMPLETION + ).withSecrets().build() } } + ); + } + + public void testPersistedConfig() throws Exception { + // If the service doesn't support the expected task type, then skip the test + Assume.assumeTrue(testConfiguration.commonConfig().supportedTaskTypes().contains(testCase.expectedTaskType)); + + var parseConfigTestConfig = testConfiguration.commonConfig(); + var persistedConfig = testCase.createPersistedConfig.apply(testConfiguration); + + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + + if (testCase.expectFailure) { + assertFailedParse(service, persistedConfig); + } else { + assertSuccessfulParse(service, persistedConfig); + } + } + } + + private void assertFailedParse(SenderService service, Utils.PersistedConfig persistedConfig) { + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> testCase.serviceParser.parseConfigs(new ServiceParserParams(service, persistedConfig, testConfiguration)) + ); + + assertThat( + exception.getMessage(), + containsString( + Strings.format("service does not support task type [%s]", testConfiguration.commonConfig().unsupportedTaskType()) + ) + ); + } + + private void assertSuccessfulParse(SenderService service, Utils.PersistedConfig persistedConfig) throws Exception { + var model = testCase.serviceParser.parseConfigs(new ServiceParserParams(service, persistedConfig, testConfiguration)); + + if (persistedConfig.config().containsKey(ModelConfigurations.CHUNKING_SETTINGS)) { + @SuppressWarnings("unchecked") + var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap( + (Map) persistedConfig.config().get(ModelConfigurations.CHUNKING_SETTINGS) + ); + assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); + } + + testConfiguration.commonConfig().assertModel(model, testCase.expectedTaskType, testCase.modelIncludesSecrets); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java index ec07d7b547004..9479a8ea6e10d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java @@ -9,39 +9,27 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.http.MockWebServer; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.junit.After; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.junit.Assume; -import org.junit.Before; import java.io.IOException; -import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; -import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; -import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; -import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; -import static org.mockito.Mockito.mock; /** * Base class for testing inference services. @@ -51,132 +39,64 @@ * To use this class, extend it and pass the constructor a configuration. *

*/ -public abstract class AbstractInferenceServiceTests extends InferenceServiceTestCase { - - protected final MockWebServer webServer = new MockWebServer(); - protected ThreadPool threadPool; - protected HttpClientManager clientManager; - - @Override - @Before - public void setUp() throws Exception { - super.setUp(); - webServer.start(); - threadPool = createThreadPool(inferenceUtilityExecutors()); - clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); - } - - @Override - @After - public void tearDown() throws Exception { - super.tearDown(); - clientManager.close(); - terminate(threadPool); - webServer.close(); - } - - private final TestConfiguration testConfiguration; +public abstract class AbstractInferenceServiceTests extends AbstractInferenceServiceBaseTests { public AbstractInferenceServiceTests(TestConfiguration testConfiguration) { - this.testConfiguration = Objects.requireNonNull(testConfiguration); - } - - /** - * Main configurations for the tests - */ - public record TestConfiguration(CommonConfig commonConfig, UpdateModelConfiguration updateModelConfiguration) { - public static class Builder { - private final CommonConfig commonConfig; - private UpdateModelConfiguration updateModelConfiguration = DISABLED_UPDATE_MODEL_TESTS; - - public Builder(CommonConfig commonConfig) { - this.commonConfig = commonConfig; - } - - public Builder enableUpdateModelTests(UpdateModelConfiguration updateModelConfiguration) { - this.updateModelConfiguration = updateModelConfiguration; - return this; - } - - public TestConfiguration build() { - return new TestConfiguration(commonConfig, updateModelConfiguration); - } - } + super(testConfiguration); } - /** - * Configurations that useful for most tests - */ - public abstract static class CommonConfig { - - private final TaskType taskType; - private final TaskType unsupportedTaskType; - - public CommonConfig(TaskType taskType, @Nullable TaskType unsupportedTaskType) { - this.taskType = Objects.requireNonNull(taskType); - this.unsupportedTaskType = unsupportedTaskType; - } - - protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager); - - protected abstract Map createServiceSettingsMap(TaskType taskType); - - protected abstract Map createTaskSettingsMap(); - - protected abstract Map createSecretSettingsMap(); + public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws Exception { + Assume.assumeTrue(testConfiguration.commonConfig().supportedTaskTypes().contains(TaskType.TEXT_EMBEDDING)); - protected abstract void assertModel(Model model, TaskType taskType); + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); - protected abstract EnumSet supportedStreamingTasks(); - } + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var config = getRequestConfigMap( + parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.REQUEST), + parseRequestConfigTestConfig.createTaskSettingsMap(), + parseRequestConfigTestConfig.createSecretSettingsMap() + ); - /** - * Configurations specific to the {@link SenderService#updateModelWithEmbeddingDetails(Model, int)} tests - */ - public abstract static class UpdateModelConfiguration { + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener); - public boolean isEnabled() { - return true; + var model = listener.actionGet(TIMEOUT); + var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(Map.of()); + assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); + parseRequestConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); } - - protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); } - private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() { - @Override - public boolean isEnabled() { - return false; - } + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws Exception { + Assume.assumeTrue(testConfiguration.commonConfig().supportedTaskTypes().contains(TaskType.TEXT_EMBEDDING)); - @Override - protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) { - throw new UnsupportedOperationException("Update model tests are disabled"); - } - }; - - public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws Exception { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var chunkingSettingsMap = createRandomChunkingSettingsMap(); var config = getRequestConfigMap( - parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING), + parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.REQUEST), parseRequestConfigTestConfig.createTaskSettingsMap(), + chunkingSettingsMap, parseRequestConfigTestConfig.createSecretSettingsMap() ); var listener = new PlainActionFuture(); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener); - parseRequestConfigTestConfig.assertModel(listener.actionGet(TIMEOUT), TaskType.TEXT_EMBEDDING); + var model = listener.actionGet(TIMEOUT); + var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap); + assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); + parseRequestConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); } } public void testParseRequestConfig_CreatesACompletionModel() throws Exception { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var config = getRequestConfigMap( - parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION), + parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.REQUEST), parseRequestConfigTestConfig.createTaskSettingsMap(), parseRequestConfigTestConfig.createSecretSettingsMap() ); @@ -189,39 +109,47 @@ public void testParseRequestConfig_CreatesACompletionModel() throws Exception { } public void testParseRequestConfig_ThrowsUnsupportedModelType() throws Exception { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var config = getRequestConfigMap( - parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + parseRequestConfigTestConfig.createServiceSettingsMap( + parseRequestConfigTestConfig.targetTaskType(), + ConfigurationParseContext.REQUEST + ), parseRequestConfigTestConfig.createTaskSettingsMap(), parseRequestConfigTestConfig.createSecretSettingsMap() ); var listener = new PlainActionFuture(); - service.parseRequestConfig("id", parseRequestConfigTestConfig.unsupportedTaskType, config, listener); + service.parseRequestConfig("id", parseRequestConfigTestConfig.unsupportedTaskType(), config, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( exception.getMessage(), - containsString(Strings.format("service does not support task type [%s]", parseRequestConfigTestConfig.unsupportedTaskType)) + containsString( + Strings.format("service does not support task type [%s]", parseRequestConfigTestConfig.unsupportedTaskType()) + ) ); } } public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var config = getRequestConfigMap( - parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + parseRequestConfigTestConfig.createServiceSettingsMap( + parseRequestConfigTestConfig.targetTaskType(), + ConfigurationParseContext.REQUEST + ), parseRequestConfigTestConfig.createTaskSettingsMap(), parseRequestConfigTestConfig.createSecretSettingsMap() ); config.put("extra_key", "value"); var listener = new PlainActionFuture(); - service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]")); @@ -229,9 +157,12 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I } public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { - var serviceSettings = parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType); + var serviceSettings = parseRequestConfigTestConfig.createServiceSettingsMap( + parseRequestConfigTestConfig.targetTaskType(), + ConfigurationParseContext.REQUEST + ); serviceSettings.put("extra_key", "value"); var config = getRequestConfigMap( serviceSettings, @@ -240,7 +171,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa ); var listener = new PlainActionFuture(); - service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]")); @@ -248,18 +179,21 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa } public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var taskSettings = parseRequestConfigTestConfig.createTaskSettingsMap(); taskSettings.put("extra_key", "value"); var config = getRequestConfigMap( - parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + parseRequestConfigTestConfig.createServiceSettingsMap( + parseRequestConfigTestConfig.targetTaskType(), + ConfigurationParseContext.REQUEST + ), taskSettings, parseRequestConfigTestConfig.createSecretSettingsMap() ); var listener = new PlainActionFuture(); - service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]")); @@ -267,179 +201,29 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() } public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var secretSettingsMap = parseRequestConfigTestConfig.createSecretSettingsMap(); secretSettingsMap.put("extra_key", "value"); var config = getRequestConfigMap( - parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + parseRequestConfigTestConfig.createServiceSettingsMap( + parseRequestConfigTestConfig.targetTaskType(), + ConfigurationParseContext.REQUEST + ), parseRequestConfigTestConfig.createTaskSettingsMap(), secretSettingsMap ); var listener = new PlainActionFuture(); - service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]")); } } - // parsePersistedConfigWithSecrets - - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throws Exception { - var parseConfigTestConfig = testConfiguration.commonConfig; - - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING), - parseConfigTestConfig.createTaskSettingsMap(), - parseConfigTestConfig.createSecretSettingsMap() - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfigMap.config(), - persistedConfigMap.secrets() - ); - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); - } - } - - public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws Exception { - var parseConfigTestConfig = testConfiguration.commonConfig; - - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION), - parseConfigTestConfig.createTaskSettingsMap(), - parseConfigTestConfig.createSecretSettingsMap() - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.COMPLETION, - persistedConfigMap.config(), - persistedConfigMap.secrets() - ); - parseConfigTestConfig.assertModel(model, TaskType.COMPLETION); - } - } - - public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() throws Exception { - var parseConfigTestConfig = testConfiguration.commonConfig; - - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), - parseConfigTestConfig.createTaskSettingsMap(), - parseConfigTestConfig.createSecretSettingsMap() - ); - - var exception = expectThrows( - ElasticsearchStatusException.class, - () -> service.parsePersistedConfigWithSecrets( - "id", - parseConfigTestConfig.unsupportedTaskType, - persistedConfigMap.config(), - persistedConfigMap.secrets() - ) - ); - - assertThat( - exception.getMessage(), - containsString( - Strings.format(fetchPersistedConfigTaskTypeParsingErrorMessageFormat(), parseConfigTestConfig.unsupportedTaskType) - ) - ); - } - } - - protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() { - return "service does not support task type [%s]"; - } - - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { - var parseConfigTestConfig = testConfiguration.commonConfig; - - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), - parseConfigTestConfig.createTaskSettingsMap(), - parseConfigTestConfig.createSecretSettingsMap() - ); - persistedConfigMap.config().put("extra_key", "value"); - - var model = service.parsePersistedConfigWithSecrets( - "id", - parseConfigTestConfig.taskType, - persistedConfigMap.config(), - persistedConfigMap.secrets() - ); - - parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); - } - } - - public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { - var parseConfigTestConfig = testConfiguration.commonConfig; - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var serviceSettings = parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType); - serviceSettings.put("extra_key", "value"); - var persistedConfigMap = getPersistedConfigMap( - serviceSettings, - parseConfigTestConfig.createTaskSettingsMap(), - parseConfigTestConfig.createSecretSettingsMap() - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - parseConfigTestConfig.taskType, - persistedConfigMap.config(), - persistedConfigMap.secrets() - ); - - parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); - } - } - - public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { - var parseConfigTestConfig = testConfiguration.commonConfig; - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var taskSettings = parseConfigTestConfig.createTaskSettingsMap(); - taskSettings.put("extra_key", "value"); - var config = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), - taskSettings, - parseConfigTestConfig.createSecretSettingsMap() - ); - - var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets()); - - parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); - } - } - - public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { - var parseConfigTestConfig = testConfiguration.commonConfig; - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var secretSettingsMap = parseConfigTestConfig.createSecretSettingsMap(); - secretSettingsMap.put("extra_key", "value"); - var config = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), - parseConfigTestConfig.createTaskSettingsMap(), - secretSettingsMap - ); - - var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets()); - - parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); - } - } - public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException { - try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { var listener = new PlainActionFuture(); service.infer( @@ -464,9 +248,9 @@ public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException { } public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { - Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled()); + Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled()); - try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { var exception = expectThrows( ElasticsearchStatusException.class, () -> service.updateModelWithEmbeddingDetails(getInvalidModel("id", "service"), randomNonNegativeInt()) @@ -477,11 +261,11 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO } public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { - Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled()); + Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled()); - try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { var embeddingSize = randomNonNegativeInt(); - var model = testConfiguration.updateModelConfiguration.createEmbeddingModel(null); + var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(null); Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); @@ -491,11 +275,11 @@ public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() } public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException { - Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled()); + Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled()); - try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { var embeddingSize = randomNonNegativeInt(); - var model = testConfiguration.updateModelConfiguration.createEmbeddingModel(SimilarityMeasure.COSINE); + var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(SimilarityMeasure.COSINE); Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); @@ -506,8 +290,8 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel // streaming tests public void testSupportedStreamingTasks() throws Exception { - try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { - assertThat(service.supportedStreamingTasks(), is(testConfiguration.commonConfig.supportedStreamingTasks())); + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { + assertThat(service.supportedStreamingTasks(), is(testConfiguration.commonConfig().supportedStreamingTasks())); assertFalse(service.canStream(TaskType.ANY)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceParameterizedTests.java new file mode 100644 index 0000000000000..1254610f32745 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceParameterizedTests.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ai21; + +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests; + +import static org.elasticsearch.xpack.inference.services.ai21.Ai21ServiceTests.createTestConfiguration; + +public class Ai21ServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests { + public Ai21ServiceParameterizedTests(AbstractInferenceServiceParameterizedTests.TestCase testCase) { + super(createTestConfiguration(), testCase); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java index cbb119d3e5710..8a1f8d806a45b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java @@ -19,7 +19,6 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -86,9 +85,13 @@ public Ai21ServiceTests() { super(createTestConfiguration()); } - private static AbstractInferenceServiceTests.TestConfiguration createTestConfiguration() { + public static AbstractInferenceServiceTests.TestConfiguration createTestConfiguration() { return new AbstractInferenceServiceTests.TestConfiguration.Builder( - new AbstractInferenceServiceTests.CommonConfig(TaskType.COMPLETION, TaskType.TEXT_EMBEDDING) { + new AbstractInferenceServiceTests.CommonConfig( + TaskType.COMPLETION, + TaskType.TEXT_EMBEDDING, + EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION) + ) { @Override protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { @@ -111,8 +114,8 @@ protected Map createSecretSettingsMap() { } @Override - protected void assertModel(Model model, TaskType taskType) { - Ai21ServiceTests.assertModel(model, taskType); + protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + Ai21ServiceTests.assertModel(model, taskType, modelIncludesSecrets); } @Override @@ -123,32 +126,35 @@ protected EnumSet supportedStreamingTasks() { ).build(); } - private static void assertModel(Model model, TaskType taskType) { + private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { switch (taskType) { - case COMPLETION -> assertCompletionModel(model); - case CHAT_COMPLETION -> assertChatCompletionModel(model); + case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets); + case CHAT_COMPLETION -> assertChatCompletionModel(model, modelIncludesSecrets); default -> fail("unexpected task type [" + taskType + "]"); } } - private static Ai21Model assertCommonModelFields(Model model) { + private static Ai21Model assertCommonModelFields(Model model, boolean modelIncludesSecrets) { assertThat(model, instanceOf(Ai21Model.class)); var customModel = (Ai21Model) model; assertThat(customModel.uri.toString(), Matchers.is("https://api.ai21.com/studio/v1/chat/completions")); assertThat(customModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE)); - assertThat(customModel.getSecretSettings().apiKey(), Matchers.is(new SecureString("secret".toCharArray()))); + + if (modelIncludesSecrets) { + assertThat(customModel.getSecretSettings().apiKey(), Matchers.is(new SecureString("secret".toCharArray()))); + } return customModel; } - private static void assertCompletionModel(Model model) { - var customModel = assertCommonModelFields(model); + private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) { + var customModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(customModel.getTaskType(), Matchers.is(TaskType.COMPLETION)); } - private static void assertChatCompletionModel(Model model) { - var customModel = assertCommonModelFields(model); + private static void assertChatCompletionModel(Model model, boolean modelIncludesSecrets) { + var customModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(customModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION)); } @@ -165,10 +171,6 @@ private static Map createSecretSettingsMap() { return new HashMap<>(Map.of("api_key", "secret")); } - protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() { - return "Failed to parse stored model [id] for [ai21] service, please delete and add the service again"; - } - @Before public void init() throws Exception { webServer.start(); @@ -183,16 +185,6 @@ public void shutdown() throws IOException { webServer.close(); } - @Override - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() { - // The Ai21Service does not support Text Embedding, so this test is not applicable. - } - - @Override - public void testParseRequestConfig_CreatesAnEmbeddingsModel() { - // The Ai21Service does not support Text Embedding, so this test is not applicable. - } - public void testParseRequestConfig_CreatesChatCompletionsModel() throws IOException { var url = "url"; var model = "model"; @@ -561,9 +553,4 @@ private Map getRequestConfigMap(Map serviceSetti return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings)); } - - @Override - public InferenceService createInferenceService() { - return createService(); - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index d7eb32861da92..36a4ce7b18c46 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -628,9 +628,10 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM ) ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [amazonbedrock] service")); assertThat( thrownException.getMessage(), - is("Failed to parse stored model [id] for [amazonbedrock] service, please delete and add the service again") + containsString("The [amazonbedrock] service does not support task type [sparse_embedding]") ); } } @@ -876,9 +877,10 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [amazonbedrock] service")); assertThat( thrownException.getMessage(), - is("Failed to parse stored model [id] for [amazonbedrock] service, please delete and add the service again") + containsString("The [amazonbedrock] service does not support task type [sparse_embedding]") ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index f1531929db8c3..127b7d1c4cfae 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -807,9 +807,10 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM () -> service.parsePersistedConfigWithSecrets("id", TaskType.SPARSE_EMBEDDING, config.config(), config.secrets()) ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [azureaistudio] service")); assertThat( thrownException.getMessage(), - is("Failed to parse stored model [id] for [azureaistudio] service, please delete and add the service again") + containsString("The [azureaistudio] service does not support task type [sparse_embedding]") ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 55b10f2e2b9d7..cf3ac6979b8e3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -435,9 +435,10 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM ) ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [azureopenai] service")); assertThat( thrownException.getMessage(), - is("Failed to parse stored model [id] for [azureopenai] service, please delete and add the service again") + containsString("The [azureopenai] service does not support task type [sparse_embedding]") ); } } @@ -668,9 +669,10 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [azureopenai] service")); assertThat( thrownException.getMessage(), - is("Failed to parse stored model [id] for [azureopenai] service, please delete and add the service again") + containsString("The [azureopenai] service does not support task type [sparse_embedding]") ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 67f545a8104de..9642e7b85cdc7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -450,10 +450,8 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM ) ); - MatcherAssert.assertThat( - thrownException.getMessage(), - is("Failed to parse stored model [id] for [cohere] service, please delete and add the service again") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [cohere] service")); + assertThat(thrownException.getMessage(), containsString("The [cohere] service does not support task type [sparse_embedding]")); } } @@ -687,10 +685,8 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) ); - MatcherAssert.assertThat( - thrownException.getMessage(), - is("Failed to parse stored model [id] for [cohere] service, please delete and add the service again") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [cohere] service")); + assertThat(thrownException.getMessage(), containsString("The [cohere] service does not support task type [sparse_embedding]")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceParameterizedTests.java new file mode 100644 index 0000000000000..a71a107f3b9d1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceParameterizedTests.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests; + +public class CustomServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests { + public CustomServiceParameterizedTests(TestCase testCase) { + super(CustomServiceTests.createTestConfiguration(), testCase); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 55bb98705a2a3..44bf17c3ac96d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -16,7 +16,6 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -74,38 +73,52 @@ public CustomServiceTests() { super(createTestConfiguration()); } - private static TestConfiguration createTestConfiguration() { - return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION) { - @Override - protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { - return CustomServiceTests.createService(threadPool, clientManager); - } + public static TestConfiguration createTestConfiguration() { + return new TestConfiguration.Builder( + new CommonConfig( + TaskType.TEXT_EMBEDDING, + TaskType.CHAT_COMPLETION, + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.COMPLETION) + ) { + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return CustomServiceTests.createService(threadPool, clientManager); + } - @Override - protected Map createServiceSettingsMap(TaskType taskType) { - return CustomServiceTests.createServiceSettingsMap(taskType); - } + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return CustomServiceTests.createServiceSettingsMap(taskType); + } - @Override - protected Map createTaskSettingsMap() { - return CustomServiceTests.createTaskSettingsMap(); - } + @Override + protected Map createTaskSettingsMap() { + return CustomServiceTests.createTaskSettingsMap(); + } - @Override - protected Map createSecretSettingsMap() { - return CustomServiceTests.createSecretSettingsMap(); - } + @Override + protected Map createSecretSettingsMap() { + return CustomServiceTests.createSecretSettingsMap(); + } - @Override - protected void assertModel(Model model, TaskType taskType) { - CustomServiceTests.assertModel(model, taskType); - } + @Override + protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + CustomServiceTests.assertModel(model, taskType, modelIncludesSecrets); + } - @Override - protected EnumSet supportedStreamingTasks() { - return EnumSet.noneOf(TaskType.class); + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.noneOf(TaskType.class); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat( + rerankingInferenceService.rerankerWindowSize("any model"), + CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE) + ); + } } - }).enableUpdateModelTests(new UpdateModelConfiguration() { + ).enableUpdateModelTests(new UpdateModelConfiguration() { @Override protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { return createInternalEmbeddingModel(similarityMeasure); @@ -113,38 +126,40 @@ protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure) }).build(); } - private static void assertModel(Model model, TaskType taskType) { + private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { switch (taskType) { - case TEXT_EMBEDDING -> assertTextEmbeddingModel(model); - case COMPLETION -> assertCompletionModel(model); + case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets); + case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets); default -> fail("unexpected task type [" + taskType + "]"); } } - private static void assertTextEmbeddingModel(Model model) { - var customModel = assertCommonModelFields(model); + private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) { + var customModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(customModel.getTaskType(), is(TaskType.TEXT_EMBEDDING)); assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(TextEmbeddingResponseParser.class)); } - private static CustomModel assertCommonModelFields(Model model) { + private static CustomModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) { assertThat(model, instanceOf(CustomModel.class)); var customModel = (CustomModel) model; assertThat(customModel.getServiceSettings().getUrl(), is("http://www.abc.com")); assertThat(customModel.getTaskSettings().getParameters(), is(Map.of("test_key", "test_value"))); - assertThat( - customModel.getSecretSettings().getSecretParameters(), - is(Map.of("test_key", new SecureString("test_value".toCharArray()))) - ); + if (modelIncludesSecrets) { + assertThat( + customModel.getSecretSettings().getSecretParameters(), + is(Map.of("test_key", new SecureString("test_value".toCharArray()))) + ); + } return customModel; } - private static void assertCompletionModel(Model model) { - var customModel = assertCommonModelFields(model); + private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) { + var customModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(customModel.getTaskType(), is(TaskType.COMPLETION)); assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(CompletionResponseParser.class)); } @@ -807,17 +822,4 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { assertThat(requestMap.get("input"), is(List.of("a"))); } } - - @Override - public InferenceService createInferenceService() { - return createService(threadPool, clientManager); - } - - @Override - protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { - assertThat( - rerankingInferenceService.rerankerWindowSize("any model"), - CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE) - ); - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index fc50acdbd39b6..998559d102ab7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -443,10 +443,8 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM ) ); - MatcherAssert.assertThat( - thrownException.getMessage(), - is("Failed to parse stored model [id] for [jinaai] service, please delete and add the service again") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [jinaai] service")); + assertThat(thrownException.getMessage(), containsString("The [jinaai] service does not support task type [sparse_embedding]")); } } @@ -683,10 +681,8 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) ); - MatcherAssert.assertThat( - thrownException.getMessage(), - is("Failed to parse stored model [id] for [jinaai] service, please delete and add the service again") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [jinaai] service")); + assertThat(thrownException.getMessage(), containsString("The [jinaai] service does not support task type [sparse_embedding]")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceParameterizedTests.java new file mode 100644 index 0000000000000..04afbc28af10a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceParameterizedTests.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama; + +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests; + +public class LlamaServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests { + public LlamaServiceParameterizedTests(AbstractInferenceServiceParameterizedTests.TestCase testCase) { + super(LlamaServiceTests.createTestConfiguration(), testCase); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java index 243235211a7de..f6a0232db529c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java @@ -77,6 +77,9 @@ import static org.elasticsearch.ExceptionsHelper.unwrapCause; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; +import static org.elasticsearch.inference.TaskType.COMPLETION; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; @@ -106,39 +109,41 @@ public LlamaServiceTests() { super(createTestConfiguration()); } - private static TestConfiguration createTestConfiguration() { - return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) { + public static TestConfiguration createTestConfiguration() { + return new TestConfiguration.Builder( + new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION)) { - @Override - protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { - return LlamaServiceTests.createService(threadPool, clientManager); - } + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return LlamaServiceTests.createService(threadPool, clientManager); + } - @Override - protected Map createServiceSettingsMap(TaskType taskType) { - return LlamaServiceTests.createServiceSettingsMap(taskType); - } + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return LlamaServiceTests.createServiceSettingsMap(taskType); + } - @Override - protected Map createTaskSettingsMap() { - return new HashMap<>(); - } + @Override + protected Map createTaskSettingsMap() { + return new HashMap<>(); + } - @Override - protected Map createSecretSettingsMap() { - return LlamaServiceTests.createSecretSettingsMap(); - } + @Override + protected Map createSecretSettingsMap() { + return LlamaServiceTests.createSecretSettingsMap(); + } - @Override - protected void assertModel(Model model, TaskType taskType) { - LlamaServiceTests.assertModel(model, taskType); - } + @Override + protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + LlamaServiceTests.assertModel(model, taskType, modelIncludesSecrets); + } - @Override - protected EnumSet supportedStreamingTasks() { - return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + } } - }).enableUpdateModelTests(new UpdateModelConfiguration() { + ).enableUpdateModelTests(new UpdateModelConfiguration() { @Override protected LlamaEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { return createInternalEmbeddingModel(similarityMeasure); @@ -146,43 +151,46 @@ protected LlamaEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarity }).build(); } - private static void assertModel(Model model, TaskType taskType) { + private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { switch (taskType) { - case TEXT_EMBEDDING -> assertTextEmbeddingModel(model); - case COMPLETION -> assertCompletionModel(model); - case CHAT_COMPLETION -> assertChatCompletionModel(model); + case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets); + case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets); + case CHAT_COMPLETION -> assertChatCompletionModel(model, modelIncludesSecrets); default -> fail("unexpected task type [" + taskType + "]"); } } - private static void assertTextEmbeddingModel(Model model) { - var llamaModel = assertCommonModelFields(model); + private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) { + var llamaModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.TEXT_EMBEDDING)); } - private static LlamaModel assertCommonModelFields(Model model) { + private static LlamaModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) { assertThat(model, instanceOf(LlamaModel.class)); var llamaModel = (LlamaModel) model; assertThat(llamaModel.getServiceSettings().modelId(), is("model_id")); assertThat(llamaModel.uri.toString(), Matchers.is("http://www.abc.com")); assertThat(llamaModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE)); - assertThat( - ((DefaultSecretSettings) llamaModel.getSecretSettings()).apiKey(), - Matchers.is(new SecureString("secret".toCharArray())) - ); + + if (modelIncludesSecrets) { + assertThat( + ((DefaultSecretSettings) llamaModel.getSecretSettings()).apiKey(), + Matchers.is(new SecureString("secret".toCharArray())) + ); + } return llamaModel; } - private static void assertCompletionModel(Model model) { - var llamaModel = assertCommonModelFields(model); + private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) { + var llamaModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.COMPLETION)); } - private static void assertChatCompletionModel(Model model) { - var llamaModel = assertCommonModelFields(model); + private static void assertChatCompletionModel(Model model, boolean modelIncludesSecrets) { + var llamaModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION)); } @@ -236,10 +244,6 @@ private static LlamaEmbeddingsModel createInternalEmbeddingModel(@Nullable Simil ); } - protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() { - return "Failed to parse stored model [id] for [llama] service, please delete and add the service again"; - } - @Before public void init() throws Exception { webServer.start(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 50731811e4164..c8f6d0ee0e2fe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -742,10 +742,8 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM () -> service.parsePersistedConfigWithSecrets("id", TaskType.SPARSE_EMBEDDING, config.config(), config.secrets()) ); - assertThat( - thrownException.getMessage(), - is("Failed to parse stored model [id] for [mistral] service, please delete and add the service again") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [mistral] service")); + assertThat(thrownException.getMessage(), containsString("The [mistral] service does not support task type [sparse_embedding]")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceParameterizedTests.java new file mode 100644 index 0000000000000..941dab6699575 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceParameterizedTests.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openai; + +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests; + +import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceTests.createTestConfiguration; + +public class OpenAiServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests { + public OpenAiServiceParameterizedTests(AbstractInferenceServiceParameterizedTests.TestCase testCase) { + super(createTestConfiguration(), testCase); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 676dca2778141..c2cda175831ed 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -18,8 +18,10 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; @@ -49,24 +51,33 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; -import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; +import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.hamcrest.CoreMatchers; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; import java.io.IOException; +import java.net.URI; import java.util.Arrays; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -76,11 +87,9 @@ import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; -import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -102,8 +111,22 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class OpenAiServiceTests extends InferenceServiceTestCase { +public class OpenAiServiceTests extends AbstractInferenceServiceTests { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private static final String MODEL = "model"; + private static final String URL = "http://www.elastic.co"; + private static final String ORGANIZATION = "org"; + private static final int MAX_INPUT_TOKENS = 123; + private static final SimilarityMeasure SIMILARITY = SimilarityMeasure.DOT_PRODUCT; + private static final int DIMENSIONS = 100; + private static final boolean DIMENSIONS_SET_BY_USER = true; + private static final String USER = "user"; + private static final String HEADER_KEY = "header_key"; + private static final String HEADER_VALUE = "header_value"; + private static final Map HEADERS = Map.of(HEADER_KEY, HEADER_VALUE); + private static final String SECRET = "secret"; + private static final String INFERENCE_ID = "id"; + private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; @@ -122,221 +145,184 @@ public void shutdown() throws IOException { webServer.close(); } - public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModel() throws IOException { - try (var service = createOpenAiService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, exception -> fail("Unexpected exception: " + exception)); - - service.parseRequestConfig( - "id", - TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - getServiceSettingsMap("model", "url", "org"), - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ), - modelVerificationListener - ); - } + public OpenAiServiceTests() { + super(createTestConfiguration()); } - public void testParseRequestConfig_CreatesAnOpenAiChatCompletionsModel() throws IOException { - var url = "url"; - var organization = "org"; - var model = "model"; - var user = "user"; - var secret = "secret"; - - try (var service = createOpenAiService()) { - ActionListener modelVerificationListener = ActionListener.wrap(m -> { - assertThat(m, instanceOf(OpenAiChatCompletionModel.class)); - - var completionsModel = (OpenAiChatCompletionModel) m; - - assertThat(completionsModel.getServiceSettings().uri().toString(), is(url)); - assertThat(completionsModel.getServiceSettings().organizationId(), is(organization)); - assertThat(completionsModel.getServiceSettings().modelId(), is(model)); - assertThat(completionsModel.getTaskSettings().user(), is(user)); - assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret)); + public static TestConfiguration createTestConfiguration() { + return new TestConfiguration.Builder( + new CommonConfig( + TaskType.TEXT_EMBEDDING, + TaskType.RERANK, + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION) + ) { + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return OpenAiServiceTests.createService(threadPool, clientManager); + } - }, exception -> fail("Unexpected exception: " + exception)); + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return createServiceSettingsMap(taskType, ConfigurationParseContext.REQUEST); + } - service.parseRequestConfig( - "id", - TaskType.COMPLETION, - getRequestConfigMap( - getServiceSettingsMap(model, url, organization), - getOpenAiTaskSettingsMap(user), - getSecretSettingsMap(secret) - ), - modelVerificationListener - ); - } - } + @Override + protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) { + return OpenAiServiceTests.createServiceSettingsMap(taskType, parseContext); + } - public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { - try (var service = createOpenAiService()) { - ActionListener modelVerificationListener = ActionListener.wrap( - model -> fail("Expected exception, but got model: " + model), - exception -> { - assertThat(exception, instanceOf(ElasticsearchStatusException.class)); - assertThat(exception.getMessage(), is("The [openai] service does not support task type [sparse_embedding]")); + @Override + protected Map createTaskSettingsMap() { + return OpenAiServiceTests.createTaskSettingsMap(); } - ); - service.parseRequestConfig( - "id", - TaskType.SPARSE_EMBEDDING, - getRequestConfigMap( - getServiceSettingsMap("model", "url", "org"), - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ), - modelVerificationListener - ); - } - } + @Override + protected Map createSecretSettingsMap() { + return getSecretSettingsMap(SECRET); + } - public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { - try (var service = createOpenAiService()) { - var config = getRequestConfigMap( - getServiceSettingsMap("model", "url", "org"), - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ); - config.put("extra_key", "value"); - - ActionListener modelVerificationListener = ActionListener.wrap( - model -> fail("Expected exception, but got model: " + model), - exception -> { - assertThat(exception, instanceOf(ElasticsearchStatusException.class)); - assertThat( - exception.getMessage(), - is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service") - ); + @Override + protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + OpenAiServiceTests.assertModel(model, taskType, modelIncludesSecrets); } - ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); - } + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + } + } + ).enableUpdateModelTests(new UpdateModelConfiguration() { + @Override + protected OpenAiEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { + return createInternalEmbeddingModel(similarityMeasure, null); + } + }).build(); } - public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { - try (var service = createOpenAiService()) { - var serviceSettings = getServiceSettingsMap("model", "url", "org"); - serviceSettings.put("extra_key", "value"); - - var config = getRequestConfigMap(serviceSettings, getOpenAiTaskSettingsMap("user"), getSecretSettingsMap("secret")); + private static Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) { + var settingsMap = new HashMap( + Map.of( + ServiceFields.MODEL_ID, + MODEL, + ServiceFields.URL, + URL, + OpenAiServiceFields.ORGANIZATION, + ORGANIZATION, + ServiceFields.MAX_INPUT_TOKENS, + MAX_INPUT_TOKENS + ) + ); - ActionListener modelVerificationListener = ActionListener.wrap((model) -> { - fail("Expected exception, but got model: " + model); - }, e -> { - assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service")); - }); + if (taskType == TaskType.TEXT_EMBEDDING) { + settingsMap.putAll(Map.of(ServiceFields.SIMILARITY, SIMILARITY.toString(), ServiceFields.DIMENSIONS, DIMENSIONS)); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + if (parseContext == ConfigurationParseContext.PERSISTENT) { + settingsMap.put(OpenAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, DIMENSIONS_SET_BY_USER); + } } - } - public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { - try (var service = createOpenAiService()) { - var taskSettingsMap = getOpenAiTaskSettingsMap("user"); - taskSettingsMap.put("extra_key", "value"); - - var config = getRequestConfigMap(getServiceSettingsMap("model", "url", "org"), taskSettingsMap, getSecretSettingsMap("secret")); + return settingsMap; + } - ActionListener modelVerificationListener = ActionListener.wrap((model) -> { - fail("Expected exception, but got model: " + model); - }, e -> { - assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service")); - }); + private static Map createTaskSettingsMap() { + return new HashMap<>(Map.of(OpenAiServiceFields.USER, USER, OpenAiServiceFields.HEADERS, HEADERS)); + } - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + switch (taskType) { + case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets); + case COMPLETION, CHAT_COMPLETION -> assertCompletionModel(model, modelIncludesSecrets); + default -> fail("unexpected task type: " + taskType); } } - public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { - try (var service = createOpenAiService()) { - var secretSettingsMap = getSecretSettingsMap("secret"); - secretSettingsMap.put("extra_key", "value"); - - var config = getRequestConfigMap( - getServiceSettingsMap("model", "url", "org"), - getOpenAiTaskSettingsMap("user"), - secretSettingsMap - ); + private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) { + assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - ActionListener modelVerificationListener = ActionListener.wrap((model) -> { - fail("Expected exception, but got model: " + model); - }, e -> { - assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service")); - }); + var embeddingsModel = (OpenAiEmbeddingsModel) model; + assertThat( + embeddingsModel.getServiceSettings(), + is( + new OpenAiEmbeddingsServiceSettings( + MODEL, + URI.create(URL), + ORGANIZATION, + SIMILARITY, + DIMENSIONS, + MAX_INPUT_TOKENS, + DIMENSIONS_SET_BY_USER, + OpenAiEmbeddingsServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS + ) + ) + ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + assertThat(embeddingsModel.getTaskSettings(), is(new OpenAiEmbeddingsTaskSettings(USER, HEADERS))); + if (modelIncludesSecrets) { + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(SECRET)); + } else { + assertNull(embeddingsModel.getSecretSettings()); } } - public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException { - try (var service = createOpenAiService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); + private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) { + assertThat(model, instanceOf(OpenAiChatCompletionModel.class)); - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, exception -> fail("Unexpected exception: " + exception)); + var completionModel = (OpenAiChatCompletionModel) model; - service.parseRequestConfig( - "id", - TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - getServiceSettingsMap("model", null, null), - getOpenAiTaskSettingsMap(null), - getSecretSettingsMap("secret") - ), - modelVerificationListener - ); - } - } + assertThat( + completionModel.getServiceSettings(), + is( + new OpenAiChatCompletionServiceSettings( + MODEL, + URI.create(URL), + ORGANIZATION, + MAX_INPUT_TOKENS, + OpenAiChatCompletionServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS + ) + ) + ); - public void testParseRequestConfig_CreatesAnOpenAiChatCompletionsModelWithoutUserWithoutUserUrlOrganization() throws IOException { - var model = "model"; - var secret = "secret"; + assertThat(completionModel.getTaskSettings(), is(new OpenAiChatCompletionTaskSettings(USER, HEADERS))); - try (var service = createOpenAiService()) { - ActionListener modelVerificationListener = ActionListener.wrap(m -> { - assertThat(m, instanceOf(OpenAiChatCompletionModel.class)); + assertSecrets(completionModel.getSecretSettings(), modelIncludesSecrets); + } - var completionsModel = (OpenAiChatCompletionModel) m; - assertNull(completionsModel.getServiceSettings().uri()); - assertNull(completionsModel.getServiceSettings().organizationId()); - assertThat(completionsModel.getServiceSettings().modelId(), is(model)); - assertNull(completionsModel.getTaskSettings().user()); - assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret)); + private static void assertSecrets(DefaultSecretSettings secretSettings, boolean modelIncludesSecrets) { + if (modelIncludesSecrets) { + assertThat(secretSettings.apiKey().toString(), is(SECRET)); + } + } - }, exception -> fail("Unexpected exception: " + exception)); + private static OpenAiEmbeddingsModel createInternalEmbeddingModel( + SimilarityMeasure similarityMeasure, + @Nullable ChunkingSettings chunkingSettings + ) { + return createInternalEmbeddingModel(similarityMeasure, URL, chunkingSettings); + } - service.parseRequestConfig( - "id", - TaskType.COMPLETION, - getRequestConfigMap(getServiceSettingsMap(model, null, null), getOpenAiTaskSettingsMap(null), getSecretSettingsMap(secret)), - modelVerificationListener - ); - } + private static OpenAiEmbeddingsModel createInternalEmbeddingModel( + SimilarityMeasure similarityMeasure, + @Nullable String url, + @Nullable ChunkingSettings chunkingSettings + ) { + return new OpenAiEmbeddingsModel( + INFERENCE_ID, + TaskType.TEXT_EMBEDDING, + "service", + new OpenAiEmbeddingsServiceSettings( + MODEL, + url == null ? null : URI.create(url), + ORGANIZATION, + similarityMeasure, + DIMENSIONS, + DIMENSIONS, + false, + null + ), + new OpenAiEmbeddingsTaskSettings(USER, HEADERS), + chunkingSettings, + new DefaultSecretSettings(new SecureString(SECRET.toCharArray())) + ); } public void testParseRequestConfig_MovesModel() throws IOException { @@ -365,497 +351,6 @@ public void testParseRequestConfig_MovesModel() throws IOException { } } - public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = createOpenAiService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, exception -> fail("Unexpected exception: " + exception)); - - service.parseRequestConfig( - "id", - TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - getServiceSettingsMap("model", null, null), - getOpenAiTaskSettingsMap(null), - createRandomChunkingSettingsMap(), - getSecretSettingsMap("secret") - ), - modelVerificationListener - ); - } - } - - public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = createOpenAiService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, exception -> fail("Unexpected exception: " + exception)); - - service.parseRequestConfig( - "id", - TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - getServiceSettingsMap("model", null, null), - getOpenAiTaskSettingsMap(null), - getSecretSettingsMap("secret") - ), - modelVerificationListener - ); - } - } - - public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModel() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", 100, null, false), - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org"), - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ); - - var thrownException = expectThrows( - ElasticsearchStatusException.class, - () -> service.parsePersistedConfigWithSecrets( - "id", - TaskType.SPARSE_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ) - ); - - assertThat( - thrownException.getMessage(), - is("Failed to parse stored model [id] for [openai] service, please delete and add the service again") - ); - } - } - - public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", null, null, null, null, true), - getOpenAiTaskSettingsMap(null), - getSecretSettingsMap("secret") - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", null, null, null, null, true), - getOpenAiTaskSettingsMap(null), - createRandomChunkingSettingsMap(), - getSecretSettingsMap("secret") - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", null, null, null, null, true), - getOpenAiTaskSettingsMap(null), - getSecretSettingsMap("secret") - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, null, true), - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ); - persistedConfig.config().put("extra_key", "value"); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { - try (var service = createOpenAiService()) { - var secretSettingsMap = getSecretSettingsMap("secret"); - secretSettingsMap.put("extra_key", "value"); - - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, null, true), - getOpenAiTaskSettingsMap("user"), - secretSettingsMap - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, null, true), - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ); - persistedConfig.secrets().put("extra_key", "value"); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { - try (var service = createOpenAiService()) { - var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true); - serviceSettingsMap.put("extra_key", "value"); - - var persistedConfig = getPersistedConfigMap( - serviceSettingsMap, - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { - try (var service = createOpenAiService()) { - var taskSettingsMap = getOpenAiTaskSettingsMap("user"); - taskSettingsMap.put("extra_key", "value"); - - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, null, true), - taskSettingsMap, - getSecretSettingsMap("secret") - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModel() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, null, true), - getOpenAiTaskSettingsMap("user") - ); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertNull(embeddingsModel.getSecretSettings()); - } - } - - public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org"), getOpenAiTaskSettingsMap("user")); - - var thrownException = expectThrows( - ElasticsearchStatusException.class, - () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) - ); - - assertThat( - thrownException.getMessage(), - is("Failed to parse stored model [id] for [openai] service, please delete and add the service again") - ); - } - } - - public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", null, null, null, null, true), - getOpenAiTaskSettingsMap(null) - ); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertNull(embeddingsModel.getSecretSettings()); - } - } - - public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", null, null, null, null, true), - getOpenAiTaskSettingsMap(null), - createRandomChunkingSettingsMap() - ); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertNull(embeddingsModel.getSecretSettings()); - } - } - - public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", null, null, null, null, true), - getOpenAiTaskSettingsMap(null) - ); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertNull(embeddingsModel.getSecretSettings()); - } - } - - public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, null, true), - getOpenAiTaskSettingsMap("user") - ); - persistedConfig.config().put("extra_key", "value"); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertNull(embeddingsModel.getSecretSettings()); - } - } - - public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { - try (var service = createOpenAiService()) { - var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true); - serviceSettingsMap.put("extra_key", "value"); - - var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getOpenAiTaskSettingsMap("user")); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertNull(embeddingsModel.getSecretSettings()); - } - } - - public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { - try (var service = createOpenAiService()) { - var taskSettingsMap = getOpenAiTaskSettingsMap("user"); - taskSettingsMap.put("extra_key", "value"); - - var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org", null, null, true), taskSettingsMap); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertNull(embeddingsModel.getSecretSettings()); - } - } - public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException { var sender = mock(Sender.class); @@ -1681,6 +1176,11 @@ public void testGetConfiguration() throws Exception { } } + private static OpenAiService createService(ThreadPool threadPool, HttpClientManager clientManager) { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + return new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); + } + private OpenAiService createOpenAiService() { return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 8cad8cbad208a..d7f5726af85e0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -409,9 +409,10 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM ) ); - MatcherAssert.assertThat( + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [voyageai] service")); + assertThat( thrownException.getMessage(), - is("Failed to parse stored model [id] for [voyageai] service, please delete and add the service again") + containsString("The [voyageai] service does not support task type [sparse_embedding]") ); } } @@ -624,9 +625,10 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) ); - MatcherAssert.assertThat( + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [voyageai] service")); + assertThat( thrownException.getMessage(), - is("Failed to parse stored model [id] for [voyageai] service, please delete and add the service again") + containsString("The [voyageai] service does not support task type [sparse_embedding]") ); } }