From 028d4458d69e97f001bd5185ff0eeb24561b4cb1 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 24 Sep 2025 09:31:25 -0400 Subject: [PATCH 01/10] Refactoring openai --- .../inference/services/ServiceUtils.java | 13 + .../services/openai/OpenAiService.java | 4 +- .../OpenAiChatCompletionServiceSettings.java | 2 +- .../OpenAiEmbeddingsServiceSettings.java | 4 +- .../AbstractInferenceServiceTests.java | 211 ++++- .../services/ai21/Ai21ServiceTests.java | 10 +- .../services/custom/CustomServiceTests.java | 10 +- .../services/llama/LlamaServiceTests.java | 4 +- .../services/openai/OpenAiServiceTests.java | 743 ++++-------------- 9 files changed, 379 insertions(+), 622 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 874625c93a528..00f3023954980 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -1067,6 +1067,10 @@ public interface EnumConstructor> { E apply(String name) throws IllegalArgumentException; } + /** + * @deprecated use {@link #parsePersistedConfigErrorMsg(String, String, TaskType)} instead + */ + @Deprecated public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName) { return format( "Failed to parse stored model [%s] for [%s] service, please delete and add the service again", @@ -1075,6 +1079,15 @@ public static String parsePersistedConfigErrorMsg(String inferenceEntityId, Stri ); } + public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) { + return format( + "Failed to parse stored model [%s] for [%s] service, error: [%s]. Please delete and add the service again", + inferenceEntityId, + serviceName, + TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName) + ); + } + public static ElasticsearchStatusException createInvalidModelException(Model model) { return new ElasticsearchStatusException( format( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index d2b7dcc527aaa..bef5670b52058 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -232,7 +232,7 @@ public OpenAiModel parsePersistedConfigWithSecrets( taskSettingsMap, chunkingSettings, secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) ); } @@ -255,7 +255,7 @@ public OpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskT taskSettingsMap, chunkingSettings, null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java index 88840cc1202ac..6d340320c655c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java @@ -47,7 +47,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject // The rate limit for usage tier 1 is 500 request per minute for most of the completion models // To find this information you need to access your account's limits https://platform.openai.com/account/limits // 500 requests per minute - private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500); + public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500); public static OpenAiChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java index 20b9bf931290c..e9b6a7c77a5fa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java @@ -50,11 +50,11 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl public static final String NAME = "openai_service_settings"; - static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; + public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; // The rate limit for usage tier 1 is 3000 request per minute for the text embedding models // To find this information you need to access your account's limits https://platform.openai.com/account/limits // 3000 requests per minute - private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); public static OpenAiEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { return switch (context) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java index ec07d7b547004..9f7fdfdf17a11 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java @@ -7,11 +7,14 @@ package org.elasticsearch.xpack.inference.services; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; +import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -20,6 +23,8 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.Utils; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.junit.After; @@ -27,11 +32,14 @@ import org.junit.Before; import java.io.IOException; +import java.util.Arrays; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.BiFunction; +import java.util.function.Function; import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; @@ -39,6 +47,7 @@ import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -56,6 +65,7 @@ public abstract class AbstractInferenceServiceTests extends InferenceServiceTest protected final MockWebServer webServer = new MockWebServer(); protected ThreadPool threadPool; protected HttpClientManager clientManager; + protected TestCase testCase; @Override @Before @@ -77,8 +87,9 @@ public void tearDown() throws Exception { private final TestConfiguration testConfiguration; - public AbstractInferenceServiceTests(TestConfiguration testConfiguration) { + public AbstractInferenceServiceTests(TestConfiguration testConfiguration, TestCase testCase) { this.testConfiguration = Objects.requireNonNull(testConfiguration); + this.testCase = testCase; } /** @@ -105,7 +116,7 @@ public TestConfiguration build() { } /** - * Configurations that useful for most tests + * Configurations that are useful for most tests */ public abstract static class CommonConfig { @@ -121,6 +132,10 @@ public CommonConfig(TaskType taskType, @Nullable TaskType unsupportedTaskType) { protected abstract Map createServiceSettingsMap(TaskType taskType); + protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) { + return createServiceSettingsMap(taskType); + } + protected abstract Map createTaskSettingsMap(); protected abstract Map createSecretSettingsMap(); @@ -154,12 +169,17 @@ protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) { } }; + @Override + public InferenceService createInferenceService() { + return testConfiguration.commonConfig.createService(threadPool, clientManager); + } + public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws Exception { var parseRequestConfigTestConfig = testConfiguration.commonConfig; try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var config = getRequestConfigMap( - parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING), + parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.REQUEST), parseRequestConfigTestConfig.createTaskSettingsMap(), parseRequestConfigTestConfig.createSecretSettingsMap() ); @@ -167,7 +187,32 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws Exception { var listener = new PlainActionFuture(); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener); - parseRequestConfigTestConfig.assertModel(listener.actionGet(TIMEOUT), TaskType.TEXT_EMBEDDING); + var model = listener.actionGet(TIMEOUT); + var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(Map.of()); + assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); + parseRequestConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws Exception { + var parseRequestConfigTestConfig = testConfiguration.commonConfig; + + try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { + var chunkingSettingsMap = createRandomChunkingSettingsMap(); + var config = getRequestConfigMap( + parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.REQUEST), + parseRequestConfigTestConfig.createTaskSettingsMap(), + chunkingSettingsMap, + parseRequestConfigTestConfig.createSecretSettingsMap() + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener); + + var model = listener.actionGet(TIMEOUT); + var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap); + assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); + parseRequestConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); } } @@ -176,7 +221,7 @@ public void testParseRequestConfig_CreatesACompletionModel() throws Exception { try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var config = getRequestConfigMap( - parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION), + parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.REQUEST), parseRequestConfigTestConfig.createTaskSettingsMap(), parseRequestConfigTestConfig.createSecretSettingsMap() ); @@ -193,7 +238,10 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws Exception try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var config = getRequestConfigMap( - parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + parseRequestConfigTestConfig.createServiceSettingsMap( + parseRequestConfigTestConfig.taskType, + ConfigurationParseContext.REQUEST + ), parseRequestConfigTestConfig.createTaskSettingsMap(), parseRequestConfigTestConfig.createSecretSettingsMap() ); @@ -214,7 +262,10 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var config = getRequestConfigMap( - parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + parseRequestConfigTestConfig.createServiceSettingsMap( + parseRequestConfigTestConfig.taskType, + ConfigurationParseContext.REQUEST + ), parseRequestConfigTestConfig.createTaskSettingsMap(), parseRequestConfigTestConfig.createSecretSettingsMap() ); @@ -231,7 +282,10 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { var parseRequestConfigTestConfig = testConfiguration.commonConfig; try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { - var serviceSettings = parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType); + var serviceSettings = parseRequestConfigTestConfig.createServiceSettingsMap( + parseRequestConfigTestConfig.taskType, + ConfigurationParseContext.REQUEST + ); serviceSettings.put("extra_key", "value"); var config = getRequestConfigMap( serviceSettings, @@ -253,7 +307,10 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() var taskSettings = parseRequestConfigTestConfig.createTaskSettingsMap(); taskSettings.put("extra_key", "value"); var config = getRequestConfigMap( - parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + parseRequestConfigTestConfig.createServiceSettingsMap( + parseRequestConfigTestConfig.taskType, + ConfigurationParseContext.REQUEST + ), taskSettings, parseRequestConfigTestConfig.createSecretSettingsMap() ); @@ -272,7 +329,10 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var secretSettingsMap = parseRequestConfigTestConfig.createSecretSettingsMap(); secretSettingsMap.put("extra_key", "value"); var config = getRequestConfigMap( - parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType), + parseRequestConfigTestConfig.createServiceSettingsMap( + parseRequestConfigTestConfig.taskType, + ConfigurationParseContext.REQUEST + ), parseRequestConfigTestConfig.createTaskSettingsMap(), secretSettingsMap ); @@ -285,26 +345,122 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap } } - // parsePersistedConfigWithSecrets + @ParametersFactory + public static Iterable parameters() throws IOException { + return Arrays.asList( + new TestCase[][] { + { + new TestCase( + "Test parsing persisted config without chunking settings", + testConfiguration -> getPersistedConfigMap( + testConfiguration.commonConfig.createServiceSettingsMap( + TaskType.TEXT_EMBEDDING, + ConfigurationParseContext.PERSISTENT + ), + testConfiguration.commonConfig.createTaskSettingsMap(), + null + ), + (service, persistedConfig) -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()), + null + ) } } + ); + } - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throws Exception { + public record TestCase( + @Nullable String description, + Function createPersistedConfig, + BiFunction serviceCallback, + @Nullable Map chunkingSettingsMap + ) {} + + public void testPersistedConfig() throws Exception { var parseConfigTestConfig = testConfiguration.commonConfig; + var persistedConfig = testCase.createPersistedConfig.apply(testConfiguration); try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING), - parseConfigTestConfig.createTaskSettingsMap(), - parseConfigTestConfig.createSecretSettingsMap() + + var model = testCase.serviceCallback.apply(service, persistedConfig); + + var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap( + testCase.chunkingSettingsMap == null ? Map.of() : testCase.chunkingSettingsMap ); + assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); - var model = service.parsePersistedConfigWithSecrets( + parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + } + + parseConfigHelper(service -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfigMap.config()), null); + } + + // parsePersistedConfig tests + + public void testParsePersistedConfig_CreatesAnEmbeddingsModel() throws Exception { + var parseConfigTestConfig = testConfiguration.commonConfig; + var persistedConfigMap = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT), + parseConfigTestConfig.createTaskSettingsMap(), + null + ); + + parseConfigHelper(service -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfigMap.config()), null); + } + + private void parseConfigHelper(Function serviceParseCallback, @Nullable Map chunkingSettingsMap) + throws Exception { + var parseConfigTestConfig = testConfiguration.commonConfig; + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + + var model = serviceParseCallback.apply(service); + + var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap == null ? Map.of() : chunkingSettingsMap); + assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); + + parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + } + } + + // parsePersistedConfigWithSecrets + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throws Exception { + var parseConfigTestConfig = testConfiguration.commonConfig; + + var persistedConfigMap = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT), + parseConfigTestConfig.createTaskSettingsMap(), + parseConfigTestConfig.createSecretSettingsMap() + ); + + parseConfigHelper( + service -> service.parsePersistedConfigWithSecrets( "id", TaskType.TEXT_EMBEDDING, persistedConfigMap.config(), persistedConfigMap.secrets() - ); - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); - } + ), + null + ); + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsAreProvided() throws Exception { + var parseConfigTestConfig = testConfiguration.commonConfig; + + var chunkingSettingsMap = createRandomChunkingSettingsMap(); + var persistedConfigMap = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT), + parseConfigTestConfig.createTaskSettingsMap(), + chunkingSettingsMap, + parseConfigTestConfig.createSecretSettingsMap() + ); + + parseConfigHelper( + service -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfigMap.config(), + persistedConfigMap.secrets() + ), + chunkingSettingsMap + ); } public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws Exception { @@ -312,7 +468,7 @@ public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION), + parseConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), parseConfigTestConfig.createTaskSettingsMap(), parseConfigTestConfig.createSecretSettingsMap() ); @@ -332,7 +488,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() thr try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT), parseConfigTestConfig.createTaskSettingsMap(), parseConfigTestConfig.createSecretSettingsMap() ); @@ -365,7 +521,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT), parseConfigTestConfig.createTaskSettingsMap(), parseConfigTestConfig.createSecretSettingsMap() ); @@ -385,7 +541,10 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { var parseConfigTestConfig = testConfiguration.commonConfig; try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var serviceSettings = parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType); + var serviceSettings = parseConfigTestConfig.createServiceSettingsMap( + parseConfigTestConfig.taskType, + ConfigurationParseContext.PERSISTENT + ); serviceSettings.put("extra_key", "value"); var persistedConfigMap = getPersistedConfigMap( serviceSettings, @@ -410,7 +569,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTask var taskSettings = parseConfigTestConfig.createTaskSettingsMap(); taskSettings.put("extra_key", "value"); var config = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT), taskSettings, parseConfigTestConfig.createSecretSettingsMap() ); @@ -427,7 +586,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecr var secretSettingsMap = parseConfigTestConfig.createSecretSettingsMap(); secretSettingsMap.put("extra_key", "value"); var config = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType), + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT), parseConfigTestConfig.createTaskSettingsMap(), secretSettingsMap ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java index cbb119d3e5710..5fb878403f05a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java @@ -19,7 +19,6 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -82,8 +81,8 @@ public class Ai21ServiceTests extends AbstractInferenceServiceTests { private ThreadPool threadPool; private HttpClientManager clientManager; - public Ai21ServiceTests() { - super(createTestConfiguration()); + public Ai21ServiceTests(TestCase testCase) { + super(createTestConfiguration(), testCase); } private static AbstractInferenceServiceTests.TestConfiguration createTestConfiguration() { @@ -561,9 +560,4 @@ private Map getRequestConfigMap(Map serviceSetti return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings)); } - - @Override - public InferenceService createInferenceService() { - return createService(); - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 55bb98705a2a3..ec303f0c7b796 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -16,7 +16,6 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -70,8 +69,8 @@ public class CustomServiceTests extends AbstractInferenceServiceTests { - public CustomServiceTests() { - super(createTestConfiguration()); + public CustomServiceTests(TestCase testCase) { + super(createTestConfiguration(), testCase); } private static TestConfiguration createTestConfiguration() { @@ -808,11 +807,6 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { } } - @Override - public InferenceService createInferenceService() { - return createService(threadPool, clientManager); - } - @Override protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java index 243235211a7de..5d81c8b062492 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java @@ -102,8 +102,8 @@ public class LlamaServiceTests extends AbstractInferenceServiceTests { private ThreadPool threadPool; private HttpClientManager clientManager; - public LlamaServiceTests() { - super(createTestConfiguration()); + public LlamaServiceTests(TestCase testCase) { + super(createTestConfiguration(), testCase); } private static TestConfiguration createTestConfiguration() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 676dca2778141..705afe78d3196 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -18,8 +18,10 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; @@ -49,24 +51,33 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; -import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; +import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.hamcrest.CoreMatchers; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; import java.io.IOException; +import java.net.URI; import java.util.Arrays; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -102,8 +113,22 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -public class OpenAiServiceTests extends InferenceServiceTestCase { +public class OpenAiServiceTests extends AbstractInferenceServiceTests { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private static final String MODEL = "model"; + private static final String URL = "http://www.elastic.co"; + private static final String ORGANIZATION = "org"; + private static final int MAX_INPUT_TOKENS = 123; + private static final SimilarityMeasure SIMILARITY = SimilarityMeasure.DOT_PRODUCT; + private static final int DIMENSIONS = 100; + private static final boolean DIMENSIONS_SET_BY_USER = true; + private static final String USER = "user"; + private static final String HEADER_KEY = "header_key"; + private static final String HEADER_VALUE = "header_value"; + private static final Map HEADERS = Map.of(HEADER_KEY, HEADER_VALUE); + private static final String SECRET = "secret"; + private static final String INFERENCE_ID = "id"; + private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; @@ -122,221 +147,174 @@ public void shutdown() throws IOException { webServer.close(); } - public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModel() throws IOException { - try (var service = createOpenAiService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, exception -> fail("Unexpected exception: " + exception)); - - service.parseRequestConfig( - "id", - TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - getServiceSettingsMap("model", "url", "org"), - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ), - modelVerificationListener - ); - } + public OpenAiServiceTests(TestCase testCase) { + super(createTestConfiguration(), testCase); } - public void testParseRequestConfig_CreatesAnOpenAiChatCompletionsModel() throws IOException { - var url = "url"; - var organization = "org"; - var model = "model"; - var user = "user"; - var secret = "secret"; - - try (var service = createOpenAiService()) { - ActionListener modelVerificationListener = ActionListener.wrap(m -> { - assertThat(m, instanceOf(OpenAiChatCompletionModel.class)); + private static TestConfiguration createTestConfiguration() { + return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.RERANK) { + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return OpenAiServiceTests.createService(threadPool, clientManager); + } - var completionsModel = (OpenAiChatCompletionModel) m; + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return createServiceSettingsMap(taskType, ConfigurationParseContext.REQUEST); + } - assertThat(completionsModel.getServiceSettings().uri().toString(), is(url)); - assertThat(completionsModel.getServiceSettings().organizationId(), is(organization)); - assertThat(completionsModel.getServiceSettings().modelId(), is(model)); - assertThat(completionsModel.getTaskSettings().user(), is(user)); - assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret)); + @Override + protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) { + return OpenAiServiceTests.createServiceSettingsMap(taskType, parseContext); + } - }, exception -> fail("Unexpected exception: " + exception)); + @Override + protected Map createTaskSettingsMap() { + return OpenAiServiceTests.createTaskSettingsMap(); + } - service.parseRequestConfig( - "id", - TaskType.COMPLETION, - getRequestConfigMap( - getServiceSettingsMap(model, url, organization), - getOpenAiTaskSettingsMap(user), - getSecretSettingsMap(secret) - ), - modelVerificationListener - ); - } - } + @Override + protected Map createSecretSettingsMap() { + return getSecretSettingsMap(SECRET); + } - public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { - try (var service = createOpenAiService()) { - ActionListener modelVerificationListener = ActionListener.wrap( - model -> fail("Expected exception, but got model: " + model), - exception -> { - assertThat(exception, instanceOf(ElasticsearchStatusException.class)); - assertThat(exception.getMessage(), is("The [openai] service does not support task type [sparse_embedding]")); - } - ); + @Override + protected void assertModel(Model model, TaskType taskType) { + OpenAiServiceTests.assertModel(model, taskType); + } - service.parseRequestConfig( - "id", - TaskType.SPARSE_EMBEDDING, - getRequestConfigMap( - getServiceSettingsMap("model", "url", "org"), - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ), - modelVerificationListener - ); - } + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + } + }).enableUpdateModelTests(new UpdateModelConfiguration() { + @Override + protected OpenAiEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { + return createInternalEmbeddingModel(similarityMeasure, null); + } + }).build(); } - public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { - try (var service = createOpenAiService()) { - var config = getRequestConfigMap( - getServiceSettingsMap("model", "url", "org"), - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ); - config.put("extra_key", "value"); - - ActionListener modelVerificationListener = ActionListener.wrap( - model -> fail("Expected exception, but got model: " + model), - exception -> { - assertThat(exception, instanceOf(ElasticsearchStatusException.class)); - assertThat( - exception.getMessage(), - is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service") - ); - } + private static Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) { + var settingsMap = new HashMap( + Map.of( + ServiceFields.MODEL_ID, + MODEL, + ServiceFields.URL, + URL, + OpenAiServiceFields.ORGANIZATION, + ORGANIZATION, + ServiceFields.MAX_INPUT_TOKENS, + MAX_INPUT_TOKENS + ) + ); + + if (taskType == TaskType.TEXT_EMBEDDING) { + settingsMap.putAll( + Map.of( + ServiceFields.SIMILARITY, + SIMILARITY.toString(), + ServiceFields.DIMENSIONS, + DIMENSIONS + ) ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + if (parseContext == ConfigurationParseContext.PERSISTENT) { + settingsMap.put(OpenAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, DIMENSIONS_SET_BY_USER); + } } - } - public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { - try (var service = createOpenAiService()) { - var serviceSettings = getServiceSettingsMap("model", "url", "org"); - serviceSettings.put("extra_key", "value"); - - var config = getRequestConfigMap(serviceSettings, getOpenAiTaskSettingsMap("user"), getSecretSettingsMap("secret")); + return settingsMap; + } - ActionListener modelVerificationListener = ActionListener.wrap((model) -> { - fail("Expected exception, but got model: " + model); - }, e -> { - assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service")); - }); + private static Map createTaskSettingsMap() { + return new HashMap<>(Map.of(OpenAiServiceFields.USER, USER, OpenAiServiceFields.HEADERS, HEADERS)); + } - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + private static void assertModel(Model model, TaskType taskType) { + switch (taskType) { + case TEXT_EMBEDDING -> assertTextEmbeddingModel(model); + case COMPLETION, CHAT_COMPLETION -> assertCompletionModel(model); + default -> fail("unexpected task type: " + taskType); } } - public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { - try (var service = createOpenAiService()) { - var taskSettingsMap = getOpenAiTaskSettingsMap("user"); - taskSettingsMap.put("extra_key", "value"); - - var config = getRequestConfigMap(getServiceSettingsMap("model", "url", "org"), taskSettingsMap, getSecretSettingsMap("secret")); + private static void assertTextEmbeddingModel(Model model) { + assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - ActionListener modelVerificationListener = ActionListener.wrap((model) -> { - fail("Expected exception, but got model: " + model); - }, e -> { - assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service")); - }); + var embeddingsModel = (OpenAiEmbeddingsModel) model; + assertThat( + embeddingsModel.getServiceSettings(), + is( + new OpenAiEmbeddingsServiceSettings( + MODEL, + URI.create(URL), + ORGANIZATION, + SIMILARITY, + DIMENSIONS, + MAX_INPUT_TOKENS, + DIMENSIONS_SET_BY_USER, + OpenAiEmbeddingsServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS + ) + ) + ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); - } + assertThat(embeddingsModel.getTaskSettings(), is(new OpenAiEmbeddingsTaskSettings(USER, HEADERS))); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(SECRET)); } - public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { - try (var service = createOpenAiService()) { - var secretSettingsMap = getSecretSettingsMap("secret"); - secretSettingsMap.put("extra_key", "value"); + private static void assertCompletionModel(Model model) { + assertThat(model, instanceOf(OpenAiChatCompletionModel.class)); - var config = getRequestConfigMap( - getServiceSettingsMap("model", "url", "org"), - getOpenAiTaskSettingsMap("user"), - secretSettingsMap - ); + var completionModel = (OpenAiChatCompletionModel) model; - ActionListener modelVerificationListener = ActionListener.wrap((model) -> { - fail("Expected exception, but got model: " + model); - }, e -> { - assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service")); - }); + assertThat( + completionModel.getServiceSettings(), + is( + new OpenAiChatCompletionServiceSettings( + MODEL, + URI.create(URL), + ORGANIZATION, + MAX_INPUT_TOKENS, + OpenAiChatCompletionServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS + ) + ) + ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); - } + assertThat(completionModel.getTaskSettings(), is(new OpenAiChatCompletionTaskSettings(USER, HEADERS))); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(SECRET)); } - public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException { - try (var service = createOpenAiService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, exception -> fail("Unexpected exception: " + exception)); - - service.parseRequestConfig( - "id", - TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - getServiceSettingsMap("model", null, null), - getOpenAiTaskSettingsMap(null), - getSecretSettingsMap("secret") - ), - modelVerificationListener - ); - } + private static OpenAiEmbeddingsModel createInternalEmbeddingModel( + SimilarityMeasure similarityMeasure, + @Nullable ChunkingSettings chunkingSettings + ) { + return createInternalEmbeddingModel(similarityMeasure, URL, chunkingSettings); } - public void testParseRequestConfig_CreatesAnOpenAiChatCompletionsModelWithoutUserWithoutUserUrlOrganization() throws IOException { - var model = "model"; - var secret = "secret"; - - try (var service = createOpenAiService()) { - ActionListener modelVerificationListener = ActionListener.wrap(m -> { - assertThat(m, instanceOf(OpenAiChatCompletionModel.class)); - - var completionsModel = (OpenAiChatCompletionModel) m; - assertNull(completionsModel.getServiceSettings().uri()); - assertNull(completionsModel.getServiceSettings().organizationId()); - assertThat(completionsModel.getServiceSettings().modelId(), is(model)); - assertNull(completionsModel.getTaskSettings().user()); - assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret)); - - }, exception -> fail("Unexpected exception: " + exception)); - - service.parseRequestConfig( - "id", - TaskType.COMPLETION, - getRequestConfigMap(getServiceSettingsMap(model, null, null), getOpenAiTaskSettingsMap(null), getSecretSettingsMap(secret)), - modelVerificationListener - ); - } + private static OpenAiEmbeddingsModel createInternalEmbeddingModel( + SimilarityMeasure similarityMeasure, + @Nullable String url, + @Nullable ChunkingSettings chunkingSettings + ) { + return new OpenAiEmbeddingsModel( + INFERENCE_ID, + TaskType.TEXT_EMBEDDING, + "service", + new OpenAiEmbeddingsServiceSettings( + MODEL, + url == null ? null : URI.create(url), + ORGANIZATION, + similarityMeasure, + DIMENSIONS, + DIMENSIONS, + false, + null + ), + new OpenAiEmbeddingsTaskSettings(USER, HEADERS), + chunkingSettings, + new DefaultSecretSettings(new SecureString(SECRET.toCharArray())) + ); } public void testParseRequestConfig_MovesModel() throws IOException { @@ -365,392 +343,6 @@ public void testParseRequestConfig_MovesModel() throws IOException { } } - public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = createOpenAiService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, exception -> fail("Unexpected exception: " + exception)); - - service.parseRequestConfig( - "id", - TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - getServiceSettingsMap("model", null, null), - getOpenAiTaskSettingsMap(null), - createRandomChunkingSettingsMap(), - getSecretSettingsMap("secret") - ), - modelVerificationListener - ); - } - } - - public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = createOpenAiService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, exception -> fail("Unexpected exception: " + exception)); - - service.parseRequestConfig( - "id", - TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - getServiceSettingsMap("model", null, null), - getOpenAiTaskSettingsMap(null), - getSecretSettingsMap("secret") - ), - modelVerificationListener - ); - } - } - - public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModel() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", 100, null, false), - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org"), - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ); - - var thrownException = expectThrows( - ElasticsearchStatusException.class, - () -> service.parsePersistedConfigWithSecrets( - "id", - TaskType.SPARSE_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ) - ); - - assertThat( - thrownException.getMessage(), - is("Failed to parse stored model [id] for [openai] service, please delete and add the service again") - ); - } - } - - public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", null, null, null, null, true), - getOpenAiTaskSettingsMap(null), - getSecretSettingsMap("secret") - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", null, null, null, null, true), - getOpenAiTaskSettingsMap(null), - createRandomChunkingSettingsMap(), - getSecretSettingsMap("secret") - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", null, null, null, null, true), - getOpenAiTaskSettingsMap(null), - getSecretSettingsMap("secret") - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, null, true), - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ); - persistedConfig.config().put("extra_key", "value"); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { - try (var service = createOpenAiService()) { - var secretSettingsMap = getSecretSettingsMap("secret"); - secretSettingsMap.put("extra_key", "value"); - - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, null, true), - getOpenAiTaskSettingsMap("user"), - secretSettingsMap - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, null, true), - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ); - persistedConfig.secrets().put("extra_key", "value"); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { - try (var service = createOpenAiService()) { - var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true); - serviceSettingsMap.put("extra_key", "value"); - - var persistedConfig = getPersistedConfigMap( - serviceSettingsMap, - getOpenAiTaskSettingsMap("user"), - getSecretSettingsMap("secret") - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { - try (var service = createOpenAiService()) { - var taskSettingsMap = getOpenAiTaskSettingsMap("user"); - taskSettingsMap.put("extra_key", "value"); - - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, null, true), - taskSettingsMap, - getSecretSettingsMap("secret") - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - - public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModel() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, null, true), - getOpenAiTaskSettingsMap("user") - ); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertNull(embeddingsModel.getSecretSettings()); - } - } - - public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org"), getOpenAiTaskSettingsMap("user")); - - var thrownException = expectThrows( - ElasticsearchStatusException.class, - () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) - ); - - assertThat( - thrownException.getMessage(), - is("Failed to parse stored model [id] for [openai] service, please delete and add the service again") - ); - } - } - - public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", null, null, null, null, true), - getOpenAiTaskSettingsMap(null) - ); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertNull(embeddingsModel.getSecretSettings()); - } - } - public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( @@ -1681,6 +1273,11 @@ public void testGetConfiguration() throws Exception { } } + private static OpenAiService createService(ThreadPool threadPool, HttpClientManager clientManager) { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + return new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); + } + private OpenAiService createOpenAiService() { return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } From c3243fda24918816e61a457ba860c09a0a61c724 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 24 Sep 2025 17:21:08 -0400 Subject: [PATCH 02/10] Splitting up parameterized tests --- ...actInferenceServiceParameterizedTests.java | 469 ++++++++++++++++++ .../AbstractInferenceServiceTests.java | 116 ++++- .../services/ai21/Ai21ServiceTests.java | 25 +- .../services/custom/CustomServiceTests.java | 30 +- .../services/llama/LlamaServiceTests.java | 37 +- .../OpenAiServiceParameterizedTests.java | 18 + .../services/openai/OpenAiServiceTests.java | 84 +--- 7 files changed, 659 insertions(+), 120 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceParameterizedTests.java diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java new file mode 100644 index 0000000000000..f6fdefb1bac6f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java @@ -0,0 +1,469 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Strings; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.Utils; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.junit.After; +import org.junit.Assume; +import org.junit.Before; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; + +import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +/** + * Base class for testing inference services using parameterized tests. + */ +public abstract class AbstractInferenceServiceParameterizedTests extends InferenceServiceTestCase { + + private final AbstractInferenceServiceTests.TestConfiguration testConfiguration; + + protected final MockWebServer webServer = new MockWebServer(); + protected ThreadPool threadPool; + protected HttpClientManager clientManager; + protected TestCase testCase; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + webServer.start(); + threadPool = createThreadPool(inferenceUtilityExecutors()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public AbstractInferenceServiceParameterizedTests( + AbstractInferenceServiceTests.TestConfiguration testConfiguration, + TestCase testCase + ) { + this.testConfiguration = Objects.requireNonNull(testConfiguration); + this.testCase = testCase; + } + + @Override + public InferenceService createInferenceService() { + return testConfiguration.commonConfig().createService(threadPool, clientManager); + } + + @ParametersFactory + public static Iterable parameters() throws IOException { + return Arrays.asList( + new TestCase[][] { + // parsePersistedConfig + { + new TestCaseBuilder( + "Test parsing persisted config without chunking settings", + testConfiguration -> getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + null + ), + (service, persistedConfig, testConfiguration) -> service.parsePersistedConfig( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config() + ), + TaskType.TEXT_EMBEDDING + ).build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with chunking settings", + testConfiguration -> getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + createRandomChunkingSettingsMap(), + null + ), + (service, persistedConfig, testConfiguration) -> service.parsePersistedConfig( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config() + ), + TaskType.TEXT_EMBEDDING + ).build() }, + // parsePersistedConfigWithSecrets + { + new TestCaseBuilder( + "Test parsing persisted config with secrets creates an embeddings model", + testConfiguration -> getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + testConfiguration.commonConfig().createSecretSettingsMap() + ), + (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ), + TaskType.TEXT_EMBEDDING + ).withSecrets().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with with secrets creates an embeddings " + + "model when chunking settings are provided", + testConfiguration -> getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + createRandomChunkingSettingsMap(), + testConfiguration.commonConfig().createSecretSettingsMap() + ), + (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ), + TaskType.TEXT_EMBEDDING + ).withSecrets().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with with secrets creates a completion " + + "model when chunking settings are not provided", + testConfiguration -> getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + testConfiguration.commonConfig().createSecretSettingsMap() + ), + (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ), + TaskType.COMPLETION + ).withSecrets().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with with secrets throws exception for unsupported task type", + testConfiguration -> getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + testConfiguration.commonConfig().createSecretSettingsMap() + ), + (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets( + "id", + testConfiguration.commonConfig().unsupportedTaskType(), + persistedConfig.config(), + persistedConfig.secrets() + ), + TaskType.COMPLETION + ).withSecrets().build() } } + ); + } + + public record TestCase( + String description, + Function createPersistedConfig, + ServiceCallback serviceCallback, + TaskType expectedTaskType, + boolean modelIncludesSecrets, + boolean expectFailure + ) {} + + @FunctionalInterface + interface ServiceCallback { + Model parseConfigs( + SenderService service, + Utils.PersistedConfig persistedConfig, + AbstractInferenceServiceTests.TestConfiguration testConfiguration + ); + } + + private static class TestCaseBuilder { + private final String description; + private final Function createPersistedConfig; + private final ServiceCallback serviceCallback; + private final TaskType expectedTaskType; + private boolean modelIncludesSecrets; + private boolean expectFailure; + + TestCaseBuilder( + String description, + Function createPersistedConfig, + ServiceCallback serviceCallback, + TaskType expectedTaskType + ) { + this.description = description; + this.createPersistedConfig = createPersistedConfig; + this.serviceCallback = serviceCallback; + this.expectedTaskType = expectedTaskType; + } + + public TestCaseBuilder withSecrets() { + this.modelIncludesSecrets = true; + return this; + } + + public TestCaseBuilder withFailure() { + this.expectFailure = true; + return this; + } + + public TestCase build() { + return new TestCase(description, createPersistedConfig, serviceCallback, expectedTaskType, modelIncludesSecrets, expectFailure); + } + } + + public void testPersistedConfig() throws Exception { + var parseConfigTestConfig = testConfiguration.commonConfig(); + var persistedConfig = testCase.createPersistedConfig.apply(testConfiguration); + + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var model = testCase.serviceCallback.parseConfigs(service, persistedConfig, testConfiguration); + + if (persistedConfig.config().containsKey(ModelConfigurations.CHUNKING_SETTINGS)) { + @SuppressWarnings("unchecked") + var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap( + (Map) persistedConfig.config().get(ModelConfigurations.CHUNKING_SETTINGS) + ); + assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); + } + + parseConfigTestConfig.assertModel(model, testCase.expectedTaskType, testCase.modelIncludesSecrets); + } + } + + // parsePersistedConfigWithSecrets + + public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() throws Exception { + var parseConfigTestConfig = testConfiguration.commonConfig(); + + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var persistedConfigMap = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT), + parseConfigTestConfig.createTaskSettingsMap(), + parseConfigTestConfig.createSecretSettingsMap() + ); + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfigWithSecrets( + "id", + parseConfigTestConfig.unsupportedTaskType(), + persistedConfigMap.config(), + persistedConfigMap.secrets() + ) + ); + + assertThat( + exception.getMessage(), + containsString( + Strings.format(fetchPersistedConfigTaskTypeParsingErrorMessageFormat(), parseConfigTestConfig.unsupportedTaskType()) + ) + ); + } + } + + protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() { + return "service does not support task type [%s]"; + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + var parseConfigTestConfig = testConfiguration.commonConfig(); + + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var persistedConfigMap = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT), + parseConfigTestConfig.createTaskSettingsMap(), + parseConfigTestConfig.createSecretSettingsMap() + ); + persistedConfigMap.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + parseConfigTestConfig.taskType(), + persistedConfigMap.config(), + persistedConfigMap.secrets() + ); + + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType()); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + var parseConfigTestConfig = testConfiguration.commonConfig(); + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var serviceSettings = parseConfigTestConfig.createServiceSettingsMap( + parseConfigTestConfig.taskType(), + ConfigurationParseContext.PERSISTENT + ); + serviceSettings.put("extra_key", "value"); + var persistedConfigMap = getPersistedConfigMap( + serviceSettings, + parseConfigTestConfig.createTaskSettingsMap(), + parseConfigTestConfig.createSecretSettingsMap() + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + parseConfigTestConfig.taskType(), + persistedConfigMap.config(), + persistedConfigMap.secrets() + ); + + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType()); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + var parseConfigTestConfig = testConfiguration.commonConfig(); + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var taskSettings = parseConfigTestConfig.createTaskSettingsMap(); + taskSettings.put("extra_key", "value"); + var config = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT), + taskSettings, + parseConfigTestConfig.createSecretSettingsMap() + ); + + var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType(), config.config(), config.secrets()); + + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType()); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + var parseConfigTestConfig = testConfiguration.commonConfig(); + try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { + var secretSettingsMap = parseConfigTestConfig.createSecretSettingsMap(); + secretSettingsMap.put("extra_key", "value"); + var config = getPersistedConfigMap( + parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT), + parseConfigTestConfig.createTaskSettingsMap(), + secretSettingsMap + ); + + var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType(), config.config(), config.secrets()); + + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType()); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException { + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { + var listener = new PlainActionFuture(); + + service.infer( + getInvalidModel("id", "service"), + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INTERNAL_SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + exception.getMessage(), + is("The internal model was invalid, please delete the service [service] with id [id] and add it again.") + ); + } + } + + public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { + Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled()); + + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> service.updateModelWithEmbeddingDetails(getInvalidModel("id", "service"), randomNonNegativeInt()) + ); + + assertThat(exception.getMessage(), containsString("Can't update embedding details for model")); + } + } + + public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { + Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled()); + + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { + var embeddingSize = randomNonNegativeInt(); + var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(null); + + Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); + + assertEquals(SimilarityMeasure.DOT_PRODUCT, updatedModel.getServiceSettings().similarity()); + assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); + } + } + + public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException { + Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled()); + + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { + var embeddingSize = randomNonNegativeInt(); + var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(SimilarityMeasure.COSINE); + + Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); + + assertEquals(SimilarityMeasure.COSINE, updatedModel.getServiceSettings().similarity()); + assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); + } + } + + // streaming tests + public void testSupportedStreamingTasks() throws Exception { + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { + assertThat(service.supportedStreamingTasks(), is(testConfiguration.commonConfig().supportedStreamingTasks())); + assertFalse(service.canStream(TaskType.ANY)); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java index 9f7fdfdf17a11..b1ff02b6bbc17 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java @@ -62,6 +62,8 @@ */ public abstract class AbstractInferenceServiceTests extends InferenceServiceTestCase { + private final TestConfiguration testConfiguration; + protected final MockWebServer webServer = new MockWebServer(); protected ThreadPool threadPool; protected HttpClientManager clientManager; @@ -85,8 +87,6 @@ public void tearDown() throws Exception { webServer.close(); } - private final TestConfiguration testConfiguration; - public AbstractInferenceServiceTests(TestConfiguration testConfiguration, TestCase testCase) { this.testConfiguration = Objects.requireNonNull(testConfiguration); this.testCase = testCase; @@ -128,6 +128,14 @@ public CommonConfig(TaskType taskType, @Nullable TaskType unsupportedTaskType) { this.unsupportedTaskType = unsupportedTaskType; } + public TaskType taskType() { + return taskType; + } + + public TaskType unsupportedTaskType() { + return unsupportedTaskType; + } + protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager); protected abstract Map createServiceSettingsMap(TaskType taskType); @@ -140,7 +148,11 @@ protected Map createServiceSettingsMap(TaskType taskType, Config protected abstract Map createSecretSettingsMap(); - protected abstract void assertModel(Model model, TaskType taskType); + protected abstract void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets); + + protected void assertModel(Model model, TaskType taskType) { + assertModel(model, taskType, true); + } protected abstract EnumSet supportedStreamingTasks(); } @@ -347,10 +359,12 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap @ParametersFactory public static Iterable parameters() throws IOException { + var chunkingSettingsMap = createRandomChunkingSettingsMap(); + return Arrays.asList( new TestCase[][] { { - new TestCase( + new TestCaseBuilder( "Test parsing persisted config without chunking settings", testConfiguration -> getPersistedConfigMap( testConfiguration.commonConfig.createServiceSettingsMap( @@ -360,36 +374,106 @@ public static Iterable parameters() throws IOException { testConfiguration.commonConfig.createTaskSettingsMap(), null ), - (service, persistedConfig) -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()), - null - ) } } + (service, persistedConfig) -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()) + ).withNullChunkingSettingsMap().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with chunking settings", + testConfiguration -> getPersistedConfigMap( + testConfiguration.commonConfig.createServiceSettingsMap( + TaskType.TEXT_EMBEDDING, + ConfigurationParseContext.PERSISTENT + ), + testConfiguration.commonConfig.createTaskSettingsMap(), + chunkingSettingsMap, + null + ), + (service, persistedConfig) -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()) + ).withChunkingSettingsMap(chunkingSettingsMap).build() } } ); } public record TestCase( - @Nullable String description, + String description, Function createPersistedConfig, BiFunction serviceCallback, - @Nullable Map chunkingSettingsMap + @Nullable Map chunkingSettingsMap, + boolean validateChunkingSettings, + boolean modelIncludesSecrets ) {} + private static class TestCaseBuilder { + private final String description; + private final Function createPersistedConfig; + private final BiFunction serviceCallback; + @Nullable + private Map chunkingSettingsMap; + private boolean validateChunkingSettings; + private boolean modelIncludesSecrets; + + TestCaseBuilder( + String description, + Function createPersistedConfig, + BiFunction serviceCallback + ) { + this.description = description; + this.createPersistedConfig = createPersistedConfig; + this.serviceCallback = serviceCallback; + } + + public TestCaseBuilder withSecrets() { + this.modelIncludesSecrets = true; + return this; + } + + public TestCaseBuilder withChunkingSettingsMap(Map chunkingSettingsMap) { + this.chunkingSettingsMap = chunkingSettingsMap; + this.validateChunkingSettings = true; + return this; + } + + /** + * Use an empty chunking settings map but still do validation that the chunking settings are set to the appropriate + * defaults. + */ + public TestCaseBuilder withEmptyChunkingSettingsMap() { + this.chunkingSettingsMap = Map.of(); + this.validateChunkingSettings = true; + return this; + } + + public TestCaseBuilder withNullChunkingSettingsMap() { + this.chunkingSettingsMap = null; + this.validateChunkingSettings = true; + return this; + } + + public TestCase build() { + return new TestCase( + description, + createPersistedConfig, + serviceCallback, + chunkingSettingsMap, + validateChunkingSettings, + modelIncludesSecrets + ); + } + } + public void testPersistedConfig() throws Exception { var parseConfigTestConfig = testConfiguration.commonConfig; var persistedConfig = testCase.createPersistedConfig.apply(testConfiguration); try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var model = testCase.serviceCallback.apply(service, persistedConfig); - var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap( - testCase.chunkingSettingsMap == null ? Map.of() : testCase.chunkingSettingsMap - ); - assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); + if (testCase.validateChunkingSettings) { + var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(testCase.chunkingSettingsMap); + assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); + } parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); } - - parseConfigHelper(service -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfigMap.config()), null); } // parsePersistedConfig tests diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java index 5fb878403f05a..50ca2f72abda0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java @@ -110,8 +110,8 @@ protected Map createSecretSettingsMap() { } @Override - protected void assertModel(Model model, TaskType taskType) { - Ai21ServiceTests.assertModel(model, taskType); + protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + Ai21ServiceTests.assertModel(model, taskType, modelIncludesSecrets); } @Override @@ -122,32 +122,35 @@ protected EnumSet supportedStreamingTasks() { ).build(); } - private static void assertModel(Model model, TaskType taskType) { + private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { switch (taskType) { - case COMPLETION -> assertCompletionModel(model); - case CHAT_COMPLETION -> assertChatCompletionModel(model); + case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets); + case CHAT_COMPLETION -> assertChatCompletionModel(model, modelIncludesSecrets); default -> fail("unexpected task type [" + taskType + "]"); } } - private static Ai21Model assertCommonModelFields(Model model) { + private static Ai21Model assertCommonModelFields(Model model, boolean modelIncludesSecrets) { assertThat(model, instanceOf(Ai21Model.class)); var customModel = (Ai21Model) model; assertThat(customModel.uri.toString(), Matchers.is("https://api.ai21.com/studio/v1/chat/completions")); assertThat(customModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE)); - assertThat(customModel.getSecretSettings().apiKey(), Matchers.is(new SecureString("secret".toCharArray()))); + + if (modelIncludesSecrets) { + assertThat(customModel.getSecretSettings().apiKey(), Matchers.is(new SecureString("secret".toCharArray()))); + } return customModel; } - private static void assertCompletionModel(Model model) { - var customModel = assertCommonModelFields(model); + private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) { + var customModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(customModel.getTaskType(), Matchers.is(TaskType.COMPLETION)); } - private static void assertChatCompletionModel(Model model) { - var customModel = assertCommonModelFields(model); + private static void assertChatCompletionModel(Model model, boolean modelIncludesSecrets) { + var customModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(customModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index ec303f0c7b796..269642812e78f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -96,8 +96,8 @@ protected Map createSecretSettingsMap() { } @Override - protected void assertModel(Model model, TaskType taskType) { - CustomServiceTests.assertModel(model, taskType); + protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + CustomServiceTests.assertModel(model, taskType, modelIncludesSecrets); } @Override @@ -112,38 +112,40 @@ protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure) }).build(); } - private static void assertModel(Model model, TaskType taskType) { + private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { switch (taskType) { - case TEXT_EMBEDDING -> assertTextEmbeddingModel(model); - case COMPLETION -> assertCompletionModel(model); + case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets); + case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets); default -> fail("unexpected task type [" + taskType + "]"); } } - private static void assertTextEmbeddingModel(Model model) { - var customModel = assertCommonModelFields(model); + private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) { + var customModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(customModel.getTaskType(), is(TaskType.TEXT_EMBEDDING)); assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(TextEmbeddingResponseParser.class)); } - private static CustomModel assertCommonModelFields(Model model) { + private static CustomModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) { assertThat(model, instanceOf(CustomModel.class)); var customModel = (CustomModel) model; assertThat(customModel.getServiceSettings().getUrl(), is("http://www.abc.com")); assertThat(customModel.getTaskSettings().getParameters(), is(Map.of("test_key", "test_value"))); - assertThat( - customModel.getSecretSettings().getSecretParameters(), - is(Map.of("test_key", new SecureString("test_value".toCharArray()))) - ); + if (modelIncludesSecrets) { + assertThat( + customModel.getSecretSettings().getSecretParameters(), + is(Map.of("test_key", new SecureString("test_value".toCharArray()))) + ); + } return customModel; } - private static void assertCompletionModel(Model model) { - var customModel = assertCommonModelFields(model); + private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) { + var customModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(customModel.getTaskType(), is(TaskType.COMPLETION)); assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(CompletionResponseParser.class)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java index 5d81c8b062492..a6773122a0d2b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java @@ -130,8 +130,8 @@ protected Map createSecretSettingsMap() { } @Override - protected void assertModel(Model model, TaskType taskType) { - LlamaServiceTests.assertModel(model, taskType); + protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + LlamaServiceTests.assertModel(model, taskType, modelIncludesSecrets); } @Override @@ -146,43 +146,46 @@ protected LlamaEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarity }).build(); } - private static void assertModel(Model model, TaskType taskType) { + private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { switch (taskType) { - case TEXT_EMBEDDING -> assertTextEmbeddingModel(model); - case COMPLETION -> assertCompletionModel(model); - case CHAT_COMPLETION -> assertChatCompletionModel(model); + case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets); + case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets); + case CHAT_COMPLETION -> assertChatCompletionModel(model, modelIncludesSecrets); default -> fail("unexpected task type [" + taskType + "]"); } } - private static void assertTextEmbeddingModel(Model model) { - var llamaModel = assertCommonModelFields(model); + private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) { + var llamaModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.TEXT_EMBEDDING)); } - private static LlamaModel assertCommonModelFields(Model model) { + private static LlamaModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) { assertThat(model, instanceOf(LlamaModel.class)); var llamaModel = (LlamaModel) model; assertThat(llamaModel.getServiceSettings().modelId(), is("model_id")); assertThat(llamaModel.uri.toString(), Matchers.is("http://www.abc.com")); assertThat(llamaModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE)); - assertThat( - ((DefaultSecretSettings) llamaModel.getSecretSettings()).apiKey(), - Matchers.is(new SecureString("secret".toCharArray())) - ); + + if (modelIncludesSecrets) { + assertThat( + ((DefaultSecretSettings) llamaModel.getSecretSettings()).apiKey(), + Matchers.is(new SecureString("secret".toCharArray())) + ); + } return llamaModel; } - private static void assertCompletionModel(Model model) { - var llamaModel = assertCommonModelFields(model); + private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) { + var llamaModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.COMPLETION)); } - private static void assertChatCompletionModel(Model model) { - var llamaModel = assertCommonModelFields(model); + private static void assertChatCompletionModel(Model model, boolean modelIncludesSecrets) { + var llamaModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceParameterizedTests.java new file mode 100644 index 0000000000000..941dab6699575 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceParameterizedTests.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.openai; + +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests; + +import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceTests.createTestConfiguration; + +public class OpenAiServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests { + public OpenAiServiceParameterizedTests(AbstractInferenceServiceParameterizedTests.TestCase testCase) { + super(createTestConfiguration(), testCase); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 705afe78d3196..5eceffe200bc4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -91,7 +91,6 @@ import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -151,7 +150,7 @@ public OpenAiServiceTests(TestCase testCase) { super(createTestConfiguration(), testCase); } - private static TestConfiguration createTestConfiguration() { + public static TestConfiguration createTestConfiguration() { return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.RERANK) { @Override protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { @@ -179,8 +178,8 @@ protected Map createSecretSettingsMap() { } @Override - protected void assertModel(Model model, TaskType taskType) { - OpenAiServiceTests.assertModel(model, taskType); + protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + OpenAiServiceTests.assertModel(model, taskType, modelIncludesSecrets); } @Override @@ -210,14 +209,7 @@ private static Map createServiceSettingsMap(TaskType taskType, C ); if (taskType == TaskType.TEXT_EMBEDDING) { - settingsMap.putAll( - Map.of( - ServiceFields.SIMILARITY, - SIMILARITY.toString(), - ServiceFields.DIMENSIONS, - DIMENSIONS - ) - ); + settingsMap.putAll(Map.of(ServiceFields.SIMILARITY, SIMILARITY.toString(), ServiceFields.DIMENSIONS, DIMENSIONS)); if (parseContext == ConfigurationParseContext.PERSISTENT) { settingsMap.put(OpenAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, DIMENSIONS_SET_BY_USER); @@ -231,15 +223,15 @@ private static Map createTaskSettingsMap() { return new HashMap<>(Map.of(OpenAiServiceFields.USER, USER, OpenAiServiceFields.HEADERS, HEADERS)); } - private static void assertModel(Model model, TaskType taskType) { + private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { switch (taskType) { - case TEXT_EMBEDDING -> assertTextEmbeddingModel(model); - case COMPLETION, CHAT_COMPLETION -> assertCompletionModel(model); + case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets); + case COMPLETION, CHAT_COMPLETION -> assertCompletionModel(model, modelIncludesSecrets); default -> fail("unexpected task type: " + taskType); } } - private static void assertTextEmbeddingModel(Model model) { + private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) { assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); var embeddingsModel = (OpenAiEmbeddingsModel) model; @@ -260,10 +252,14 @@ private static void assertTextEmbeddingModel(Model model) { ); assertThat(embeddingsModel.getTaskSettings(), is(new OpenAiEmbeddingsTaskSettings(USER, HEADERS))); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(SECRET)); + if (modelIncludesSecrets) { + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(SECRET)); + } else { + assertNull(embeddingsModel.getSecretSettings()); + } } - private static void assertCompletionModel(Model model) { + private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) { assertThat(model, instanceOf(OpenAiChatCompletionModel.class)); var completionModel = (OpenAiChatCompletionModel) model; @@ -282,7 +278,14 @@ private static void assertCompletionModel(Model model) { ); assertThat(completionModel.getTaskSettings(), is(new OpenAiChatCompletionTaskSettings(USER, HEADERS))); - assertThat(completionModel.getSecretSettings().apiKey().toString(), is(SECRET)); + + assertSecrets(completionModel.getSecretSettings(), modelIncludesSecrets); + } + + private static void assertSecrets(DefaultSecretSettings secretSettings, boolean modelIncludesSecrets) { + if (modelIncludesSecrets) { + assertThat(secretSettings.apiKey().toString(), is(SECRET)); + } } private static OpenAiEmbeddingsModel createInternalEmbeddingModel( @@ -343,49 +346,6 @@ public void testParseRequestConfig_MovesModel() throws IOException { } } - public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", null, null, null, null, true), - getOpenAiTaskSettingsMap(null), - createRandomChunkingSettingsMap() - ); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertNull(embeddingsModel.getSecretSettings()); - } - } - - public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", null, null, null, null, true), - getOpenAiTaskSettingsMap(null) - ); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().uri()); - assertNull(embeddingsModel.getServiceSettings().organizationId()); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertNull(embeddingsModel.getTaskSettings().user()); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertNull(embeddingsModel.getSecretSettings()); - } - } - public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( From 399a50b5754bd2f7c8008e88415d9b211773d59b Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 25 Sep 2025 13:04:57 -0400 Subject: [PATCH 03/10] Working tests --- .../inference/services/ai21/Ai21Service.java | 10 +- .../services/custom/CustomService.java | 23 +- .../services/llama/LlamaService.java | 4 +- .../AbstractInferenceServiceBaseTests.java | 173 ++++++ ...actInferenceServiceParameterizedTests.java | 550 ++++++++---------- .../AbstractInferenceServiceTests.java | 537 ++--------------- .../ai21/Ai21ServiceParameterizedTests.java | 18 + .../services/ai21/Ai21ServiceTests.java | 26 +- .../CustomServiceParameterizedTests.java | 16 + .../services/custom/CustomServiceTests.java | 78 +-- .../llama/LlamaServiceParameterizedTests.java | 16 + .../services/llama/LlamaServiceTests.java | 65 ++- .../services/openai/OpenAiServiceTests.java | 131 ++--- 13 files changed, 638 insertions(+), 1009 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceParameterizedTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceParameterizedTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceParameterizedTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java index 438d31d8dd411..64437685af12b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java @@ -213,7 +213,7 @@ public Ai21Model parsePersistedConfigWithSecrets( taskType, serviceSettingsMap, secretSettingsMap, - parsePersistedConfigErrorMsg(modelId, NAME) + parsePersistedConfigErrorMsg(modelId, NAME, taskType) ); } @@ -222,7 +222,13 @@ public Ai21Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - return createModelFromPersistent(modelId, taskType, serviceSettingsMap, null, parsePersistedConfigErrorMsg(modelId, NAME)); + return createModelFromPersistent( + modelId, + taskType, + serviceSettingsMap, + null, + parsePersistedConfigErrorMsg(modelId, NAME, taskType) + ); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 7cd069ac2e3e0..f9e1dba847dc7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -105,7 +105,10 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - var chunkingSettings = extractChunkingSettings(config, taskType); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } CustomModel model = createModel( inferenceEntityId, @@ -156,14 +159,6 @@ private static RequestParameters createParameters(CustomModel model) { }; } - private static ChunkingSettings extractChunkingSettings(Map config, TaskType taskType) { - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return null; - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); @@ -229,7 +224,10 @@ public CustomModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - var chunkingSettings = extractChunkingSettings(config, taskType); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } return createModelWithoutLoggingDeprecations( inferenceEntityId, @@ -246,7 +244,10 @@ public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskT Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - var chunkingSettings = extractChunkingSettings(config, taskType); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } return createModelWithoutLoggingDeprecations( inferenceEntityId, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java index 829dbe0a18955..c13026c36dd73 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java @@ -318,7 +318,7 @@ public Model parsePersistedConfigWithSecrets( serviceSettingsMap, chunkingSettings, secretSettingsMap, - parsePersistedConfigErrorMsg(modelId, NAME) + parsePersistedConfigErrorMsg(modelId, NAME, taskType) ); } @@ -357,7 +357,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map supportedTaskTypes; + + public CommonConfig(TaskType targetTaskType, @Nullable TaskType unsupportedTaskType, EnumSet supportedTaskTypes) { + this.targetTaskType = Objects.requireNonNull(targetTaskType); + this.unsupportedTaskType = unsupportedTaskType; + this.supportedTaskTypes = Objects.requireNonNull(supportedTaskTypes); + } + + public TaskType targetTaskType() { + return targetTaskType; + } + + public TaskType unsupportedTaskType() { + return unsupportedTaskType; + } + + public EnumSet supportedTaskTypes() { + return supportedTaskTypes; + } + + protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager); + + protected abstract Map createServiceSettingsMap(TaskType taskType); + + protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) { + return createServiceSettingsMap(taskType); + } + + protected abstract Map createTaskSettingsMap(); + + protected abstract Map createSecretSettingsMap(); + + protected abstract void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets); + + protected void assertModel(Model model, TaskType taskType) { + assertModel(model, taskType, true); + } + + protected abstract EnumSet supportedStreamingTasks(); + + /** + * Override this method if the service support reranking. This method won't be called if the service doesn't support reranking. + */ + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + fail("Reranking services should override this test method to verify window size"); + } + } + + /** + * Configurations specific to the {@link SenderService#updateModelWithEmbeddingDetails(Model, int)} tests + */ + public abstract static class UpdateModelConfiguration { + + public boolean isEnabled() { + return true; + } + + protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); + } + + private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() { + @Override + public boolean isEnabled() { + return false; + } + + @Override + protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) { + throw new UnsupportedOperationException("Update model tests are disabled"); + } + }; + + @Override + public InferenceService createInferenceService() { + return testConfiguration.commonConfig.createService(threadPool, clientManager); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + testConfiguration.commonConfig.assertRerankerWindowSize(rerankingInferenceService); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java index f6fdefb1bac6f..37ba073407df8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java @@ -10,80 +10,35 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Strings; import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.http.MockWebServer; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; -import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.junit.After; import org.junit.Assume; -import org.junit.Before; import java.io.IOException; import java.util.Arrays; -import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.function.Function; -import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; -import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; -import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; -import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; -import static org.mockito.Mockito.mock; /** * Base class for testing inference services using parameterized tests. */ -public abstract class AbstractInferenceServiceParameterizedTests extends InferenceServiceTestCase { - - private final AbstractInferenceServiceTests.TestConfiguration testConfiguration; - - protected final MockWebServer webServer = new MockWebServer(); - protected ThreadPool threadPool; - protected HttpClientManager clientManager; - protected TestCase testCase; - - @Override - @Before - public void setUp() throws Exception { - super.setUp(); - webServer.start(); - threadPool = createThreadPool(inferenceUtilityExecutors()); - clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); - } - - @Override - @After - public void tearDown() throws Exception { - super.tearDown(); - clientManager.close(); - terminate(threadPool); - webServer.close(); - } +public abstract class AbstractInferenceServiceParameterizedTests extends AbstractInferenceServiceBaseTests { public AbstractInferenceServiceParameterizedTests( - AbstractInferenceServiceTests.TestConfiguration testConfiguration, + AbstractInferenceServiceBaseTests.TestConfiguration testConfiguration, TestCase testCase ) { - this.testConfiguration = Objects.requireNonNull(testConfiguration); + super(testConfiguration); this.testCase = testCase; } @@ -92,11 +47,66 @@ public InferenceService createInferenceService() { return testConfiguration.commonConfig().createService(threadPool, clientManager); } + public record TestCase( + String description, + Function createPersistedConfig, + ServiceParser serviceParser, + TaskType expectedTaskType, + boolean modelIncludesSecrets, + boolean expectFailure + ) {} + + private record ServiceParserParams( + SenderService service, + Utils.PersistedConfig persistedConfig, + AbstractInferenceServiceBaseTests.TestConfiguration testConfiguration + ) {} + + @FunctionalInterface + private interface ServiceParser { + Model parseConfigs(ServiceParserParams params); + } + + private static class TestCaseBuilder { + private final String description; + private final Function createPersistedConfig; + private final ServiceParser serviceParser; + private final TaskType expectedTaskType; + private boolean modelIncludesSecrets; + private boolean expectFailure; + + TestCaseBuilder( + String description, + Function createPersistedConfig, + ServiceParser serviceParser, + TaskType expectedTaskType + ) { + this.description = description; + this.createPersistedConfig = createPersistedConfig; + this.serviceParser = serviceParser; + this.expectedTaskType = expectedTaskType; + } + + public TestCaseBuilder withSecrets() { + this.modelIncludesSecrets = true; + return this; + } + + public TestCaseBuilder expectFailure() { + this.expectFailure = true; + return this; + } + + public TestCase build() { + return new TestCase(description, createPersistedConfig, serviceParser, expectedTaskType, modelIncludesSecrets, expectFailure); + } + } + @ParametersFactory public static Iterable parameters() throws IOException { return Arrays.asList( new TestCase[][] { - // parsePersistedConfig + // Test cases for parsePersistedConfig method { new TestCaseBuilder( "Test parsing persisted config without chunking settings", @@ -106,11 +116,7 @@ public static Iterable parameters() throws IOException { testConfiguration.commonConfig().createTaskSettingsMap(), null ), - (service, persistedConfig, testConfiguration) -> service.parsePersistedConfig( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config() - ), + (params) -> params.service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, params.persistedConfig.config()), TaskType.TEXT_EMBEDDING ).build() }, { @@ -123,14 +129,56 @@ public static Iterable parameters() throws IOException { createRandomChunkingSettingsMap(), null ), - (service, persistedConfig, testConfiguration) -> service.parsePersistedConfig( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config() - ), + (params) -> params.service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, params.persistedConfig.config()), TaskType.TEXT_EMBEDDING ).build() }, - // parsePersistedConfigWithSecrets + { + new TestCaseBuilder( + "Test parsing persisted config does not throw when an extra key exists in config", + testConfiguration -> { + var persistedConfigMap = getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + null + ); + persistedConfigMap.config().put("extra_key", "value"); + return persistedConfigMap; + }, + (params) -> params.service.parsePersistedConfig("id", TaskType.COMPLETION, params.persistedConfig.config()), + TaskType.COMPLETION + ).build() }, + { + new TestCaseBuilder( + "Test parsing persisted config does not throw when an extra key exists in service settings", + testConfiguration -> { + var serviceSettings = testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT); + serviceSettings.put("extra_key", "value"); + + return getPersistedConfigMap(serviceSettings, testConfiguration.commonConfig().createTaskSettingsMap(), null); + }, + (params) -> params.service.parsePersistedConfig("id", TaskType.COMPLETION, params.persistedConfig.config()), + TaskType.COMPLETION + ).build() }, + { + new TestCaseBuilder( + "Test parsing persisted config does not throw when an extra key exists in task settings", + testConfiguration -> { + var taskSettingsMap = testConfiguration.commonConfig().createTaskSettingsMap(); + taskSettingsMap.put("extra_key", "value"); + + return getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), + taskSettingsMap, + null + ); + }, + (params) -> params.service.parsePersistedConfig("id", TaskType.COMPLETION, params.persistedConfig.config()), + TaskType.COMPLETION + ).build() }, + // Test cases for parsePersistedConfigWithSecrets method { new TestCaseBuilder( "Test parsing persisted config with secrets creates an embeddings model", @@ -140,11 +188,11 @@ public static Iterable parameters() throws IOException { testConfiguration.commonConfig().createTaskSettingsMap(), testConfiguration.commonConfig().createSecretSettingsMap() ), - (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets( + (params) -> params.service.parsePersistedConfigWithSecrets( "id", TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + params.persistedConfig.config(), + params.persistedConfig.secrets() ), TaskType.TEXT_EMBEDDING ).withSecrets().build() }, @@ -159,11 +207,11 @@ public static Iterable parameters() throws IOException { createRandomChunkingSettingsMap(), testConfiguration.commonConfig().createSecretSettingsMap() ), - (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets( + (params) -> params.service.parsePersistedConfigWithSecrets( "id", TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + params.persistedConfig.config(), + params.persistedConfig.secrets() ), TaskType.TEXT_EMBEDDING ).withSecrets().build() }, @@ -177,11 +225,11 @@ public static Iterable parameters() throws IOException { testConfiguration.commonConfig().createTaskSettingsMap(), testConfiguration.commonConfig().createSecretSettingsMap() ), - (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets( + (params) -> params.service.parsePersistedConfigWithSecrets( "id", TaskType.COMPLETION, - persistedConfig.config(), - persistedConfig.secrets() + params.persistedConfig.config(), + params.persistedConfig.secrets() ), TaskType.COMPLETION ).withSecrets().build() }, @@ -194,276 +242,146 @@ public static Iterable parameters() throws IOException { testConfiguration.commonConfig().createTaskSettingsMap(), testConfiguration.commonConfig().createSecretSettingsMap() ), - (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets( + (params) -> params.service.parsePersistedConfigWithSecrets( "id", - testConfiguration.commonConfig().unsupportedTaskType(), - persistedConfig.config(), - persistedConfig.secrets() + params.testConfiguration.commonConfig().unsupportedTaskType(), + params.persistedConfig.config(), + params.persistedConfig.secrets() + ), + TaskType.COMPLETION + ).withSecrets().expectFailure().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with with secrets does not throw when an extra key exists in config", + testConfiguration -> { + var persistedConfigMap = getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + testConfiguration.commonConfig().createSecretSettingsMap() + ); + persistedConfigMap.config().put("extra_key", "value"); + return persistedConfigMap; + }, + (params) -> params.service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + params.persistedConfig.config(), + params.persistedConfig.secrets() + ), + TaskType.COMPLETION + ).withSecrets().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with with secrets does not throw when an extra key exists in service settings", + testConfiguration -> { + var serviceSettings = testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT); + serviceSettings.put("extra_key", "value"); + + return getPersistedConfigMap( + serviceSettings, + testConfiguration.commonConfig().createTaskSettingsMap(), + testConfiguration.commonConfig().createSecretSettingsMap() + ); + }, + (params) -> params.service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + params.persistedConfig.config(), + params.persistedConfig.secrets() + ), + TaskType.COMPLETION + ).withSecrets().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with with secrets does not throw when an extra key exists in task settings", + testConfiguration -> { + var taskSettingsMap = testConfiguration.commonConfig().createTaskSettingsMap(); + taskSettingsMap.put("extra_key", "value"); + + return getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), + taskSettingsMap, + testConfiguration.commonConfig().createSecretSettingsMap() + ); + }, + (params) -> params.service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + params.persistedConfig.config(), + params.persistedConfig.secrets() + ), + TaskType.COMPLETION + ).withSecrets().build() }, + { + new TestCaseBuilder( + "Test parsing persisted config with with secrets does not throw when an extra key exists in secret settings", + testConfiguration -> { + var secretSettingsMap = testConfiguration.commonConfig().createSecretSettingsMap(); + secretSettingsMap.put("extra_key", "value"); + + return getPersistedConfigMap( + testConfiguration.commonConfig() + .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), + testConfiguration.commonConfig().createTaskSettingsMap(), + secretSettingsMap + ); + }, + (params) -> params.service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + params.persistedConfig.config(), + params.persistedConfig.secrets() ), TaskType.COMPLETION ).withSecrets().build() } } ); } - public record TestCase( - String description, - Function createPersistedConfig, - ServiceCallback serviceCallback, - TaskType expectedTaskType, - boolean modelIncludesSecrets, - boolean expectFailure - ) {} - - @FunctionalInterface - interface ServiceCallback { - Model parseConfigs( - SenderService service, - Utils.PersistedConfig persistedConfig, - AbstractInferenceServiceTests.TestConfiguration testConfiguration - ); - } - - private static class TestCaseBuilder { - private final String description; - private final Function createPersistedConfig; - private final ServiceCallback serviceCallback; - private final TaskType expectedTaskType; - private boolean modelIncludesSecrets; - private boolean expectFailure; - - TestCaseBuilder( - String description, - Function createPersistedConfig, - ServiceCallback serviceCallback, - TaskType expectedTaskType - ) { - this.description = description; - this.createPersistedConfig = createPersistedConfig; - this.serviceCallback = serviceCallback; - this.expectedTaskType = expectedTaskType; - } - - public TestCaseBuilder withSecrets() { - this.modelIncludesSecrets = true; - return this; - } - - public TestCaseBuilder withFailure() { - this.expectFailure = true; - return this; - } - - public TestCase build() { - return new TestCase(description, createPersistedConfig, serviceCallback, expectedTaskType, modelIncludesSecrets, expectFailure); - } - } - public void testPersistedConfig() throws Exception { + // If the service doesn't support the expected task type, then skip the test + Assume.assumeTrue(testConfiguration.commonConfig().supportedTaskTypes().contains(testCase.expectedTaskType)); + var parseConfigTestConfig = testConfiguration.commonConfig(); var persistedConfig = testCase.createPersistedConfig.apply(testConfiguration); try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var model = testCase.serviceCallback.parseConfigs(service, persistedConfig, testConfiguration); - if (persistedConfig.config().containsKey(ModelConfigurations.CHUNKING_SETTINGS)) { - @SuppressWarnings("unchecked") - var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap( - (Map) persistedConfig.config().get(ModelConfigurations.CHUNKING_SETTINGS) - ); - assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); + if (testCase.expectFailure) { + assertFailedParse(service, persistedConfig); + } else { + assertSuccessfulParse(service, persistedConfig); } - - parseConfigTestConfig.assertModel(model, testCase.expectedTaskType, testCase.modelIncludesSecrets); - } - } - - // parsePersistedConfigWithSecrets - - public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() throws Exception { - var parseConfigTestConfig = testConfiguration.commonConfig(); - - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT), - parseConfigTestConfig.createTaskSettingsMap(), - parseConfigTestConfig.createSecretSettingsMap() - ); - - var exception = expectThrows( - ElasticsearchStatusException.class, - () -> service.parsePersistedConfigWithSecrets( - "id", - parseConfigTestConfig.unsupportedTaskType(), - persistedConfigMap.config(), - persistedConfigMap.secrets() - ) - ); - - assertThat( - exception.getMessage(), - containsString( - Strings.format(fetchPersistedConfigTaskTypeParsingErrorMessageFormat(), parseConfigTestConfig.unsupportedTaskType()) - ) - ); - } - } - - protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() { - return "service does not support task type [%s]"; - } - - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { - var parseConfigTestConfig = testConfiguration.commonConfig(); - - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT), - parseConfigTestConfig.createTaskSettingsMap(), - parseConfigTestConfig.createSecretSettingsMap() - ); - persistedConfigMap.config().put("extra_key", "value"); - - var model = service.parsePersistedConfigWithSecrets( - "id", - parseConfigTestConfig.taskType(), - persistedConfigMap.config(), - persistedConfigMap.secrets() - ); - - parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType()); - } - } - - public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { - var parseConfigTestConfig = testConfiguration.commonConfig(); - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var serviceSettings = parseConfigTestConfig.createServiceSettingsMap( - parseConfigTestConfig.taskType(), - ConfigurationParseContext.PERSISTENT - ); - serviceSettings.put("extra_key", "value"); - var persistedConfigMap = getPersistedConfigMap( - serviceSettings, - parseConfigTestConfig.createTaskSettingsMap(), - parseConfigTestConfig.createSecretSettingsMap() - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - parseConfigTestConfig.taskType(), - persistedConfigMap.config(), - persistedConfigMap.secrets() - ); - - parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType()); } } - public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { - var parseConfigTestConfig = testConfiguration.commonConfig(); - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var taskSettings = parseConfigTestConfig.createTaskSettingsMap(); - taskSettings.put("extra_key", "value"); - var config = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT), - taskSettings, - parseConfigTestConfig.createSecretSettingsMap() - ); - - var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType(), config.config(), config.secrets()); - - parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType()); - } - } - - public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { - var parseConfigTestConfig = testConfiguration.commonConfig(); - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var secretSettingsMap = parseConfigTestConfig.createSecretSettingsMap(); - secretSettingsMap.put("extra_key", "value"); - var config = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT), - parseConfigTestConfig.createTaskSettingsMap(), - secretSettingsMap - ); - - var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType(), config.config(), config.secrets()); - - parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType()); - } - } - - public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException { - try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { - var listener = new PlainActionFuture(); - - service.infer( - getInvalidModel("id", "service"), - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.INTERNAL_SEARCH, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + private void assertFailedParse(SenderService service, Utils.PersistedConfig persistedConfig) { + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> testCase.serviceParser.parseConfigs(new ServiceParserParams(service, persistedConfig, testConfiguration)) + ); - var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat( - exception.getMessage(), - is("The internal model was invalid, please delete the service [service] with id [id] and add it again.") - ); - } + assertThat( + exception.getMessage(), + containsString( + Strings.format("service does not support task type [%s]", testConfiguration.commonConfig().unsupportedTaskType()) + ) + ); } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { - Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled()); + private void assertSuccessfulParse(SenderService service, Utils.PersistedConfig persistedConfig) throws Exception { + var model = testCase.serviceParser.parseConfigs(new ServiceParserParams(service, persistedConfig, testConfiguration)); - try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { - var exception = expectThrows( - ElasticsearchStatusException.class, - () -> service.updateModelWithEmbeddingDetails(getInvalidModel("id", "service"), randomNonNegativeInt()) + if (persistedConfig.config().containsKey(ModelConfigurations.CHUNKING_SETTINGS)) { + @SuppressWarnings("unchecked") + var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap( + (Map) persistedConfig.config().get(ModelConfigurations.CHUNKING_SETTINGS) ); - - assertThat(exception.getMessage(), containsString("Can't update embedding details for model")); + assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); } - } - - public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { - Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled()); - try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { - var embeddingSize = randomNonNegativeInt(); - var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(null); - - Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); - - assertEquals(SimilarityMeasure.DOT_PRODUCT, updatedModel.getServiceSettings().similarity()); - assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); - } - } - - public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException { - Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled()); - - try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { - var embeddingSize = randomNonNegativeInt(); - var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(SimilarityMeasure.COSINE); - - Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); - - assertEquals(SimilarityMeasure.COSINE, updatedModel.getServiceSettings().similarity()); - assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); - } - } - - // streaming tests - public void testSupportedStreamingTasks() throws Exception { - try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { - assertThat(service.supportedStreamingTasks(), is(testConfiguration.commonConfig().supportedStreamingTasks())); - assertFalse(service.canStream(TaskType.ANY)); - } + testConfiguration.commonConfig().assertModel(model, testCase.expectedTaskType, testCase.modelIncludesSecrets); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java index b1ff02b6bbc17..9479a8ea6e10d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java @@ -7,50 +7,29 @@ package org.elasticsearch.xpack.inference.services; -import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; - import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; -import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.http.MockWebServer; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; -import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.junit.After; import org.junit.Assume; -import org.junit.Before; import java.io.IOException; -import java.util.Arrays; -import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; -import java.util.function.BiFunction; -import java.util.function.Function; import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; -import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; -import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; -import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; -import static org.mockito.Mockito.mock; /** * Base class for testing inference services. @@ -60,134 +39,16 @@ * To use this class, extend it and pass the constructor a configuration. *

*/ -public abstract class AbstractInferenceServiceTests extends InferenceServiceTestCase { - - private final TestConfiguration testConfiguration; - - protected final MockWebServer webServer = new MockWebServer(); - protected ThreadPool threadPool; - protected HttpClientManager clientManager; - protected TestCase testCase; - - @Override - @Before - public void setUp() throws Exception { - super.setUp(); - webServer.start(); - threadPool = createThreadPool(inferenceUtilityExecutors()); - clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); - } +public abstract class AbstractInferenceServiceTests extends AbstractInferenceServiceBaseTests { - @Override - @After - public void tearDown() throws Exception { - super.tearDown(); - clientManager.close(); - terminate(threadPool); - webServer.close(); - } - - public AbstractInferenceServiceTests(TestConfiguration testConfiguration, TestCase testCase) { - this.testConfiguration = Objects.requireNonNull(testConfiguration); - this.testCase = testCase; - } - - /** - * Main configurations for the tests - */ - public record TestConfiguration(CommonConfig commonConfig, UpdateModelConfiguration updateModelConfiguration) { - public static class Builder { - private final CommonConfig commonConfig; - private UpdateModelConfiguration updateModelConfiguration = DISABLED_UPDATE_MODEL_TESTS; - - public Builder(CommonConfig commonConfig) { - this.commonConfig = commonConfig; - } - - public Builder enableUpdateModelTests(UpdateModelConfiguration updateModelConfiguration) { - this.updateModelConfiguration = updateModelConfiguration; - return this; - } - - public TestConfiguration build() { - return new TestConfiguration(commonConfig, updateModelConfiguration); - } - } - } - - /** - * Configurations that are useful for most tests - */ - public abstract static class CommonConfig { - - private final TaskType taskType; - private final TaskType unsupportedTaskType; - - public CommonConfig(TaskType taskType, @Nullable TaskType unsupportedTaskType) { - this.taskType = Objects.requireNonNull(taskType); - this.unsupportedTaskType = unsupportedTaskType; - } - - public TaskType taskType() { - return taskType; - } - - public TaskType unsupportedTaskType() { - return unsupportedTaskType; - } - - protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager); - - protected abstract Map createServiceSettingsMap(TaskType taskType); - - protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) { - return createServiceSettingsMap(taskType); - } - - protected abstract Map createTaskSettingsMap(); - - protected abstract Map createSecretSettingsMap(); - - protected abstract void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets); - - protected void assertModel(Model model, TaskType taskType) { - assertModel(model, taskType, true); - } - - protected abstract EnumSet supportedStreamingTasks(); - } - - /** - * Configurations specific to the {@link SenderService#updateModelWithEmbeddingDetails(Model, int)} tests - */ - public abstract static class UpdateModelConfiguration { - - public boolean isEnabled() { - return true; - } - - protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); - } - - private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() { - @Override - public boolean isEnabled() { - return false; - } - - @Override - protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) { - throw new UnsupportedOperationException("Update model tests are disabled"); - } - }; - - @Override - public InferenceService createInferenceService() { - return testConfiguration.commonConfig.createService(threadPool, clientManager); + public AbstractInferenceServiceTests(TestConfiguration testConfiguration) { + super(testConfiguration); } public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws Exception { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + Assume.assumeTrue(testConfiguration.commonConfig().supportedTaskTypes().contains(TaskType.TEXT_EMBEDDING)); + + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var config = getRequestConfigMap( @@ -207,7 +68,9 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws Exception { } public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws Exception { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + Assume.assumeTrue(testConfiguration.commonConfig().supportedTaskTypes().contains(TaskType.TEXT_EMBEDDING)); + + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var chunkingSettingsMap = createRandomChunkingSettingsMap(); @@ -229,7 +92,7 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP } public void testParseRequestConfig_CreatesACompletionModel() throws Exception { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var config = getRequestConfigMap( @@ -246,12 +109,12 @@ public void testParseRequestConfig_CreatesACompletionModel() throws Exception { } public void testParseRequestConfig_ThrowsUnsupportedModelType() throws Exception { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var config = getRequestConfigMap( parseRequestConfigTestConfig.createServiceSettingsMap( - parseRequestConfigTestConfig.taskType, + parseRequestConfigTestConfig.targetTaskType(), ConfigurationParseContext.REQUEST ), parseRequestConfigTestConfig.createTaskSettingsMap(), @@ -259,23 +122,25 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws Exception ); var listener = new PlainActionFuture(); - service.parseRequestConfig("id", parseRequestConfigTestConfig.unsupportedTaskType, config, listener); + service.parseRequestConfig("id", parseRequestConfigTestConfig.unsupportedTaskType(), config, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( exception.getMessage(), - containsString(Strings.format("service does not support task type [%s]", parseRequestConfigTestConfig.unsupportedTaskType)) + containsString( + Strings.format("service does not support task type [%s]", parseRequestConfigTestConfig.unsupportedTaskType()) + ) ); } } public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var config = getRequestConfigMap( parseRequestConfigTestConfig.createServiceSettingsMap( - parseRequestConfigTestConfig.taskType, + parseRequestConfigTestConfig.targetTaskType(), ConfigurationParseContext.REQUEST ), parseRequestConfigTestConfig.createTaskSettingsMap(), @@ -284,7 +149,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I config.put("extra_key", "value"); var listener = new PlainActionFuture(); - service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]")); @@ -292,10 +157,10 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I } public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var serviceSettings = parseRequestConfigTestConfig.createServiceSettingsMap( - parseRequestConfigTestConfig.taskType, + parseRequestConfigTestConfig.targetTaskType(), ConfigurationParseContext.REQUEST ); serviceSettings.put("extra_key", "value"); @@ -306,7 +171,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa ); var listener = new PlainActionFuture(); - service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]")); @@ -314,13 +179,13 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa } public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var taskSettings = parseRequestConfigTestConfig.createTaskSettingsMap(); taskSettings.put("extra_key", "value"); var config = getRequestConfigMap( parseRequestConfigTestConfig.createServiceSettingsMap( - parseRequestConfigTestConfig.taskType, + parseRequestConfigTestConfig.targetTaskType(), ConfigurationParseContext.REQUEST ), taskSettings, @@ -328,7 +193,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() ); var listener = new PlainActionFuture(); - service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]")); @@ -336,13 +201,13 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() } public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { - var parseRequestConfigTestConfig = testConfiguration.commonConfig; + var parseRequestConfigTestConfig = testConfiguration.commonConfig(); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var secretSettingsMap = parseRequestConfigTestConfig.createSecretSettingsMap(); secretSettingsMap.put("extra_key", "value"); var config = getRequestConfigMap( parseRequestConfigTestConfig.createServiceSettingsMap( - parseRequestConfigTestConfig.taskType, + parseRequestConfigTestConfig.targetTaskType(), ConfigurationParseContext.REQUEST ), parseRequestConfigTestConfig.createTaskSettingsMap(), @@ -350,339 +215,15 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap ); var listener = new PlainActionFuture(); - service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener); + service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]")); } } - @ParametersFactory - public static Iterable parameters() throws IOException { - var chunkingSettingsMap = createRandomChunkingSettingsMap(); - - return Arrays.asList( - new TestCase[][] { - { - new TestCaseBuilder( - "Test parsing persisted config without chunking settings", - testConfiguration -> getPersistedConfigMap( - testConfiguration.commonConfig.createServiceSettingsMap( - TaskType.TEXT_EMBEDDING, - ConfigurationParseContext.PERSISTENT - ), - testConfiguration.commonConfig.createTaskSettingsMap(), - null - ), - (service, persistedConfig) -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()) - ).withNullChunkingSettingsMap().build() }, - { - new TestCaseBuilder( - "Test parsing persisted config with chunking settings", - testConfiguration -> getPersistedConfigMap( - testConfiguration.commonConfig.createServiceSettingsMap( - TaskType.TEXT_EMBEDDING, - ConfigurationParseContext.PERSISTENT - ), - testConfiguration.commonConfig.createTaskSettingsMap(), - chunkingSettingsMap, - null - ), - (service, persistedConfig) -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()) - ).withChunkingSettingsMap(chunkingSettingsMap).build() } } - ); - } - - public record TestCase( - String description, - Function createPersistedConfig, - BiFunction serviceCallback, - @Nullable Map chunkingSettingsMap, - boolean validateChunkingSettings, - boolean modelIncludesSecrets - ) {} - - private static class TestCaseBuilder { - private final String description; - private final Function createPersistedConfig; - private final BiFunction serviceCallback; - @Nullable - private Map chunkingSettingsMap; - private boolean validateChunkingSettings; - private boolean modelIncludesSecrets; - - TestCaseBuilder( - String description, - Function createPersistedConfig, - BiFunction serviceCallback - ) { - this.description = description; - this.createPersistedConfig = createPersistedConfig; - this.serviceCallback = serviceCallback; - } - - public TestCaseBuilder withSecrets() { - this.modelIncludesSecrets = true; - return this; - } - - public TestCaseBuilder withChunkingSettingsMap(Map chunkingSettingsMap) { - this.chunkingSettingsMap = chunkingSettingsMap; - this.validateChunkingSettings = true; - return this; - } - - /** - * Use an empty chunking settings map but still do validation that the chunking settings are set to the appropriate - * defaults. - */ - public TestCaseBuilder withEmptyChunkingSettingsMap() { - this.chunkingSettingsMap = Map.of(); - this.validateChunkingSettings = true; - return this; - } - - public TestCaseBuilder withNullChunkingSettingsMap() { - this.chunkingSettingsMap = null; - this.validateChunkingSettings = true; - return this; - } - - public TestCase build() { - return new TestCase( - description, - createPersistedConfig, - serviceCallback, - chunkingSettingsMap, - validateChunkingSettings, - modelIncludesSecrets - ); - } - } - - public void testPersistedConfig() throws Exception { - var parseConfigTestConfig = testConfiguration.commonConfig; - var persistedConfig = testCase.createPersistedConfig.apply(testConfiguration); - - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var model = testCase.serviceCallback.apply(service, persistedConfig); - - if (testCase.validateChunkingSettings) { - var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(testCase.chunkingSettingsMap); - assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); - } - - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); - } - } - - // parsePersistedConfig tests - - public void testParsePersistedConfig_CreatesAnEmbeddingsModel() throws Exception { - var parseConfigTestConfig = testConfiguration.commonConfig; - var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT), - parseConfigTestConfig.createTaskSettingsMap(), - null - ); - - parseConfigHelper(service -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfigMap.config()), null); - } - - private void parseConfigHelper(Function serviceParseCallback, @Nullable Map chunkingSettingsMap) - throws Exception { - var parseConfigTestConfig = testConfiguration.commonConfig; - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - - var model = serviceParseCallback.apply(service); - - var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap == null ? Map.of() : chunkingSettingsMap); - assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); - - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); - } - } - - // parsePersistedConfigWithSecrets - - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throws Exception { - var parseConfigTestConfig = testConfiguration.commonConfig; - - var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT), - parseConfigTestConfig.createTaskSettingsMap(), - parseConfigTestConfig.createSecretSettingsMap() - ); - - parseConfigHelper( - service -> service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfigMap.config(), - persistedConfigMap.secrets() - ), - null - ); - } - - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsAreProvided() throws Exception { - var parseConfigTestConfig = testConfiguration.commonConfig; - - var chunkingSettingsMap = createRandomChunkingSettingsMap(); - var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT), - parseConfigTestConfig.createTaskSettingsMap(), - chunkingSettingsMap, - parseConfigTestConfig.createSecretSettingsMap() - ); - - parseConfigHelper( - service -> service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfigMap.config(), - persistedConfigMap.secrets() - ), - chunkingSettingsMap - ); - } - - public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws Exception { - var parseConfigTestConfig = testConfiguration.commonConfig; - - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT), - parseConfigTestConfig.createTaskSettingsMap(), - parseConfigTestConfig.createSecretSettingsMap() - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.COMPLETION, - persistedConfigMap.config(), - persistedConfigMap.secrets() - ); - parseConfigTestConfig.assertModel(model, TaskType.COMPLETION); - } - } - - public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() throws Exception { - var parseConfigTestConfig = testConfiguration.commonConfig; - - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT), - parseConfigTestConfig.createTaskSettingsMap(), - parseConfigTestConfig.createSecretSettingsMap() - ); - - var exception = expectThrows( - ElasticsearchStatusException.class, - () -> service.parsePersistedConfigWithSecrets( - "id", - parseConfigTestConfig.unsupportedTaskType, - persistedConfigMap.config(), - persistedConfigMap.secrets() - ) - ); - - assertThat( - exception.getMessage(), - containsString( - Strings.format(fetchPersistedConfigTaskTypeParsingErrorMessageFormat(), parseConfigTestConfig.unsupportedTaskType) - ) - ); - } - } - - protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() { - return "service does not support task type [%s]"; - } - - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { - var parseConfigTestConfig = testConfiguration.commonConfig; - - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var persistedConfigMap = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT), - parseConfigTestConfig.createTaskSettingsMap(), - parseConfigTestConfig.createSecretSettingsMap() - ); - persistedConfigMap.config().put("extra_key", "value"); - - var model = service.parsePersistedConfigWithSecrets( - "id", - parseConfigTestConfig.taskType, - persistedConfigMap.config(), - persistedConfigMap.secrets() - ); - - parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); - } - } - - public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { - var parseConfigTestConfig = testConfiguration.commonConfig; - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var serviceSettings = parseConfigTestConfig.createServiceSettingsMap( - parseConfigTestConfig.taskType, - ConfigurationParseContext.PERSISTENT - ); - serviceSettings.put("extra_key", "value"); - var persistedConfigMap = getPersistedConfigMap( - serviceSettings, - parseConfigTestConfig.createTaskSettingsMap(), - parseConfigTestConfig.createSecretSettingsMap() - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - parseConfigTestConfig.taskType, - persistedConfigMap.config(), - persistedConfigMap.secrets() - ); - - parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); - } - } - - public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { - var parseConfigTestConfig = testConfiguration.commonConfig; - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var taskSettings = parseConfigTestConfig.createTaskSettingsMap(); - taskSettings.put("extra_key", "value"); - var config = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT), - taskSettings, - parseConfigTestConfig.createSecretSettingsMap() - ); - - var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets()); - - parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); - } - } - - public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { - var parseConfigTestConfig = testConfiguration.commonConfig; - try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) { - var secretSettingsMap = parseConfigTestConfig.createSecretSettingsMap(); - secretSettingsMap.put("extra_key", "value"); - var config = getPersistedConfigMap( - parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT), - parseConfigTestConfig.createTaskSettingsMap(), - secretSettingsMap - ); - - var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets()); - - parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); - } - } - public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException { - try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { var listener = new PlainActionFuture(); service.infer( @@ -707,9 +248,9 @@ public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException { } public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { - Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled()); + Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled()); - try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { var exception = expectThrows( ElasticsearchStatusException.class, () -> service.updateModelWithEmbeddingDetails(getInvalidModel("id", "service"), randomNonNegativeInt()) @@ -720,11 +261,11 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO } public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { - Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled()); + Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled()); - try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { var embeddingSize = randomNonNegativeInt(); - var model = testConfiguration.updateModelConfiguration.createEmbeddingModel(null); + var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(null); Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); @@ -734,11 +275,11 @@ public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() } public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException { - Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled()); + Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled()); - try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { var embeddingSize = randomNonNegativeInt(); - var model = testConfiguration.updateModelConfiguration.createEmbeddingModel(SimilarityMeasure.COSINE); + var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(SimilarityMeasure.COSINE); Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); @@ -749,8 +290,8 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel // streaming tests public void testSupportedStreamingTasks() throws Exception { - try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) { - assertThat(service.supportedStreamingTasks(), is(testConfiguration.commonConfig.supportedStreamingTasks())); + try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) { + assertThat(service.supportedStreamingTasks(), is(testConfiguration.commonConfig().supportedStreamingTasks())); assertFalse(service.canStream(TaskType.ANY)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceParameterizedTests.java new file mode 100644 index 0000000000000..1254610f32745 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceParameterizedTests.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ai21; + +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests; + +import static org.elasticsearch.xpack.inference.services.ai21.Ai21ServiceTests.createTestConfiguration; + +public class Ai21ServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests { + public Ai21ServiceParameterizedTests(AbstractInferenceServiceParameterizedTests.TestCase testCase) { + super(createTestConfiguration(), testCase); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java index 50ca2f72abda0..8a1f8d806a45b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java @@ -81,13 +81,17 @@ public class Ai21ServiceTests extends AbstractInferenceServiceTests { private ThreadPool threadPool; private HttpClientManager clientManager; - public Ai21ServiceTests(TestCase testCase) { - super(createTestConfiguration(), testCase); + public Ai21ServiceTests() { + super(createTestConfiguration()); } - private static AbstractInferenceServiceTests.TestConfiguration createTestConfiguration() { + public static AbstractInferenceServiceTests.TestConfiguration createTestConfiguration() { return new AbstractInferenceServiceTests.TestConfiguration.Builder( - new AbstractInferenceServiceTests.CommonConfig(TaskType.COMPLETION, TaskType.TEXT_EMBEDDING) { + new AbstractInferenceServiceTests.CommonConfig( + TaskType.COMPLETION, + TaskType.TEXT_EMBEDDING, + EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION) + ) { @Override protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { @@ -167,10 +171,6 @@ private static Map createSecretSettingsMap() { return new HashMap<>(Map.of("api_key", "secret")); } - protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() { - return "Failed to parse stored model [id] for [ai21] service, please delete and add the service again"; - } - @Before public void init() throws Exception { webServer.start(); @@ -185,16 +185,6 @@ public void shutdown() throws IOException { webServer.close(); } - @Override - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() { - // The Ai21Service does not support Text Embedding, so this test is not applicable. - } - - @Override - public void testParseRequestConfig_CreatesAnEmbeddingsModel() { - // The Ai21Service does not support Text Embedding, so this test is not applicable. - } - public void testParseRequestConfig_CreatesChatCompletionsModel() throws IOException { var url = "url"; var model = "model"; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceParameterizedTests.java new file mode 100644 index 0000000000000..a71a107f3b9d1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceParameterizedTests.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests; + +public class CustomServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests { + public CustomServiceParameterizedTests(TestCase testCase) { + super(CustomServiceTests.createTestConfiguration(), testCase); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 269642812e78f..44bf17c3ac96d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -69,42 +69,56 @@ public class CustomServiceTests extends AbstractInferenceServiceTests { - public CustomServiceTests(TestCase testCase) { - super(createTestConfiguration(), testCase); + public CustomServiceTests() { + super(createTestConfiguration()); } - private static TestConfiguration createTestConfiguration() { - return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION) { - @Override - protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { - return CustomServiceTests.createService(threadPool, clientManager); - } + public static TestConfiguration createTestConfiguration() { + return new TestConfiguration.Builder( + new CommonConfig( + TaskType.TEXT_EMBEDDING, + TaskType.CHAT_COMPLETION, + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.COMPLETION) + ) { + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return CustomServiceTests.createService(threadPool, clientManager); + } - @Override - protected Map createServiceSettingsMap(TaskType taskType) { - return CustomServiceTests.createServiceSettingsMap(taskType); - } + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return CustomServiceTests.createServiceSettingsMap(taskType); + } - @Override - protected Map createTaskSettingsMap() { - return CustomServiceTests.createTaskSettingsMap(); - } + @Override + protected Map createTaskSettingsMap() { + return CustomServiceTests.createTaskSettingsMap(); + } - @Override - protected Map createSecretSettingsMap() { - return CustomServiceTests.createSecretSettingsMap(); - } + @Override + protected Map createSecretSettingsMap() { + return CustomServiceTests.createSecretSettingsMap(); + } - @Override - protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { - CustomServiceTests.assertModel(model, taskType, modelIncludesSecrets); - } + @Override + protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + CustomServiceTests.assertModel(model, taskType, modelIncludesSecrets); + } - @Override - protected EnumSet supportedStreamingTasks() { - return EnumSet.noneOf(TaskType.class); + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.noneOf(TaskType.class); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat( + rerankingInferenceService.rerankerWindowSize("any model"), + CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE) + ); + } } - }).enableUpdateModelTests(new UpdateModelConfiguration() { + ).enableUpdateModelTests(new UpdateModelConfiguration() { @Override protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { return createInternalEmbeddingModel(similarityMeasure); @@ -808,12 +822,4 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { assertThat(requestMap.get("input"), is(List.of("a"))); } } - - @Override - protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { - assertThat( - rerankingInferenceService.rerankerWindowSize("any model"), - CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE) - ); - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceParameterizedTests.java new file mode 100644 index 0000000000000..04afbc28af10a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceParameterizedTests.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama; + +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests; + +public class LlamaServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests { + public LlamaServiceParameterizedTests(AbstractInferenceServiceParameterizedTests.TestCase testCase) { + super(LlamaServiceTests.createTestConfiguration(), testCase); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java index a6773122a0d2b..f6a0232db529c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java @@ -77,6 +77,9 @@ import static org.elasticsearch.ExceptionsHelper.unwrapCause; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; +import static org.elasticsearch.inference.TaskType.COMPLETION; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; @@ -102,43 +105,45 @@ public class LlamaServiceTests extends AbstractInferenceServiceTests { private ThreadPool threadPool; private HttpClientManager clientManager; - public LlamaServiceTests(TestCase testCase) { - super(createTestConfiguration(), testCase); + public LlamaServiceTests() { + super(createTestConfiguration()); } - private static TestConfiguration createTestConfiguration() { - return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) { + public static TestConfiguration createTestConfiguration() { + return new TestConfiguration.Builder( + new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION)) { - @Override - protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { - return LlamaServiceTests.createService(threadPool, clientManager); - } + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return LlamaServiceTests.createService(threadPool, clientManager); + } - @Override - protected Map createServiceSettingsMap(TaskType taskType) { - return LlamaServiceTests.createServiceSettingsMap(taskType); - } + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return LlamaServiceTests.createServiceSettingsMap(taskType); + } - @Override - protected Map createTaskSettingsMap() { - return new HashMap<>(); - } + @Override + protected Map createTaskSettingsMap() { + return new HashMap<>(); + } - @Override - protected Map createSecretSettingsMap() { - return LlamaServiceTests.createSecretSettingsMap(); - } + @Override + protected Map createSecretSettingsMap() { + return LlamaServiceTests.createSecretSettingsMap(); + } - @Override - protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { - LlamaServiceTests.assertModel(model, taskType, modelIncludesSecrets); - } + @Override + protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + LlamaServiceTests.assertModel(model, taskType, modelIncludesSecrets); + } - @Override - protected EnumSet supportedStreamingTasks() { - return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + } } - }).enableUpdateModelTests(new UpdateModelConfiguration() { + ).enableUpdateModelTests(new UpdateModelConfiguration() { @Override protected LlamaEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { return createInternalEmbeddingModel(similarityMeasure); @@ -239,10 +244,6 @@ private static LlamaEmbeddingsModel createInternalEmbeddingModel(@Nullable Simil ); } - protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() { - return "Failed to parse stored model [id] for [llama] service, please delete and add the service again"; - } - @Before public void init() throws Exception { webServer.start(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 5eceffe200bc4..c2cda175831ed 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -87,7 +87,6 @@ import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; -import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; @@ -146,47 +145,53 @@ public void shutdown() throws IOException { webServer.close(); } - public OpenAiServiceTests(TestCase testCase) { - super(createTestConfiguration(), testCase); + public OpenAiServiceTests() { + super(createTestConfiguration()); } public static TestConfiguration createTestConfiguration() { - return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.RERANK) { - @Override - protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { - return OpenAiServiceTests.createService(threadPool, clientManager); - } + return new TestConfiguration.Builder( + new CommonConfig( + TaskType.TEXT_EMBEDDING, + TaskType.RERANK, + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION) + ) { + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return OpenAiServiceTests.createService(threadPool, clientManager); + } - @Override - protected Map createServiceSettingsMap(TaskType taskType) { - return createServiceSettingsMap(taskType, ConfigurationParseContext.REQUEST); - } + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return createServiceSettingsMap(taskType, ConfigurationParseContext.REQUEST); + } - @Override - protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) { - return OpenAiServiceTests.createServiceSettingsMap(taskType, parseContext); - } + @Override + protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) { + return OpenAiServiceTests.createServiceSettingsMap(taskType, parseContext); + } - @Override - protected Map createTaskSettingsMap() { - return OpenAiServiceTests.createTaskSettingsMap(); - } + @Override + protected Map createTaskSettingsMap() { + return OpenAiServiceTests.createTaskSettingsMap(); + } - @Override - protected Map createSecretSettingsMap() { - return getSecretSettingsMap(SECRET); - } + @Override + protected Map createSecretSettingsMap() { + return getSecretSettingsMap(SECRET); + } - @Override - protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { - OpenAiServiceTests.assertModel(model, taskType, modelIncludesSecrets); - } + @Override + protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + OpenAiServiceTests.assertModel(model, taskType, modelIncludesSecrets); + } - @Override - protected EnumSet supportedStreamingTasks() { - return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + } } - }).enableUpdateModelTests(new UpdateModelConfiguration() { + ).enableUpdateModelTests(new UpdateModelConfiguration() { @Override protected OpenAiEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { return createInternalEmbeddingModel(similarityMeasure, null); @@ -346,68 +351,6 @@ public void testParseRequestConfig_MovesModel() throws IOException { } } - public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { - try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, null, true), - getOpenAiTaskSettingsMap("user") - ); - persistedConfig.config().put("extra_key", "value"); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertNull(embeddingsModel.getSecretSettings()); - } - } - - public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { - try (var service = createOpenAiService()) { - var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true); - serviceSettingsMap.put("extra_key", "value"); - - var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getOpenAiTaskSettingsMap("user")); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertNull(embeddingsModel.getSecretSettings()); - } - } - - public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { - try (var service = createOpenAiService()) { - var taskSettingsMap = getOpenAiTaskSettingsMap("user"); - taskSettingsMap.put("extra_key", "value"); - - var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org", null, null, true), taskSettingsMap); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); - - var embeddingsModel = (OpenAiEmbeddingsModel) model; - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); - assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org")); - assertThat(embeddingsModel.getServiceSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings().user(), is("user")); - assertNull(embeddingsModel.getSecretSettings()); - } - } - public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException { var sender = mock(Sender.class); From 060eed90f38084a2953e1362bc3131224e8e759c Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 25 Sep 2025 17:25:48 +0000 Subject: [PATCH 04/10] [CI] Auto commit changes from spotless --- .../xpack/inference/services/custom/CustomService.java | 4 +++- .../inference/services/AbstractInferenceServiceBaseTests.java | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index f9e1dba847dc7..c69284117ef36 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -107,7 +107,9 @@ public void parseRequestConfig( ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); } CustomModel model = createModel( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java index 46fedc1e71f4e..7657f7e79cf25 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java @@ -29,7 +29,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.mockito.Mockito.mock; -public abstract class AbstractInferenceServiceBaseTests extends InferenceServiceTestCase{ +public abstract class AbstractInferenceServiceBaseTests extends InferenceServiceTestCase { protected final TestConfiguration testConfiguration; protected final MockWebServer webServer = new MockWebServer(); From ac44cc6cff056612ed4f3ecfdc5332b2eedefd11 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 25 Sep 2025 17:32:57 +0000 Subject: [PATCH 05/10] [CI] Update transport version definitions --- server/src/main/resources/transport/upper_bounds/8.18.csv | 2 +- server/src/main/resources/transport/upper_bounds/8.19.csv | 2 +- server/src/main/resources/transport/upper_bounds/9.0.csv | 2 +- server/src/main/resources/transport/upper_bounds/9.1.csv | 2 +- server/src/main/resources/transport/upper_bounds/9.2.csv | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/server/src/main/resources/transport/upper_bounds/8.18.csv b/server/src/main/resources/transport/upper_bounds/8.18.csv index ffc592e1809ee..266bfbbd3bf78 100644 --- a/server/src/main/resources/transport/upper_bounds/8.18.csv +++ b/server/src/main/resources/transport/upper_bounds/8.18.csv @@ -1 +1 @@ -initial_elasticsearch_8_18_8,8840010 +transform_check_for_dangling_tasks,8840011 diff --git a/server/src/main/resources/transport/upper_bounds/8.19.csv b/server/src/main/resources/transport/upper_bounds/8.19.csv index 3cc6f439c5ea5..3600b3f8c633a 100644 --- a/server/src/main/resources/transport/upper_bounds/8.19.csv +++ b/server/src/main/resources/transport/upper_bounds/8.19.csv @@ -1 +1 @@ -initial_elasticsearch_8_19_5,8841069 +transform_check_for_dangling_tasks,8841070 diff --git a/server/src/main/resources/transport/upper_bounds/9.0.csv b/server/src/main/resources/transport/upper_bounds/9.0.csv index 8ad2ed1a4cacf..c11e6837bb813 100644 --- a/server/src/main/resources/transport/upper_bounds/9.0.csv +++ b/server/src/main/resources/transport/upper_bounds/9.0.csv @@ -1 +1 @@ -initial_elasticsearch_9_0_8,9000017 +transform_check_for_dangling_tasks,9000018 diff --git a/server/src/main/resources/transport/upper_bounds/9.1.csv b/server/src/main/resources/transport/upper_bounds/9.1.csv index 1cea5dc4d929b..80b97d85f7511 100644 --- a/server/src/main/resources/transport/upper_bounds/9.1.csv +++ b/server/src/main/resources/transport/upper_bounds/9.1.csv @@ -1 +1 @@ -initial_elasticsearch_9_1_5,9112008 +transform_check_for_dangling_tasks,9112009 diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv index b1209b927d8a5..e4c91df18cda8 100644 --- a/server/src/main/resources/transport/upper_bounds/9.2.csv +++ b/server/src/main/resources/transport/upper_bounds/9.2.csv @@ -1 +1 @@ -inference_api_openai_embeddings_headers,9169000 +index_reshard_shardcount_summary,9172000 From cc99292fab2956b11320a59f5a81f130a016ee69 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 29 Sep 2025 15:08:14 -0400 Subject: [PATCH 06/10] Removing deprecated function --- .../xpack/inference/services/ServiceUtils.java | 12 ------------ .../AlibabaCloudSearchService.java | 4 ++-- .../amazonbedrock/AmazonBedrockService.java | 4 ++-- .../services/anthropic/AnthropicService.java | 4 ++-- .../azureaistudio/AzureAiStudioService.java | 4 ++-- .../services/azureopenai/AzureOpenAiService.java | 4 ++-- .../inference/services/cohere/CohereService.java | 4 ++-- .../contextualai/ContextualAiService.java | 4 ++-- .../elastic/ElasticInferenceService.java | 4 ++-- .../googleaistudio/GoogleAiStudioService.java | 4 ++-- .../googlevertexai/GoogleVertexAiService.java | 4 ++-- .../huggingface/HuggingFaceBaseService.java | 4 ++-- .../services/ibmwatsonx/IbmWatsonxService.java | 4 ++-- .../inference/services/jinaai/JinaAIService.java | 4 ++-- .../services/mistral/MistralService.java | 4 ++-- .../services/voyageai/VoyageAIService.java | 4 ++-- .../amazonbedrock/AmazonBedrockServiceTests.java | 12 ++++++++++-- .../azureaistudio/AzureAiStudioServiceTests.java | 6 +++++- .../azureopenai/AzureOpenAiServiceTests.java | 12 ++++++++++-- .../services/cohere/CohereServiceTests.java | 16 ++++++++++++---- .../services/jinaai/JinaAIServiceTests.java | 16 ++++++++++++---- .../services/mistral/MistralServiceTests.java | 6 +++++- .../services/voyageai/VoyageAIServiceTests.java | 16 ++++++++++++---- 23 files changed, 96 insertions(+), 60 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 69b15f3e32cc2..3fda2506b9721 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -1079,18 +1079,6 @@ public interface EnumConstructor> { E apply(String name) throws IllegalArgumentException; } - /** - * @deprecated use {@link #parsePersistedConfigErrorMsg(String, String, TaskType)} instead - */ - @Deprecated - public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName) { - return format( - "Failed to parse stored model [%s] for [%s] service, please delete and add the service again", - inferenceEntityId, - serviceName - ); - } - public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) { return format( "Failed to parse stored model [%s] for [%s] service, error: [%s]. Please delete and add the service again", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index f474850b9f190..749a7f994d3c5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -256,7 +256,7 @@ public AlibabaCloudSearchModel parsePersistedConfigWithSecrets( taskSettingsMap, chunkingSettings, secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) ); } @@ -277,7 +277,7 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta taskSettingsMap, chunkingSettings, null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 11204018a5523..8275257608b33 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -241,7 +241,7 @@ public Model parsePersistedConfigWithSecrets( taskSettingsMap, chunkingSettings, secretSettingsMap, - parsePersistedConfigErrorMsg(modelId, NAME), + parsePersistedConfigErrorMsg(modelId, NAME, taskType), ConfigurationParseContext.PERSISTENT ); } @@ -263,7 +263,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) ); - MatcherAssert.assertThat( + assertThat( + thrownException.getMessage(), + containsString("Failed to parse stored model [id] for [cohere] service") + ); + assertThat( thrownException.getMessage(), - is("Failed to parse stored model [id] for [cohere] service, please delete and add the service again") + containsString("The [cohere] service does not support task type [sparse_embedding]") ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index fc50acdbd39b6..e5ed8891b368a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -443,9 +443,13 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM ) ); - MatcherAssert.assertThat( + assertThat( + thrownException.getMessage(), + containsString("Failed to parse stored model [id] for [jinaai] service") + ); + assertThat( thrownException.getMessage(), - is("Failed to parse stored model [id] for [jinaai] service, please delete and add the service again") + containsString("The [jinaai] service does not support task type [sparse_embedding]") ); } } @@ -683,9 +687,13 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) ); - MatcherAssert.assertThat( + assertThat( + thrownException.getMessage(), + containsString("Failed to parse stored model [id] for [jinaai] service") + ); + assertThat( thrownException.getMessage(), - is("Failed to parse stored model [id] for [jinaai] service, please delete and add the service again") + containsString("The [jinaai] service does not support task type [sparse_embedding]") ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 50731811e4164..1bf37948012e4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -744,7 +744,11 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM assertThat( thrownException.getMessage(), - is("Failed to parse stored model [id] for [mistral] service, please delete and add the service again") + containsString("Failed to parse stored model [id] for [mistral] service") + ); + assertThat( + thrownException.getMessage(), + containsString("The [mistral] service does not support task type [sparse_embedding]") ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 8cad8cbad208a..99c6d31c207b6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -409,9 +409,13 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM ) ); - MatcherAssert.assertThat( + assertThat( + thrownException.getMessage(), + containsString("Failed to parse stored model [id] for [voyageai] service") + ); + assertThat( thrownException.getMessage(), - is("Failed to parse stored model [id] for [voyageai] service, please delete and add the service again") + containsString("The [voyageai] service does not support task type [sparse_embedding]") ); } } @@ -624,9 +628,13 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) ); - MatcherAssert.assertThat( + assertThat( + thrownException.getMessage(), + containsString("Failed to parse stored model [id] for [voyageai] service") + ); + assertThat( thrownException.getMessage(), - is("Failed to parse stored model [id] for [voyageai] service, please delete and add the service again") + containsString("The [voyageai] service does not support task type [sparse_embedding]") ); } } From 0896e5c8390544ce8769a50bd288bcbda42b8da0 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 30 Sep 2025 09:27:44 -0400 Subject: [PATCH 07/10] Moving string creation and refactoring customservice chunking --- .../inference/services/ServiceUtils.java | 18 ++++++++++++++++++ .../inference/services/ai21/Ai21Service.java | 18 +++++------------- .../AlibabaCloudSearchService.java | 18 +++++------------- .../amazonbedrock/AmazonBedrockService.java | 8 ++------ .../services/anthropic/AnthropicService.java | 18 +++++------------- .../azureaistudio/AzureAiStudioService.java | 16 +++++----------- .../azureopenai/AzureOpenAiService.java | 18 +++++------------- .../services/cohere/CohereService.java | 18 +++++------------- .../contextualai/ContextualAiService.java | 8 +++----- .../services/custom/CustomService.java | 19 +++++++++++-------- .../elastic/ElasticInferenceService.java | 16 +++++----------- .../googleaistudio/GoogleAiStudioService.java | 18 +++++------------- .../googlevertexai/GoogleVertexAiService.java | 18 +++++------------- .../ibmwatsonx/IbmWatsonxService.java | 18 +++++------------- .../services/jinaai/JinaAIService.java | 18 +++++------------- .../services/llama/LlamaService.java | 19 +++++-------------- .../services/mistral/MistralService.java | 18 +++++------------- .../services/openai/OpenAiService.java | 16 +++++----------- .../services/voyageai/VoyageAIService.java | 18 +++++------------- 19 files changed, 109 insertions(+), 209 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 3fda2506b9721..2c383f5db2f5b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -1088,6 +1088,24 @@ public static String parsePersistedConfigErrorMsg(String inferenceEntityId, Stri ); } + /** + * Create an exception for when the task type is not valid for the service. + */ + public static ElasticsearchStatusException createInvalidTaskTypeException( + String inferenceEntityId, + String serviceName, + TaskType taskType, + ConfigurationParseContext parseContext + ) { + var message = parseContext == ConfigurationParseContext.PERSISTENT + ? parsePersistedConfigErrorMsg(inferenceEntityId, serviceName, taskType) + : TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName); + return new ElasticsearchStatusException( + message, + RestStatus.BAD_REQUEST + ); + } + public static ElasticsearchStatusException createInvalidModelException(Model model) { return new ElasticsearchStatusException( format( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java index f1375085c1f71..57bcc267ac644 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.ai21; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; @@ -27,7 +26,6 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; @@ -55,7 +53,7 @@ import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -183,7 +181,6 @@ public void parseRequestConfig( taskType, serviceSettingsMap, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -212,8 +209,7 @@ public Ai21Model parsePersistedConfigWithSecrets( modelId, taskType, serviceSettingsMap, - secretSettingsMap, - parsePersistedConfigErrorMsg(modelId, NAME, taskType) + secretSettingsMap ); } @@ -226,8 +222,7 @@ public Ai21Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { switch (taskType) { case CHAT_COMPLETION, COMPLETION: return new Ai21ChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context); default: - throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + throw createInvalidTaskTypeException(modelId, NAME, taskType, context); } } @@ -261,15 +255,13 @@ private Ai21Model createModelFromPersistent( String inferenceEntityId, TaskType taskType, Map serviceSettings, - Map secretSettings, - String failureMessage + Map secretSettings ) { return createModel( inferenceEntityId, taskType, serviceSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 749a7f994d3c5..5f2378d40116d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.alibabacloudsearch; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -32,7 +31,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -59,7 +57,7 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -135,7 +133,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -165,8 +162,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -175,7 +171,6 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -187,7 +182,6 @@ private static AlibabaCloudSearchModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -229,7 +223,7 @@ private static AlibabaCloudSearchModel createModel( secretSettings, context ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -255,8 +249,7 @@ public AlibabaCloudSearchModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + secretSettingsMap ); } @@ -276,8 +269,7 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta serviceSettingsMap, taskSettingsMap, chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + null ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 8275257608b33..5e3ba81b7e1bc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -60,7 +60,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -204,7 +204,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -241,7 +240,6 @@ public Model parsePersistedConfigWithSecrets( taskSettingsMap, chunkingSettings, secretSettingsMap, - parsePersistedConfigErrorMsg(modelId, NAME, taskType), ConfigurationParseContext.PERSISTENT ); } @@ -263,7 +261,6 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { switch (taskType) { @@ -318,7 +314,7 @@ private static AmazonBedrockModel createModel( checkChatCompletionProviderForTopKParameter(model); return model; } - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index ce39fc261312a..1f2c7f6cb4cdc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.anthropic; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -28,7 +27,6 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; @@ -48,7 +46,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -94,7 +92,6 @@ public void parseRequestConfig( serviceSettingsMap, taskSettingsMap, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -113,8 +110,7 @@ private static AnthropicModel createModelFromPersistent( TaskType taskType, Map serviceSettings, Map taskSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -122,7 +118,6 @@ private static AnthropicModel createModelFromPersistent( serviceSettings, taskSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -133,7 +128,6 @@ private static AnthropicModel createModel( Map serviceSettings, Map taskSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -146,7 +140,7 @@ private static AnthropicModel createModel( secretSettings, context ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -166,8 +160,7 @@ public AnthropicModel parsePersistedConfigWithSecrets( taskType, serviceSettingsMap, taskSettingsMap, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + secretSettingsMap ); } @@ -181,8 +174,7 @@ public AnthropicModel parsePersistedConfig(String inferenceEntityId, TaskType ta taskType, serviceSettingsMap, taskSettingsMap, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + null ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 3bda3ca281f92..23d46820b688f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -60,7 +60,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -185,7 +185,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -221,8 +220,7 @@ public AzureAiStudioModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + secretSettingsMap ); } @@ -242,8 +240,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M serviceSettingsMap, taskSettingsMap, chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + null ); } @@ -279,7 +276,6 @@ private static AzureAiStudioModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { @@ -305,7 +301,7 @@ private static AzureAiStudioModel createModel( context ); case RERANK -> model = new AzureAiStudioRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); } final var azureAiStudioServiceSettings = (AzureAiStudioServiceSettings) model.getServiceSettings(); checkProviderAndEndpointTypeForTask(taskType, azureAiStudioServiceSettings.provider(), azureAiStudioServiceSettings.endpointType()); @@ -318,8 +314,7 @@ private AzureAiStudioModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - Map secretSettings, - String failureMessage + Map secretSettings ) { return createModel( inferenceEntityId, @@ -328,7 +323,6 @@ private AzureAiStudioModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 46ec593b5b5b6..f3d70f004131b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.azureopenai; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -30,7 +29,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -55,7 +53,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -114,7 +112,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -134,8 +131,7 @@ private static AzureOpenAiModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -144,7 +140,6 @@ private static AzureOpenAiModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -156,7 +151,6 @@ private static AzureOpenAiModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { switch (taskType) { @@ -183,7 +177,7 @@ private static AzureOpenAiModel createModel( context ); } - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); } } @@ -209,8 +203,7 @@ public AzureOpenAiModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + secretSettingsMap ); } @@ -230,8 +223,7 @@ public AzureOpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType serviceSettingsMap, taskSettingsMap, chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + null ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 2374c67944501..f1e34138aad5c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.cohere; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -32,7 +31,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -60,7 +58,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -130,7 +128,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -150,8 +147,7 @@ private static CohereModel createModelWithoutLoggingDeprecations( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -160,7 +156,6 @@ private static CohereModel createModelWithoutLoggingDeprecations( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -172,7 +167,6 @@ private static CohereModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -186,7 +180,7 @@ private static CohereModel createModel( ); case RERANK -> new CohereRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context); case COMPLETION -> new CohereCompletionModel(inferenceEntityId, serviceSettings, secretSettings, context); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -212,8 +206,7 @@ public CohereModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + secretSettingsMap ); } @@ -233,8 +226,7 @@ public CohereModel parsePersistedConfig(String inferenceEntityId, TaskType taskT serviceSettingsMap, taskSettingsMap, chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + null ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java index 6937ac4703262..2f59400287a10 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java @@ -46,6 +46,8 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; + /** * Contextual AI inference service for reranking tasks. * This service uses the Contextual AI REST API to perform document reranking. @@ -97,7 +99,6 @@ public void parseRequestConfig( serviceSettingsMap, taskSettingsMap, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -117,11 +118,10 @@ private static ContextualAiRerankModel createModel( Map serviceSettings, Map taskSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { if (taskType != TaskType.RERANK) { - throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); } return new ContextualAiRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context); @@ -144,7 +144,6 @@ public ContextualAiRerankModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, secretSettingsMap, - ServiceUtils.parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType), ConfigurationParseContext.PERSISTENT ); } @@ -160,7 +159,6 @@ public ContextualAiRerankModel parsePersistedConfig(String inferenceEntityId, Ta serviceSettingsMap, taskSettingsMap, null, - ServiceUtils.parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType), ConfigurationParseContext.PERSISTENT ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 91f97ac585d85..e0c91885999f4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -227,10 +227,7 @@ public CustomModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } + var chunkingSettings = extractPersistentChunkingSettings(config, taskType); return createModelWithoutLoggingDeprecations( inferenceEntityId, @@ -242,15 +239,21 @@ public CustomModel parsePersistedConfigWithSecrets( ); } + private static ChunkingSettings extractPersistentChunkingSettings(Map config, TaskType taskType) { + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + // note there's + return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return null; + } + @Override public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } + var chunkingSettings = extractPersistentChunkingSettings(config, taskType); return createModelWithoutLoggingDeprecations( inferenceEntityId, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index a1d7e331db8a2..7063b00c8ac99 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -79,7 +79,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -446,7 +446,6 @@ public void parseRequestConfig( chunkingSettings, serviceSettingsMap, elasticInferenceServiceComponents, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -494,7 +493,6 @@ private static ElasticInferenceServiceModel createModel( ChunkingSettings chunkingSettings, @Nullable Map secretSettings, ElasticInferenceServiceComponents elasticInferenceServiceComponents, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -540,7 +538,7 @@ private static ElasticInferenceServiceModel createModel( context, chunkingSettings ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -566,8 +564,7 @@ public Model parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + secretSettingsMap ); } @@ -587,8 +584,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M serviceSettingsMap, taskSettingsMap, chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + null ); } @@ -603,8 +599,7 @@ private ElasticInferenceServiceModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -614,7 +609,6 @@ private ElasticInferenceServiceModel createModelFromPersistent( chunkingSettings, secretSettings, elasticInferenceServiceComponents, - failureMessage, ConfigurationParseContext.PERSISTENT ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 02fee095e71e1..f5924cf6f9854 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.googleaistudio; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -31,7 +30,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; @@ -60,7 +58,7 @@ import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -127,7 +125,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -149,7 +146,6 @@ private static GoogleAiStudioModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -172,7 +168,7 @@ private static GoogleAiStudioModel createModel( secretSettings, context ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -198,8 +194,7 @@ public GoogleAiStudioModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + secretSettingsMap ); } @@ -209,8 +204,7 @@ private static GoogleAiStudioModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - Map secretSettings, - String failureMessage + Map secretSettings ) { return createModel( inferenceEntityId, @@ -219,7 +213,6 @@ private static GoogleAiStudioModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -240,8 +233,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M serviceSettingsMap, taskSettingsMap, chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + null ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 5fa251f74dd9a..f8318b1d6f838 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.services.googlevertexai; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -31,7 +30,6 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; @@ -63,7 +61,7 @@ import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -149,7 +147,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -185,8 +182,7 @@ public Model parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + secretSettingsMap ); } @@ -206,8 +202,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M serviceSettingsMap, taskSettingsMap, chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + null ); } @@ -356,8 +351,7 @@ private static GoogleVertexAiModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - Map secretSettings, - String failureMessage + Map secretSettings ) { return createModel( inferenceEntityId, @@ -366,7 +360,6 @@ private static GoogleVertexAiModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -378,7 +371,6 @@ private static GoogleVertexAiModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -412,7 +404,7 @@ private static GoogleVertexAiModel createModel( context ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index c1bfc2ec6807e..8b4d8fd1f7dae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.ibmwatsonx; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -30,7 +29,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; @@ -62,7 +60,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -128,7 +126,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -150,7 +147,6 @@ private static IbmWatsonxModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -181,7 +177,7 @@ private static IbmWatsonxModel createModel( secretSettings, context ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -207,8 +203,7 @@ public IbmWatsonxModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + secretSettingsMap ); } @@ -228,8 +223,7 @@ private static IbmWatsonxModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - Map secretSettings, - String failureMessage + Map secretSettings ) { return createModel( inferenceEntityId, @@ -238,7 +232,6 @@ private static IbmWatsonxModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -259,8 +252,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M serviceSettingsMap, taskSettingsMap, chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + null ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index 4169fb3a62be5..e0f7ddf864876 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.jinaai; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -31,7 +30,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -57,7 +55,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -121,7 +119,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -141,8 +138,7 @@ private static JinaAIModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -151,7 +147,6 @@ private static JinaAIModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -163,7 +158,6 @@ private static JinaAIModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -177,7 +171,7 @@ private static JinaAIModel createModel( context ); case RERANK -> new JinaAIRerankModel(inferenceEntityId, NAME, serviceSettings, taskSettings, secretSettings, context); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -203,8 +197,7 @@ public JinaAIModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + secretSettingsMap ); } @@ -224,8 +217,7 @@ public JinaAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskT serviceSettingsMap, taskSettingsMap, chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + null ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java index 4d99151c65159..fed4824201e42 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.llama; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; @@ -28,7 +27,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; @@ -64,7 +62,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -138,7 +136,6 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc * @param serviceSettings the settings for the inference service * @param chunkingSettings the settings for chunking, if applicable * @param secretSettings the secret settings for the model, such as API keys or tokens - * @param failureMessage the message to use in case of failure * @param context the context for parsing configuration settings * @return a new instance of LlamaModel based on the provided parameters */ @@ -148,7 +145,6 @@ protected LlamaModel createModel( Map serviceSettings, ChunkingSettings chunkingSettings, Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { switch (taskType) { @@ -157,7 +153,7 @@ protected LlamaModel createModel( case CHAT_COMPLETION, COMPLETION: return new LlamaChatCompletionModel(inferenceId, taskType, NAME, serviceSettings, secretSettings, context); default: - throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + throw createInvalidTaskTypeException(inferenceId, NAME, taskType, context); } } @@ -283,7 +279,6 @@ public void parseRequestConfig( serviceSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -318,8 +313,7 @@ public Model parsePersistedConfigWithSecrets( taskType, serviceSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(modelId, NAME, taskType) + secretSettingsMap ); } @@ -328,8 +322,7 @@ private LlamaModel createModelFromPersistent( TaskType taskType, Map serviceSettings, ChunkingSettings chunkingSettings, - Map secretSettings, - String failureMessage + Map secretSettings ) { return createModel( inferenceEntityId, @@ -337,7 +330,6 @@ private LlamaModel createModelFromPersistent( serviceSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -357,8 +349,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { switch (taskType) { @@ -299,7 +293,7 @@ private static MistralModel createModel( case CHAT_COMPLETION, COMPLETION: return new MistralChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context); default: - throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + throw createInvalidTaskTypeException(modelId, NAME, taskType, context); } } @@ -308,8 +302,7 @@ private MistralModel createModelFromPersistent( TaskType taskType, Map serviceSettings, ChunkingSettings chunkingSettings, - Map secretSettings, - String failureMessage + Map secretSettings ) { return createModel( inferenceEntityId, @@ -317,7 +310,6 @@ private MistralModel createModelFromPersistent( serviceSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 3a8464a5b8464..944c0af330c2a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -65,7 +65,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -139,7 +139,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -159,8 +158,7 @@ private static OpenAiModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -169,7 +167,6 @@ private static OpenAiModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -181,7 +178,6 @@ private static OpenAiModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -204,7 +200,7 @@ private static OpenAiModel createModel( secretSettings, context ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -232,8 +228,7 @@ public OpenAiModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + secretSettingsMap ); } @@ -255,8 +250,7 @@ public OpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskT serviceSettingsMap, taskSettingsMap, chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + null ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 3d2968b0c8a89..1b4be3842ca52 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.voyageai; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -31,7 +30,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -56,7 +54,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -152,7 +150,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST ); @@ -172,8 +169,7 @@ private static VoyageAIModel createModelFromPersistent( Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage + @Nullable Map secretSettings ) { return createModel( inferenceEntityId, @@ -182,7 +178,6 @@ private static VoyageAIModel createModelFromPersistent( taskSettings, chunkingSettings, secretSettings, - failureMessage, ConfigurationParseContext.PERSISTENT ); } @@ -194,7 +189,6 @@ private static VoyageAIModel createModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, ConfigurationParseContext context ) { return switch (taskType) { @@ -208,7 +202,7 @@ private static VoyageAIModel createModel( context ); case RERANK -> new VoyageAIRerankModel(inferenceEntityId, NAME, serviceSettings, taskSettings, secretSettings, context); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); }; } @@ -234,8 +228,7 @@ public VoyageAIModel parsePersistedConfigWithSecrets( serviceSettingsMap, taskSettingsMap, chunkingSettings, - secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + secretSettingsMap ); } @@ -255,8 +248,7 @@ public VoyageAIModel parsePersistedConfig(String inferenceEntityId, TaskType tas serviceSettingsMap, taskSettingsMap, chunkingSettings, - null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType) + null ); } From 49c1301afb6fed9bf1a561a02b55ffb3eee1d5b4 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 30 Sep 2025 09:41:46 -0400 Subject: [PATCH 08/10] Removing usages of persistent function --- .../xpack/inference/services/ServiceUtils.java | 18 ++++++++++-------- .../services/custom/CustomService.java | 6 +++--- .../huggingface/HuggingFaceBaseService.java | 4 ---- .../HuggingFaceModelParameters.java | 1 - .../huggingface/HuggingFaceService.java | 5 ++--- .../elser/HuggingFaceElserService.java | 5 ++--- 6 files changed, 17 insertions(+), 22 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 2c383f5db2f5b..9d03278470d35 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -1079,14 +1079,7 @@ public interface EnumConstructor> { E apply(String name) throws IllegalArgumentException; } - public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) { - return format( - "Failed to parse stored model [%s] for [%s] service, error: [%s]. Please delete and add the service again", - inferenceEntityId, - serviceName, - TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName) - ); - } + /** * Create an exception for when the task type is not valid for the service. @@ -1106,6 +1099,15 @@ public static ElasticsearchStatusException createInvalidTaskTypeException( ); } + private static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) { + return format( + "Failed to parse stored model [%s] for [%s] service, error: [%s]. Please delete and add the service again", + inferenceEntityId, + serviceName, + TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName) + ); + } + public static ElasticsearchStatusException createInvalidModelException(Model model) { return new ElasticsearchStatusException( format( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index e0c91885999f4..c5c8853c7374a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -57,9 +57,9 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.inference.TaskType.unsupportedTaskTypeErrorMsg; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -211,7 +211,7 @@ private static CustomModel createModel( ConfigurationParseContext context ) { if (supportedTaskTypes.contains(taskType) == false) { - throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); + throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); } return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, chunkingSettings, context); } @@ -241,7 +241,7 @@ public CustomModel parsePersistedConfigWithSecrets( private static ChunkingSettings extractPersistentChunkingSettings(Map config, TaskType taskType) { if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - // note there's + // note there's return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index 470228d80c5d1..403a1758983ac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -31,7 +31,6 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; @@ -84,7 +83,6 @@ public void parseRequestConfig( taskSettingsMap, chunkingSettings, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, name()), ConfigurationParseContext.REQUEST ) ); @@ -123,7 +121,6 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( taskSettingsMap, chunkingSettings, secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, name(), taskType), ConfigurationParseContext.PERSISTENT ) ); @@ -147,7 +144,6 @@ public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskSettingsMap, chunkingSettings, null, - parsePersistedConfigErrorMsg(inferenceEntityId, name(), taskType), ConfigurationParseContext.PERSISTENT ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java index 6dabaa66ffb2b..7600207eebad0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java @@ -20,6 +20,5 @@ public record HuggingFaceModelParameters( Map taskSettings, ChunkingSettings chunkingSettings, Map secretSettings, - String failureMessage, ConfigurationParseContext context ) {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index d0a98d8252923..e727a9e20cd8c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.huggingface; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -26,7 +25,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; @@ -54,6 +52,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; /** * This class is responsible for managing the Hugging Face inference service. @@ -124,7 +123,7 @@ protected HuggingFaceModel createModel(HuggingFaceModelParameters params) { params.secretSettings(), params.context() ); - default -> throw new ElasticsearchStatusException(params.failureMessage(), RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(params.inferenceEntityId(), NAME, params.taskType(), params.context()); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 775a4e90ae034..081d5c63b84ff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.huggingface.elser; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -25,7 +24,6 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; @@ -50,6 +48,7 @@ import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL; @@ -87,7 +86,7 @@ protected HuggingFaceModel createModel(HuggingFaceModelParameters input) { input.secretSettings(), input.context() ); - default -> throw new ElasticsearchStatusException(input.failureMessage(), RestStatus.BAD_REQUEST); + default -> throw createInvalidTaskTypeException(input.inferenceEntityId(), NAME, input.taskType(), input.context()); }; } From 6ec99f9a80ff0801d51c9bb32a51c306a19188cf Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 30 Sep 2025 13:50:17 +0000 Subject: [PATCH 09/10] [CI] Auto commit changes from spotless --- .../inference/services/ServiceUtils.java | 7 +---- .../inference/services/ai21/Ai21Service.java | 30 +++---------------- .../services/anthropic/AnthropicService.java | 16 ++-------- .../azureaistudio/AzureAiStudioService.java | 9 +----- .../azureopenai/AzureOpenAiService.java | 9 +----- .../contextualai/ContextualAiService.java | 9 +----- .../elastic/ElasticInferenceService.java | 9 +----- .../googleaistudio/GoogleAiStudioService.java | 9 +----- .../googlevertexai/GoogleVertexAiService.java | 9 +----- .../ibmwatsonx/IbmWatsonxService.java | 9 +----- .../services/jinaai/JinaAIService.java | 9 +----- .../services/llama/LlamaService.java | 16 ++-------- .../services/mistral/MistralService.java | 16 ++-------- .../services/openai/OpenAiService.java | 9 +----- .../services/voyageai/VoyageAIService.java | 9 +----- .../AmazonBedrockServiceTests.java | 10 ++----- .../AzureAiStudioServiceTests.java | 5 +--- .../azureopenai/AzureOpenAiServiceTests.java | 10 ++----- .../services/cohere/CohereServiceTests.java | 20 +++---------- .../services/jinaai/JinaAIServiceTests.java | 20 +++---------- .../services/mistral/MistralServiceTests.java | 10 ++----- .../voyageai/VoyageAIServiceTests.java | 10 ++----- 22 files changed, 38 insertions(+), 222 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 9d03278470d35..2607e68acadbb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -1079,8 +1079,6 @@ public interface EnumConstructor> { E apply(String name) throws IllegalArgumentException; } - - /** * Create an exception for when the task type is not valid for the service. */ @@ -1093,10 +1091,7 @@ public static ElasticsearchStatusException createInvalidTaskTypeException( var message = parseContext == ConfigurationParseContext.PERSISTENT ? parsePersistedConfigErrorMsg(inferenceEntityId, serviceName, taskType) : TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName); - return new ElasticsearchStatusException( - message, - RestStatus.BAD_REQUEST - ); + return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST); } private static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java index 57bcc267ac644..69eddd7ecade2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java @@ -176,13 +176,7 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Ai21Model model = createModel( - modelId, - taskType, - serviceSettingsMap, - serviceSettingsMap, - ConfigurationParseContext.REQUEST - ); + Ai21Model model = createModel(modelId, taskType, serviceSettingsMap, serviceSettingsMap, ConfigurationParseContext.REQUEST); throwIfNotEmptyMap(config, NAME); throwIfNotEmptyMap(serviceSettingsMap, NAME); @@ -205,12 +199,7 @@ public Ai21Model parsePersistedConfigWithSecrets( removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - return createModelFromPersistent( - modelId, - taskType, - serviceSettingsMap, - secretSettingsMap - ); + return createModelFromPersistent(modelId, taskType, serviceSettingsMap, secretSettingsMap); } @Override @@ -218,12 +207,7 @@ public Ai21Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - return createModelFromPersistent( - modelId, - taskType, - serviceSettingsMap, - null - ); + return createModelFromPersistent(modelId, taskType, serviceSettingsMap, null); } @Override @@ -257,13 +241,7 @@ private Ai21Model createModelFromPersistent( Map serviceSettings, Map secretSettings ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); + return createModel(inferenceEntityId, taskType, serviceSettings, secretSettings, ConfigurationParseContext.PERSISTENT); } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 1f2c7f6cb4cdc..29a1582ce6236 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -155,13 +155,7 @@ public AnthropicModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - secretSettingsMap - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap); } @Override @@ -169,13 +163,7 @@ public AnthropicModel parsePersistedConfig(String inferenceEntityId, TaskType ta Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - null - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 23d46820b688f..60b8bbdf86fa5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -234,14 +234,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index f3d70f004131b..de067ae0096b5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -217,14 +217,7 @@ public AzureOpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java index 2f59400287a10..9088de1dacd2d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java @@ -153,14 +153,7 @@ public ContextualAiRerankModel parsePersistedConfig(String inferenceEntityId, Ta Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - return createModel( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - null, - ConfigurationParseContext.PERSISTENT - ); + return createModel(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null, ConfigurationParseContext.PERSISTENT); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 7063b00c8ac99..c4156c0bfd6b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -578,14 +578,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index f5924cf6f9854..dc08ec8544e3c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -227,14 +227,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index f8318b1d6f838..66a4dd0649730 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -196,14 +196,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 8b4d8fd1f7dae..ee836d03747de 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -246,14 +246,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index e0f7ddf864876..5e95f85e78ecd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -211,14 +211,7 @@ public JinaAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskT chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null - ); + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java index fed4824201e42..b3c0e12927fff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java @@ -308,13 +308,7 @@ public Model parsePersistedConfigWithSecrets( chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } - return createModelFromPersistent( - modelId, - taskType, - serviceSettingsMap, - chunkingSettings, - secretSettingsMap - ); + return createModelFromPersistent(modelId, taskType, serviceSettingsMap, chunkingSettings, secretSettingsMap); } private LlamaModel createModelFromPersistent( @@ -344,13 +338,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) ); - assertThat( - thrownException.getMessage(), - containsString("Failed to parse stored model [id] for [amazonbedrock] service") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [amazonbedrock] service")); assertThat( thrownException.getMessage(), containsString("The [amazonbedrock] service does not support task type [sparse_embedding]") diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 3e17a603c85c2..127b7d1c4cfae 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -807,10 +807,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM () -> service.parsePersistedConfigWithSecrets("id", TaskType.SPARSE_EMBEDDING, config.config(), config.secrets()) ); - assertThat( - thrownException.getMessage(), - containsString("Failed to parse stored model [id] for [azureaistudio] service") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [azureaistudio] service")); assertThat( thrownException.getMessage(), containsString("The [azureaistudio] service does not support task type [sparse_embedding]") diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 0012e1f27be02..cf3ac6979b8e3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -435,10 +435,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM ) ); - assertThat( - thrownException.getMessage(), - containsString("Failed to parse stored model [id] for [azureopenai] service") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [azureopenai] service")); assertThat( thrownException.getMessage(), containsString("The [azureopenai] service does not support task type [sparse_embedding]") @@ -672,10 +669,7 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) ); - assertThat( - thrownException.getMessage(), - containsString("Failed to parse stored model [id] for [azureopenai] service") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [azureopenai] service")); assertThat( thrownException.getMessage(), containsString("The [azureopenai] service does not support task type [sparse_embedding]") diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 453a5ee72d4c1..9642e7b85cdc7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -450,14 +450,8 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM ) ); - assertThat( - thrownException.getMessage(), - containsString("Failed to parse stored model [id] for [cohere] service") - ); - assertThat( - thrownException.getMessage(), - containsString("The [cohere] service does not support task type [sparse_embedding]") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [cohere] service")); + assertThat(thrownException.getMessage(), containsString("The [cohere] service does not support task type [sparse_embedding]")); } } @@ -691,14 +685,8 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) ); - assertThat( - thrownException.getMessage(), - containsString("Failed to parse stored model [id] for [cohere] service") - ); - assertThat( - thrownException.getMessage(), - containsString("The [cohere] service does not support task type [sparse_embedding]") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [cohere] service")); + assertThat(thrownException.getMessage(), containsString("The [cohere] service does not support task type [sparse_embedding]")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index e5ed8891b368a..998559d102ab7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -443,14 +443,8 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM ) ); - assertThat( - thrownException.getMessage(), - containsString("Failed to parse stored model [id] for [jinaai] service") - ); - assertThat( - thrownException.getMessage(), - containsString("The [jinaai] service does not support task type [sparse_embedding]") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [jinaai] service")); + assertThat(thrownException.getMessage(), containsString("The [jinaai] service does not support task type [sparse_embedding]")); } } @@ -687,14 +681,8 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) ); - assertThat( - thrownException.getMessage(), - containsString("Failed to parse stored model [id] for [jinaai] service") - ); - assertThat( - thrownException.getMessage(), - containsString("The [jinaai] service does not support task type [sparse_embedding]") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [jinaai] service")); + assertThat(thrownException.getMessage(), containsString("The [jinaai] service does not support task type [sparse_embedding]")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 1bf37948012e4..c8f6d0ee0e2fe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -742,14 +742,8 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM () -> service.parsePersistedConfigWithSecrets("id", TaskType.SPARSE_EMBEDDING, config.config(), config.secrets()) ); - assertThat( - thrownException.getMessage(), - containsString("Failed to parse stored model [id] for [mistral] service") - ); - assertThat( - thrownException.getMessage(), - containsString("The [mistral] service does not support task type [sparse_embedding]") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [mistral] service")); + assertThat(thrownException.getMessage(), containsString("The [mistral] service does not support task type [sparse_embedding]")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 99c6d31c207b6..d7f5726af85e0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -409,10 +409,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM ) ); - assertThat( - thrownException.getMessage(), - containsString("Failed to parse stored model [id] for [voyageai] service") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [voyageai] service")); assertThat( thrownException.getMessage(), containsString("The [voyageai] service does not support task type [sparse_embedding]") @@ -628,10 +625,7 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) ); - assertThat( - thrownException.getMessage(), - containsString("Failed to parse stored model [id] for [voyageai] service") - ); + assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [voyageai] service")); assertThat( thrownException.getMessage(), containsString("The [voyageai] service does not support task type [sparse_embedding]") From d52d4dc8d4580ede6f18bada75eb702968a1ad97 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 30 Sep 2025 10:35:17 -0400 Subject: [PATCH 10/10] Finishing comment --- .../xpack/inference/services/custom/CustomService.java | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index c5c8853c7374a..aa7fd1337428f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -241,7 +241,13 @@ public CustomModel parsePersistedConfigWithSecrets( private static ChunkingSettings extractPersistentChunkingSettings(Map config, TaskType taskType) { if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - // note there's + /* + * There's a sutle difference between how the chunking settings are parsed for the request context vs the persistent context. + * For persistent context, to support backwards compatibility, if the chunking settings are not present, removeFromMap will + * return null which results in the older word boundary chunking settings being used as the default. + * For request context, removeFromMapOrDefaultEmpty returns an empty map which results in the newer sentence boundary chunking + * settings being used as the default. + */ return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); }