From 028d4458d69e97f001bd5185ff0eeb24561b4cb1 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Wed, 24 Sep 2025 09:31:25 -0400
Subject: [PATCH 01/10] Refactoring openai
---
.../inference/services/ServiceUtils.java | 13 +
.../services/openai/OpenAiService.java | 4 +-
.../OpenAiChatCompletionServiceSettings.java | 2 +-
.../OpenAiEmbeddingsServiceSettings.java | 4 +-
.../AbstractInferenceServiceTests.java | 211 ++++-
.../services/ai21/Ai21ServiceTests.java | 10 +-
.../services/custom/CustomServiceTests.java | 10 +-
.../services/llama/LlamaServiceTests.java | 4 +-
.../services/openai/OpenAiServiceTests.java | 743 ++++--------------
9 files changed, 379 insertions(+), 622 deletions(-)
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
index 874625c93a528..00f3023954980 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
@@ -1067,6 +1067,10 @@ public interface EnumConstructor> {
E apply(String name) throws IllegalArgumentException;
}
+ /**
+ * @deprecated use {@link #parsePersistedConfigErrorMsg(String, String, TaskType)} instead
+ */
+ @Deprecated
public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName) {
return format(
"Failed to parse stored model [%s] for [%s] service, please delete and add the service again",
@@ -1075,6 +1079,15 @@ public static String parsePersistedConfigErrorMsg(String inferenceEntityId, Stri
);
}
+ public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) {
+ return format(
+ "Failed to parse stored model [%s] for [%s] service, error: [%s]. Please delete and add the service again",
+ inferenceEntityId,
+ serviceName,
+ TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName)
+ );
+ }
+
public static ElasticsearchStatusException createInvalidModelException(Model model) {
return new ElasticsearchStatusException(
format(
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
index d2b7dcc527aaa..bef5670b52058 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
@@ -232,7 +232,7 @@ public OpenAiModel parsePersistedConfigWithSecrets(
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
);
}
@@ -255,7 +255,7 @@ public OpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskT
taskSettingsMap,
chunkingSettings,
null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java
index 88840cc1202ac..6d340320c655c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java
@@ -47,7 +47,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
// The rate limit for usage tier 1 is 500 request per minute for most of the completion models
// To find this information you need to access your account's limits https://platform.openai.com/account/limits
// 500 requests per minute
- private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500);
+ public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500);
public static OpenAiChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java
index 20b9bf931290c..e9b6a7c77a5fa 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java
@@ -50,11 +50,11 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
public static final String NAME = "openai_service_settings";
- static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
+ public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
// The rate limit for usage tier 1 is 3000 request per minute for the text embedding models
// To find this information you need to access your account's limits https://platform.openai.com/account/limits
// 3000 requests per minute
- private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
+ public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
public static OpenAiEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) {
return switch (context) {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java
index ec07d7b547004..9f7fdfdf17a11 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java
@@ -7,11 +7,14 @@
package org.elasticsearch.xpack.inference.services;
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
+import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
@@ -20,6 +23,8 @@
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.inference.Utils;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.junit.After;
@@ -27,11 +32,14 @@
import org.junit.Before;
import java.io.IOException;
+import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.function.BiFunction;
+import java.util.function.Function;
import static org.elasticsearch.xpack.inference.Utils.TIMEOUT;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
@@ -39,6 +47,7 @@
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
@@ -56,6 +65,7 @@ public abstract class AbstractInferenceServiceTests extends InferenceServiceTest
protected final MockWebServer webServer = new MockWebServer();
protected ThreadPool threadPool;
protected HttpClientManager clientManager;
+ protected TestCase testCase;
@Override
@Before
@@ -77,8 +87,9 @@ public void tearDown() throws Exception {
private final TestConfiguration testConfiguration;
- public AbstractInferenceServiceTests(TestConfiguration testConfiguration) {
+ public AbstractInferenceServiceTests(TestConfiguration testConfiguration, TestCase testCase) {
this.testConfiguration = Objects.requireNonNull(testConfiguration);
+ this.testCase = testCase;
}
/**
@@ -105,7 +116,7 @@ public TestConfiguration build() {
}
/**
- * Configurations that useful for most tests
+ * Configurations that are useful for most tests
*/
public abstract static class CommonConfig {
@@ -121,6 +132,10 @@ public CommonConfig(TaskType taskType, @Nullable TaskType unsupportedTaskType) {
protected abstract Map createServiceSettingsMap(TaskType taskType);
+ protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) {
+ return createServiceSettingsMap(taskType);
+ }
+
protected abstract Map createTaskSettingsMap();
protected abstract Map createSecretSettingsMap();
@@ -154,12 +169,17 @@ protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) {
}
};
+ @Override
+ public InferenceService createInferenceService() {
+ return testConfiguration.commonConfig.createService(threadPool, clientManager);
+ }
+
public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws Exception {
var parseRequestConfigTestConfig = testConfiguration.commonConfig;
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var config = getRequestConfigMap(
- parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
+ parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.REQUEST),
parseRequestConfigTestConfig.createTaskSettingsMap(),
parseRequestConfigTestConfig.createSecretSettingsMap()
);
@@ -167,7 +187,32 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws Exception {
var listener = new PlainActionFuture();
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener);
- parseRequestConfigTestConfig.assertModel(listener.actionGet(TIMEOUT), TaskType.TEXT_EMBEDDING);
+ var model = listener.actionGet(TIMEOUT);
+ var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(Map.of());
+ assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
+ parseRequestConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
+ }
+ }
+
+ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws Exception {
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+
+ try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
+ var chunkingSettingsMap = createRandomChunkingSettingsMap();
+ var config = getRequestConfigMap(
+ parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.REQUEST),
+ parseRequestConfigTestConfig.createTaskSettingsMap(),
+ chunkingSettingsMap,
+ parseRequestConfigTestConfig.createSecretSettingsMap()
+ );
+
+ var listener = new PlainActionFuture();
+ service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener);
+
+ var model = listener.actionGet(TIMEOUT);
+ var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap);
+ assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
+ parseRequestConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
}
}
@@ -176,7 +221,7 @@ public void testParseRequestConfig_CreatesACompletionModel() throws Exception {
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var config = getRequestConfigMap(
- parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION),
+ parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.REQUEST),
parseRequestConfigTestConfig.createTaskSettingsMap(),
parseRequestConfigTestConfig.createSecretSettingsMap()
);
@@ -193,7 +238,10 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws Exception
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var config = getRequestConfigMap(
- parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType),
+ parseRequestConfigTestConfig.createServiceSettingsMap(
+ parseRequestConfigTestConfig.taskType,
+ ConfigurationParseContext.REQUEST
+ ),
parseRequestConfigTestConfig.createTaskSettingsMap(),
parseRequestConfigTestConfig.createSecretSettingsMap()
);
@@ -214,7 +262,10 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var config = getRequestConfigMap(
- parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType),
+ parseRequestConfigTestConfig.createServiceSettingsMap(
+ parseRequestConfigTestConfig.taskType,
+ ConfigurationParseContext.REQUEST
+ ),
parseRequestConfigTestConfig.createTaskSettingsMap(),
parseRequestConfigTestConfig.createSecretSettingsMap()
);
@@ -231,7 +282,10 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException {
var parseRequestConfigTestConfig = testConfiguration.commonConfig;
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
- var serviceSettings = parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType);
+ var serviceSettings = parseRequestConfigTestConfig.createServiceSettingsMap(
+ parseRequestConfigTestConfig.taskType,
+ ConfigurationParseContext.REQUEST
+ );
serviceSettings.put("extra_key", "value");
var config = getRequestConfigMap(
serviceSettings,
@@ -253,7 +307,10 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap()
var taskSettings = parseRequestConfigTestConfig.createTaskSettingsMap();
taskSettings.put("extra_key", "value");
var config = getRequestConfigMap(
- parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType),
+ parseRequestConfigTestConfig.createServiceSettingsMap(
+ parseRequestConfigTestConfig.taskType,
+ ConfigurationParseContext.REQUEST
+ ),
taskSettings,
parseRequestConfigTestConfig.createSecretSettingsMap()
);
@@ -272,7 +329,10 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap
var secretSettingsMap = parseRequestConfigTestConfig.createSecretSettingsMap();
secretSettingsMap.put("extra_key", "value");
var config = getRequestConfigMap(
- parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType),
+ parseRequestConfigTestConfig.createServiceSettingsMap(
+ parseRequestConfigTestConfig.taskType,
+ ConfigurationParseContext.REQUEST
+ ),
parseRequestConfigTestConfig.createTaskSettingsMap(),
secretSettingsMap
);
@@ -285,26 +345,122 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap
}
}
- // parsePersistedConfigWithSecrets
+ @ParametersFactory
+ public static Iterable parameters() throws IOException {
+ return Arrays.asList(
+ new TestCase[][] {
+ {
+ new TestCase(
+ "Test parsing persisted config without chunking settings",
+ testConfiguration -> getPersistedConfigMap(
+ testConfiguration.commonConfig.createServiceSettingsMap(
+ TaskType.TEXT_EMBEDDING,
+ ConfigurationParseContext.PERSISTENT
+ ),
+ testConfiguration.commonConfig.createTaskSettingsMap(),
+ null
+ ),
+ (service, persistedConfig) -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()),
+ null
+ ) } }
+ );
+ }
- public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throws Exception {
+ public record TestCase(
+ @Nullable String description,
+ Function createPersistedConfig,
+ BiFunction serviceCallback,
+ @Nullable Map chunkingSettingsMap
+ ) {}
+
+ public void testPersistedConfig() throws Exception {
var parseConfigTestConfig = testConfiguration.commonConfig;
+ var persistedConfig = testCase.createPersistedConfig.apply(testConfiguration);
try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
- parseConfigTestConfig.createTaskSettingsMap(),
- parseConfigTestConfig.createSecretSettingsMap()
+
+ var model = testCase.serviceCallback.apply(service, persistedConfig);
+
+ var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(
+ testCase.chunkingSettingsMap == null ? Map.of() : testCase.chunkingSettingsMap
);
+ assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
- var model = service.parsePersistedConfigWithSecrets(
+ parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
+ }
+
+ parseConfigHelper(service -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfigMap.config()), null);
+ }
+
+ // parsePersistedConfig tests
+
+ public void testParsePersistedConfig_CreatesAnEmbeddingsModel() throws Exception {
+ var parseConfigTestConfig = testConfiguration.commonConfig;
+ var persistedConfigMap = getPersistedConfigMap(
+ parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT),
+ parseConfigTestConfig.createTaskSettingsMap(),
+ null
+ );
+
+ parseConfigHelper(service -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfigMap.config()), null);
+ }
+
+ private void parseConfigHelper(Function serviceParseCallback, @Nullable Map chunkingSettingsMap)
+ throws Exception {
+ var parseConfigTestConfig = testConfiguration.commonConfig;
+ try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
+
+ var model = serviceParseCallback.apply(service);
+
+ var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap == null ? Map.of() : chunkingSettingsMap);
+ assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
+
+ parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
+ }
+ }
+
+ // parsePersistedConfigWithSecrets
+
+ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throws Exception {
+ var parseConfigTestConfig = testConfiguration.commonConfig;
+
+ var persistedConfigMap = getPersistedConfigMap(
+ parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT),
+ parseConfigTestConfig.createTaskSettingsMap(),
+ parseConfigTestConfig.createSecretSettingsMap()
+ );
+
+ parseConfigHelper(
+ service -> service.parsePersistedConfigWithSecrets(
"id",
TaskType.TEXT_EMBEDDING,
persistedConfigMap.config(),
persistedConfigMap.secrets()
- );
- parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
- }
+ ),
+ null
+ );
+ }
+
+ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsAreProvided() throws Exception {
+ var parseConfigTestConfig = testConfiguration.commonConfig;
+
+ var chunkingSettingsMap = createRandomChunkingSettingsMap();
+ var persistedConfigMap = getPersistedConfigMap(
+ parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT),
+ parseConfigTestConfig.createTaskSettingsMap(),
+ chunkingSettingsMap,
+ parseConfigTestConfig.createSecretSettingsMap()
+ );
+
+ parseConfigHelper(
+ service -> service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ persistedConfigMap.config(),
+ persistedConfigMap.secrets()
+ ),
+ chunkingSettingsMap
+ );
}
public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws Exception {
@@ -312,7 +468,7 @@ public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws
try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION),
+ parseConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
parseConfigTestConfig.createTaskSettingsMap(),
parseConfigTestConfig.createSecretSettingsMap()
);
@@ -332,7 +488,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() thr
try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType),
+ parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT),
parseConfigTestConfig.createTaskSettingsMap(),
parseConfigTestConfig.createSecretSettingsMap()
);
@@ -365,7 +521,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists
try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType),
+ parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT),
parseConfigTestConfig.createTaskSettingsMap(),
parseConfigTestConfig.createSecretSettingsMap()
);
@@ -385,7 +541,10 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists
public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException {
var parseConfigTestConfig = testConfiguration.commonConfig;
try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var serviceSettings = parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType);
+ var serviceSettings = parseConfigTestConfig.createServiceSettingsMap(
+ parseConfigTestConfig.taskType,
+ ConfigurationParseContext.PERSISTENT
+ );
serviceSettings.put("extra_key", "value");
var persistedConfigMap = getPersistedConfigMap(
serviceSettings,
@@ -410,7 +569,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTask
var taskSettings = parseConfigTestConfig.createTaskSettingsMap();
taskSettings.put("extra_key", "value");
var config = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType),
+ parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT),
taskSettings,
parseConfigTestConfig.createSecretSettingsMap()
);
@@ -427,7 +586,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecr
var secretSettingsMap = parseConfigTestConfig.createSecretSettingsMap();
secretSettingsMap.put("extra_key", "value");
var config = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType),
+ parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT),
parseConfigTestConfig.createTaskSettingsMap(),
secretSettingsMap
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
index cbb119d3e5710..5fb878403f05a 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
@@ -19,7 +19,6 @@
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.EmptyTaskSettings;
-import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
@@ -82,8 +81,8 @@ public class Ai21ServiceTests extends AbstractInferenceServiceTests {
private ThreadPool threadPool;
private HttpClientManager clientManager;
- public Ai21ServiceTests() {
- super(createTestConfiguration());
+ public Ai21ServiceTests(TestCase testCase) {
+ super(createTestConfiguration(), testCase);
}
private static AbstractInferenceServiceTests.TestConfiguration createTestConfiguration() {
@@ -561,9 +560,4 @@ private Map getRequestConfigMap(Map serviceSetti
return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings));
}
-
- @Override
- public InferenceService createInferenceService() {
- return createService();
- }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
index 55bb98705a2a3..ec303f0c7b796 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
@@ -16,7 +16,6 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
-import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
@@ -70,8 +69,8 @@
public class CustomServiceTests extends AbstractInferenceServiceTests {
- public CustomServiceTests() {
- super(createTestConfiguration());
+ public CustomServiceTests(TestCase testCase) {
+ super(createTestConfiguration(), testCase);
}
private static TestConfiguration createTestConfiguration() {
@@ -808,11 +807,6 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
}
}
- @Override
- public InferenceService createInferenceService() {
- return createService(threadPool, clientManager);
- }
-
@Override
protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
assertThat(
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
index 243235211a7de..5d81c8b062492 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
@@ -102,8 +102,8 @@ public class LlamaServiceTests extends AbstractInferenceServiceTests {
private ThreadPool threadPool;
private HttpClientManager clientManager;
- public LlamaServiceTests() {
- super(createTestConfiguration());
+ public LlamaServiceTests(TestCase testCase) {
+ super(createTestConfiguration(), testCase);
}
private static TestConfiguration createTestConfiguration() {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
index 676dca2778141..705afe78d3196 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
@@ -18,8 +18,10 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
@@ -49,24 +51,33 @@
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
-import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
+import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests;
+import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
+import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.hamcrest.CoreMatchers;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
+import java.net.URI;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
+import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
@@ -102,8 +113,22 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class OpenAiServiceTests extends InferenceServiceTestCase {
+public class OpenAiServiceTests extends AbstractInferenceServiceTests {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+ private static final String MODEL = "model";
+ private static final String URL = "http://www.elastic.co";
+ private static final String ORGANIZATION = "org";
+ private static final int MAX_INPUT_TOKENS = 123;
+ private static final SimilarityMeasure SIMILARITY = SimilarityMeasure.DOT_PRODUCT;
+ private static final int DIMENSIONS = 100;
+ private static final boolean DIMENSIONS_SET_BY_USER = true;
+ private static final String USER = "user";
+ private static final String HEADER_KEY = "header_key";
+ private static final String HEADER_VALUE = "header_value";
+ private static final Map HEADERS = Map.of(HEADER_KEY, HEADER_VALUE);
+ private static final String SECRET = "secret";
+ private static final String INFERENCE_ID = "id";
+
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
private HttpClientManager clientManager;
@@ -122,221 +147,174 @@ public void shutdown() throws IOException {
webServer.close();
}
- public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModel() throws IOException {
- try (var service = createOpenAiService()) {
- ActionListener modelVerificationListener = ActionListener.wrap(model -> {
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }, exception -> fail("Unexpected exception: " + exception));
-
- service.parseRequestConfig(
- "id",
- TaskType.TEXT_EMBEDDING,
- getRequestConfigMap(
- getServiceSettingsMap("model", "url", "org"),
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- ),
- modelVerificationListener
- );
- }
+ public OpenAiServiceTests(TestCase testCase) {
+ super(createTestConfiguration(), testCase);
}
- public void testParseRequestConfig_CreatesAnOpenAiChatCompletionsModel() throws IOException {
- var url = "url";
- var organization = "org";
- var model = "model";
- var user = "user";
- var secret = "secret";
-
- try (var service = createOpenAiService()) {
- ActionListener modelVerificationListener = ActionListener.wrap(m -> {
- assertThat(m, instanceOf(OpenAiChatCompletionModel.class));
+ private static TestConfiguration createTestConfiguration() {
+ return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.RERANK) {
+ @Override
+ protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
+ return OpenAiServiceTests.createService(threadPool, clientManager);
+ }
- var completionsModel = (OpenAiChatCompletionModel) m;
+ @Override
+ protected Map createServiceSettingsMap(TaskType taskType) {
+ return createServiceSettingsMap(taskType, ConfigurationParseContext.REQUEST);
+ }
- assertThat(completionsModel.getServiceSettings().uri().toString(), is(url));
- assertThat(completionsModel.getServiceSettings().organizationId(), is(organization));
- assertThat(completionsModel.getServiceSettings().modelId(), is(model));
- assertThat(completionsModel.getTaskSettings().user(), is(user));
- assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret));
+ @Override
+ protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) {
+ return OpenAiServiceTests.createServiceSettingsMap(taskType, parseContext);
+ }
- }, exception -> fail("Unexpected exception: " + exception));
+ @Override
+ protected Map createTaskSettingsMap() {
+ return OpenAiServiceTests.createTaskSettingsMap();
+ }
- service.parseRequestConfig(
- "id",
- TaskType.COMPLETION,
- getRequestConfigMap(
- getServiceSettingsMap(model, url, organization),
- getOpenAiTaskSettingsMap(user),
- getSecretSettingsMap(secret)
- ),
- modelVerificationListener
- );
- }
- }
+ @Override
+ protected Map createSecretSettingsMap() {
+ return getSecretSettingsMap(SECRET);
+ }
- public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
- try (var service = createOpenAiService()) {
- ActionListener modelVerificationListener = ActionListener.wrap(
- model -> fail("Expected exception, but got model: " + model),
- exception -> {
- assertThat(exception, instanceOf(ElasticsearchStatusException.class));
- assertThat(exception.getMessage(), is("The [openai] service does not support task type [sparse_embedding]"));
- }
- );
+ @Override
+ protected void assertModel(Model model, TaskType taskType) {
+ OpenAiServiceTests.assertModel(model, taskType);
+ }
- service.parseRequestConfig(
- "id",
- TaskType.SPARSE_EMBEDDING,
- getRequestConfigMap(
- getServiceSettingsMap("model", "url", "org"),
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- ),
- modelVerificationListener
- );
- }
+ @Override
+ protected EnumSet supportedStreamingTasks() {
+ return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
+ }
+ }).enableUpdateModelTests(new UpdateModelConfiguration() {
+ @Override
+ protected OpenAiEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) {
+ return createInternalEmbeddingModel(similarityMeasure, null);
+ }
+ }).build();
}
- public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
- try (var service = createOpenAiService()) {
- var config = getRequestConfigMap(
- getServiceSettingsMap("model", "url", "org"),
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- );
- config.put("extra_key", "value");
-
- ActionListener modelVerificationListener = ActionListener.wrap(
- model -> fail("Expected exception, but got model: " + model),
- exception -> {
- assertThat(exception, instanceOf(ElasticsearchStatusException.class));
- assertThat(
- exception.getMessage(),
- is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service")
- );
- }
+ private static Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) {
+ var settingsMap = new HashMap(
+ Map.of(
+ ServiceFields.MODEL_ID,
+ MODEL,
+ ServiceFields.URL,
+ URL,
+ OpenAiServiceFields.ORGANIZATION,
+ ORGANIZATION,
+ ServiceFields.MAX_INPUT_TOKENS,
+ MAX_INPUT_TOKENS
+ )
+ );
+
+ if (taskType == TaskType.TEXT_EMBEDDING) {
+ settingsMap.putAll(
+ Map.of(
+ ServiceFields.SIMILARITY,
+ SIMILARITY.toString(),
+ ServiceFields.DIMENSIONS,
+ DIMENSIONS
+ )
);
- service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
+ if (parseContext == ConfigurationParseContext.PERSISTENT) {
+ settingsMap.put(OpenAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, DIMENSIONS_SET_BY_USER);
+ }
}
- }
- public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException {
- try (var service = createOpenAiService()) {
- var serviceSettings = getServiceSettingsMap("model", "url", "org");
- serviceSettings.put("extra_key", "value");
-
- var config = getRequestConfigMap(serviceSettings, getOpenAiTaskSettingsMap("user"), getSecretSettingsMap("secret"));
+ return settingsMap;
+ }
- ActionListener modelVerificationListener = ActionListener.wrap((model) -> {
- fail("Expected exception, but got model: " + model);
- }, e -> {
- assertThat(e, instanceOf(ElasticsearchStatusException.class));
- assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service"));
- });
+ private static Map createTaskSettingsMap() {
+ return new HashMap<>(Map.of(OpenAiServiceFields.USER, USER, OpenAiServiceFields.HEADERS, HEADERS));
+ }
- service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
+ private static void assertModel(Model model, TaskType taskType) {
+ switch (taskType) {
+ case TEXT_EMBEDDING -> assertTextEmbeddingModel(model);
+ case COMPLETION, CHAT_COMPLETION -> assertCompletionModel(model);
+ default -> fail("unexpected task type: " + taskType);
}
}
- public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException {
- try (var service = createOpenAiService()) {
- var taskSettingsMap = getOpenAiTaskSettingsMap("user");
- taskSettingsMap.put("extra_key", "value");
-
- var config = getRequestConfigMap(getServiceSettingsMap("model", "url", "org"), taskSettingsMap, getSecretSettingsMap("secret"));
+ private static void assertTextEmbeddingModel(Model model) {
+ assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
- ActionListener modelVerificationListener = ActionListener.wrap((model) -> {
- fail("Expected exception, but got model: " + model);
- }, e -> {
- assertThat(e, instanceOf(ElasticsearchStatusException.class));
- assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service"));
- });
+ var embeddingsModel = (OpenAiEmbeddingsModel) model;
+ assertThat(
+ embeddingsModel.getServiceSettings(),
+ is(
+ new OpenAiEmbeddingsServiceSettings(
+ MODEL,
+ URI.create(URL),
+ ORGANIZATION,
+ SIMILARITY,
+ DIMENSIONS,
+ MAX_INPUT_TOKENS,
+ DIMENSIONS_SET_BY_USER,
+ OpenAiEmbeddingsServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS
+ )
+ )
+ );
- service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
- }
+ assertThat(embeddingsModel.getTaskSettings(), is(new OpenAiEmbeddingsTaskSettings(USER, HEADERS)));
+ assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(SECRET));
}
- public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException {
- try (var service = createOpenAiService()) {
- var secretSettingsMap = getSecretSettingsMap("secret");
- secretSettingsMap.put("extra_key", "value");
+ private static void assertCompletionModel(Model model) {
+ assertThat(model, instanceOf(OpenAiChatCompletionModel.class));
- var config = getRequestConfigMap(
- getServiceSettingsMap("model", "url", "org"),
- getOpenAiTaskSettingsMap("user"),
- secretSettingsMap
- );
+ var completionModel = (OpenAiChatCompletionModel) model;
- ActionListener modelVerificationListener = ActionListener.wrap((model) -> {
- fail("Expected exception, but got model: " + model);
- }, e -> {
- assertThat(e, instanceOf(ElasticsearchStatusException.class));
- assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service"));
- });
+ assertThat(
+ completionModel.getServiceSettings(),
+ is(
+ new OpenAiChatCompletionServiceSettings(
+ MODEL,
+ URI.create(URL),
+ ORGANIZATION,
+ MAX_INPUT_TOKENS,
+ OpenAiChatCompletionServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS
+ )
+ )
+ );
- service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
- }
+ assertThat(completionModel.getTaskSettings(), is(new OpenAiChatCompletionTaskSettings(USER, HEADERS)));
+ assertThat(completionModel.getSecretSettings().apiKey().toString(), is(SECRET));
}
- public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException {
- try (var service = createOpenAiService()) {
- ActionListener modelVerificationListener = ActionListener.wrap(model -> {
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }, exception -> fail("Unexpected exception: " + exception));
-
- service.parseRequestConfig(
- "id",
- TaskType.TEXT_EMBEDDING,
- getRequestConfigMap(
- getServiceSettingsMap("model", null, null),
- getOpenAiTaskSettingsMap(null),
- getSecretSettingsMap("secret")
- ),
- modelVerificationListener
- );
- }
+ private static OpenAiEmbeddingsModel createInternalEmbeddingModel(
+ SimilarityMeasure similarityMeasure,
+ @Nullable ChunkingSettings chunkingSettings
+ ) {
+ return createInternalEmbeddingModel(similarityMeasure, URL, chunkingSettings);
}
- public void testParseRequestConfig_CreatesAnOpenAiChatCompletionsModelWithoutUserWithoutUserUrlOrganization() throws IOException {
- var model = "model";
- var secret = "secret";
-
- try (var service = createOpenAiService()) {
- ActionListener modelVerificationListener = ActionListener.wrap(m -> {
- assertThat(m, instanceOf(OpenAiChatCompletionModel.class));
-
- var completionsModel = (OpenAiChatCompletionModel) m;
- assertNull(completionsModel.getServiceSettings().uri());
- assertNull(completionsModel.getServiceSettings().organizationId());
- assertThat(completionsModel.getServiceSettings().modelId(), is(model));
- assertNull(completionsModel.getTaskSettings().user());
- assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret));
-
- }, exception -> fail("Unexpected exception: " + exception));
-
- service.parseRequestConfig(
- "id",
- TaskType.COMPLETION,
- getRequestConfigMap(getServiceSettingsMap(model, null, null), getOpenAiTaskSettingsMap(null), getSecretSettingsMap(secret)),
- modelVerificationListener
- );
- }
+ private static OpenAiEmbeddingsModel createInternalEmbeddingModel(
+ SimilarityMeasure similarityMeasure,
+ @Nullable String url,
+ @Nullable ChunkingSettings chunkingSettings
+ ) {
+ return new OpenAiEmbeddingsModel(
+ INFERENCE_ID,
+ TaskType.TEXT_EMBEDDING,
+ "service",
+ new OpenAiEmbeddingsServiceSettings(
+ MODEL,
+ url == null ? null : URI.create(url),
+ ORGANIZATION,
+ similarityMeasure,
+ DIMENSIONS,
+ DIMENSIONS,
+ false,
+ null
+ ),
+ new OpenAiEmbeddingsTaskSettings(USER, HEADERS),
+ chunkingSettings,
+ new DefaultSecretSettings(new SecureString(SECRET.toCharArray()))
+ );
}
public void testParseRequestConfig_MovesModel() throws IOException {
@@ -365,392 +343,6 @@ public void testParseRequestConfig_MovesModel() throws IOException {
}
}
- public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
- try (var service = createOpenAiService()) {
- ActionListener modelVerificationListener = ActionListener.wrap(model -> {
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }, exception -> fail("Unexpected exception: " + exception));
-
- service.parseRequestConfig(
- "id",
- TaskType.TEXT_EMBEDDING,
- getRequestConfigMap(
- getServiceSettingsMap("model", null, null),
- getOpenAiTaskSettingsMap(null),
- createRandomChunkingSettingsMap(),
- getSecretSettingsMap("secret")
- ),
- modelVerificationListener
- );
- }
- }
-
- public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
- try (var service = createOpenAiService()) {
- ActionListener modelVerificationListener = ActionListener.wrap(model -> {
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }, exception -> fail("Unexpected exception: " + exception));
-
- service.parseRequestConfig(
- "id",
- TaskType.TEXT_EMBEDDING,
- getRequestConfigMap(
- getServiceSettingsMap("model", null, null),
- getOpenAiTaskSettingsMap(null),
- getSecretSettingsMap("secret")
- ),
- modelVerificationListener
- );
- }
- }
-
- public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModel() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org", 100, null, false),
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org"),
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- );
-
- var thrownException = expectThrows(
- ElasticsearchStatusException.class,
- () -> service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.SPARSE_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- )
- );
-
- assertThat(
- thrownException.getMessage(),
- is("Failed to parse stored model [id] for [openai] service, please delete and add the service again")
- );
- }
- }
-
- public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", null, null, null, null, true),
- getOpenAiTaskSettingsMap(null),
- getSecretSettingsMap("secret")
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", null, null, null, null, true),
- getOpenAiTaskSettingsMap(null),
- createRandomChunkingSettingsMap(),
- getSecretSettingsMap("secret")
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", null, null, null, null, true),
- getOpenAiTaskSettingsMap(null),
- getSecretSettingsMap("secret")
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org", null, null, true),
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- );
- persistedConfig.config().put("extra_key", "value");
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException {
- try (var service = createOpenAiService()) {
- var secretSettingsMap = getSecretSettingsMap("secret");
- secretSettingsMap.put("extra_key", "value");
-
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org", null, null, true),
- getOpenAiTaskSettingsMap("user"),
- secretSettingsMap
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org", null, null, true),
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- );
- persistedConfig.secrets().put("extra_key", "value");
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException {
- try (var service = createOpenAiService()) {
- var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true);
- serviceSettingsMap.put("extra_key", "value");
-
- var persistedConfig = getPersistedConfigMap(
- serviceSettingsMap,
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException {
- try (var service = createOpenAiService()) {
- var taskSettingsMap = getOpenAiTaskSettingsMap("user");
- taskSettingsMap.put("extra_key", "value");
-
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org", null, null, true),
- taskSettingsMap,
- getSecretSettingsMap("secret")
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModel() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org", null, null, true),
- getOpenAiTaskSettingsMap("user")
- );
-
- var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertNull(embeddingsModel.getSecretSettings());
- }
- }
-
- public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org"), getOpenAiTaskSettingsMap("user"));
-
- var thrownException = expectThrows(
- ElasticsearchStatusException.class,
- () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
- );
-
- assertThat(
- thrownException.getMessage(),
- is("Failed to parse stored model [id] for [openai] service, please delete and add the service again")
- );
- }
- }
-
- public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", null, null, null, null, true),
- getOpenAiTaskSettingsMap(null)
- );
-
- var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertNull(embeddingsModel.getSecretSettings());
- }
- }
-
public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
try (var service = createOpenAiService()) {
var persistedConfig = getPersistedConfigMap(
@@ -1681,6 +1273,11 @@ public void testGetConfiguration() throws Exception {
}
}
+ private static OpenAiService createService(ThreadPool threadPool, HttpClientManager clientManager) {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+ return new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty());
+ }
+
private OpenAiService createOpenAiService() {
return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty());
}
From c3243fda24918816e61a457ba860c09a0a61c724 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Wed, 24 Sep 2025 17:21:08 -0400
Subject: [PATCH 02/10] Splitting up parameterized tests
---
...actInferenceServiceParameterizedTests.java | 469 ++++++++++++++++++
.../AbstractInferenceServiceTests.java | 116 ++++-
.../services/ai21/Ai21ServiceTests.java | 25 +-
.../services/custom/CustomServiceTests.java | 30 +-
.../services/llama/LlamaServiceTests.java | 37 +-
.../OpenAiServiceParameterizedTests.java | 18 +
.../services/openai/OpenAiServiceTests.java | 84 +---
7 files changed, 659 insertions(+), 120 deletions(-)
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceParameterizedTests.java
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java
new file mode 100644
index 0000000000000..f6fdefb1bac6f
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java
@@ -0,0 +1,469 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services;
+
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Strings;
+import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.inference.Utils;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.junit.After;
+import org.junit.Assume;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.function.Function;
+
+import static org.elasticsearch.xpack.inference.Utils.TIMEOUT;
+import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
+import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Base class for testing inference services using parameterized tests.
+ */
+public abstract class AbstractInferenceServiceParameterizedTests extends InferenceServiceTestCase {
+
+ private final AbstractInferenceServiceTests.TestConfiguration testConfiguration;
+
+ protected final MockWebServer webServer = new MockWebServer();
+ protected ThreadPool threadPool;
+ protected HttpClientManager clientManager;
+ protected TestCase testCase;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ webServer.start();
+ threadPool = createThreadPool(inferenceUtilityExecutors());
+ clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+ }
+
+ @Override
+ @After
+ public void tearDown() throws Exception {
+ super.tearDown();
+ clientManager.close();
+ terminate(threadPool);
+ webServer.close();
+ }
+
+ public AbstractInferenceServiceParameterizedTests(
+ AbstractInferenceServiceTests.TestConfiguration testConfiguration,
+ TestCase testCase
+ ) {
+ this.testConfiguration = Objects.requireNonNull(testConfiguration);
+ this.testCase = testCase;
+ }
+
+ @Override
+ public InferenceService createInferenceService() {
+ return testConfiguration.commonConfig().createService(threadPool, clientManager);
+ }
+
+ @ParametersFactory
+ public static Iterable parameters() throws IOException {
+ return Arrays.asList(
+ new TestCase[][] {
+ // parsePersistedConfig
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config without chunking settings",
+ testConfiguration -> getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ null
+ ),
+ (service, persistedConfig, testConfiguration) -> service.parsePersistedConfig(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ persistedConfig.config()
+ ),
+ TaskType.TEXT_EMBEDDING
+ ).build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with chunking settings",
+ testConfiguration -> getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ createRandomChunkingSettingsMap(),
+ null
+ ),
+ (service, persistedConfig, testConfiguration) -> service.parsePersistedConfig(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ persistedConfig.config()
+ ),
+ TaskType.TEXT_EMBEDDING
+ ).build() },
+ // parsePersistedConfigWithSecrets
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with secrets creates an embeddings model",
+ testConfiguration -> getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ testConfiguration.commonConfig().createSecretSettingsMap()
+ ),
+ (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ persistedConfig.config(),
+ persistedConfig.secrets()
+ ),
+ TaskType.TEXT_EMBEDDING
+ ).withSecrets().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with with secrets creates an embeddings "
+ + "model when chunking settings are provided",
+ testConfiguration -> getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ createRandomChunkingSettingsMap(),
+ testConfiguration.commonConfig().createSecretSettingsMap()
+ ),
+ (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ persistedConfig.config(),
+ persistedConfig.secrets()
+ ),
+ TaskType.TEXT_EMBEDDING
+ ).withSecrets().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with with secrets creates a completion "
+ + "model when chunking settings are not provided",
+ testConfiguration -> getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ testConfiguration.commonConfig().createSecretSettingsMap()
+ ),
+ (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.COMPLETION,
+ persistedConfig.config(),
+ persistedConfig.secrets()
+ ),
+ TaskType.COMPLETION
+ ).withSecrets().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with with secrets throws exception for unsupported task type",
+ testConfiguration -> getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ testConfiguration.commonConfig().createSecretSettingsMap()
+ ),
+ (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets(
+ "id",
+ testConfiguration.commonConfig().unsupportedTaskType(),
+ persistedConfig.config(),
+ persistedConfig.secrets()
+ ),
+ TaskType.COMPLETION
+ ).withSecrets().build() } }
+ );
+ }
+
+ public record TestCase(
+ String description,
+ Function createPersistedConfig,
+ ServiceCallback serviceCallback,
+ TaskType expectedTaskType,
+ boolean modelIncludesSecrets,
+ boolean expectFailure
+ ) {}
+
+ @FunctionalInterface
+ interface ServiceCallback {
+ Model parseConfigs(
+ SenderService service,
+ Utils.PersistedConfig persistedConfig,
+ AbstractInferenceServiceTests.TestConfiguration testConfiguration
+ );
+ }
+
+ private static class TestCaseBuilder {
+ private final String description;
+ private final Function createPersistedConfig;
+ private final ServiceCallback serviceCallback;
+ private final TaskType expectedTaskType;
+ private boolean modelIncludesSecrets;
+ private boolean expectFailure;
+
+ TestCaseBuilder(
+ String description,
+ Function createPersistedConfig,
+ ServiceCallback serviceCallback,
+ TaskType expectedTaskType
+ ) {
+ this.description = description;
+ this.createPersistedConfig = createPersistedConfig;
+ this.serviceCallback = serviceCallback;
+ this.expectedTaskType = expectedTaskType;
+ }
+
+ public TestCaseBuilder withSecrets() {
+ this.modelIncludesSecrets = true;
+ return this;
+ }
+
+ public TestCaseBuilder withFailure() {
+ this.expectFailure = true;
+ return this;
+ }
+
+ public TestCase build() {
+ return new TestCase(description, createPersistedConfig, serviceCallback, expectedTaskType, modelIncludesSecrets, expectFailure);
+ }
+ }
+
+ public void testPersistedConfig() throws Exception {
+ var parseConfigTestConfig = testConfiguration.commonConfig();
+ var persistedConfig = testCase.createPersistedConfig.apply(testConfiguration);
+
+ try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
+ var model = testCase.serviceCallback.parseConfigs(service, persistedConfig, testConfiguration);
+
+ if (persistedConfig.config().containsKey(ModelConfigurations.CHUNKING_SETTINGS)) {
+ @SuppressWarnings("unchecked")
+ var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(
+ (Map) persistedConfig.config().get(ModelConfigurations.CHUNKING_SETTINGS)
+ );
+ assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
+ }
+
+ parseConfigTestConfig.assertModel(model, testCase.expectedTaskType, testCase.modelIncludesSecrets);
+ }
+ }
+
+ // parsePersistedConfigWithSecrets
+
+ public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() throws Exception {
+ var parseConfigTestConfig = testConfiguration.commonConfig();
+
+ try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
+ var persistedConfigMap = getPersistedConfigMap(
+ parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT),
+ parseConfigTestConfig.createTaskSettingsMap(),
+ parseConfigTestConfig.createSecretSettingsMap()
+ );
+
+ var exception = expectThrows(
+ ElasticsearchStatusException.class,
+ () -> service.parsePersistedConfigWithSecrets(
+ "id",
+ parseConfigTestConfig.unsupportedTaskType(),
+ persistedConfigMap.config(),
+ persistedConfigMap.secrets()
+ )
+ );
+
+ assertThat(
+ exception.getMessage(),
+ containsString(
+ Strings.format(fetchPersistedConfigTaskTypeParsingErrorMessageFormat(), parseConfigTestConfig.unsupportedTaskType())
+ )
+ );
+ }
+ }
+
+ protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() {
+ return "service does not support task type [%s]";
+ }
+
+ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
+ var parseConfigTestConfig = testConfiguration.commonConfig();
+
+ try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
+ var persistedConfigMap = getPersistedConfigMap(
+ parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT),
+ parseConfigTestConfig.createTaskSettingsMap(),
+ parseConfigTestConfig.createSecretSettingsMap()
+ );
+ persistedConfigMap.config().put("extra_key", "value");
+
+ var model = service.parsePersistedConfigWithSecrets(
+ "id",
+ parseConfigTestConfig.taskType(),
+ persistedConfigMap.config(),
+ persistedConfigMap.secrets()
+ );
+
+ parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType());
+ }
+ }
+
+ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException {
+ var parseConfigTestConfig = testConfiguration.commonConfig();
+ try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
+ var serviceSettings = parseConfigTestConfig.createServiceSettingsMap(
+ parseConfigTestConfig.taskType(),
+ ConfigurationParseContext.PERSISTENT
+ );
+ serviceSettings.put("extra_key", "value");
+ var persistedConfigMap = getPersistedConfigMap(
+ serviceSettings,
+ parseConfigTestConfig.createTaskSettingsMap(),
+ parseConfigTestConfig.createSecretSettingsMap()
+ );
+
+ var model = service.parsePersistedConfigWithSecrets(
+ "id",
+ parseConfigTestConfig.taskType(),
+ persistedConfigMap.config(),
+ persistedConfigMap.secrets()
+ );
+
+ parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType());
+ }
+ }
+
+ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException {
+ var parseConfigTestConfig = testConfiguration.commonConfig();
+ try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
+ var taskSettings = parseConfigTestConfig.createTaskSettingsMap();
+ taskSettings.put("extra_key", "value");
+ var config = getPersistedConfigMap(
+ parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT),
+ taskSettings,
+ parseConfigTestConfig.createSecretSettingsMap()
+ );
+
+ var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType(), config.config(), config.secrets());
+
+ parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType());
+ }
+ }
+
+ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException {
+ var parseConfigTestConfig = testConfiguration.commonConfig();
+ try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
+ var secretSettingsMap = parseConfigTestConfig.createSecretSettingsMap();
+ secretSettingsMap.put("extra_key", "value");
+ var config = getPersistedConfigMap(
+ parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT),
+ parseConfigTestConfig.createTaskSettingsMap(),
+ secretSettingsMap
+ );
+
+ var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType(), config.config(), config.secrets());
+
+ parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType());
+ }
+ }
+
+ public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException {
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
+ var listener = new PlainActionFuture();
+
+ service.infer(
+ getInvalidModel("id", "service"),
+ null,
+ null,
+ null,
+ List.of(""),
+ false,
+ new HashMap<>(),
+ InputType.INTERNAL_SEARCH,
+ InferenceAction.Request.DEFAULT_TIMEOUT,
+ listener
+ );
+
+ var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+ assertThat(
+ exception.getMessage(),
+ is("The internal model was invalid, please delete the service [service] with id [id] and add it again.")
+ );
+ }
+ }
+
+ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
+ Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled());
+
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
+ var exception = expectThrows(
+ ElasticsearchStatusException.class,
+ () -> service.updateModelWithEmbeddingDetails(getInvalidModel("id", "service"), randomNonNegativeInt())
+ );
+
+ assertThat(exception.getMessage(), containsString("Can't update embedding details for model"));
+ }
+ }
+
+ public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
+ Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled());
+
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
+ var embeddingSize = randomNonNegativeInt();
+ var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(null);
+
+ Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
+
+ assertEquals(SimilarityMeasure.DOT_PRODUCT, updatedModel.getServiceSettings().similarity());
+ assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
+ }
+ }
+
+ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
+ Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled());
+
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
+ var embeddingSize = randomNonNegativeInt();
+ var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(SimilarityMeasure.COSINE);
+
+ Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
+
+ assertEquals(SimilarityMeasure.COSINE, updatedModel.getServiceSettings().similarity());
+ assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
+ }
+ }
+
+ // streaming tests
+ public void testSupportedStreamingTasks() throws Exception {
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
+ assertThat(service.supportedStreamingTasks(), is(testConfiguration.commonConfig().supportedStreamingTasks()));
+ assertFalse(service.canStream(TaskType.ANY));
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java
index 9f7fdfdf17a11..b1ff02b6bbc17 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java
@@ -62,6 +62,8 @@
*/
public abstract class AbstractInferenceServiceTests extends InferenceServiceTestCase {
+ private final TestConfiguration testConfiguration;
+
protected final MockWebServer webServer = new MockWebServer();
protected ThreadPool threadPool;
protected HttpClientManager clientManager;
@@ -85,8 +87,6 @@ public void tearDown() throws Exception {
webServer.close();
}
- private final TestConfiguration testConfiguration;
-
public AbstractInferenceServiceTests(TestConfiguration testConfiguration, TestCase testCase) {
this.testConfiguration = Objects.requireNonNull(testConfiguration);
this.testCase = testCase;
@@ -128,6 +128,14 @@ public CommonConfig(TaskType taskType, @Nullable TaskType unsupportedTaskType) {
this.unsupportedTaskType = unsupportedTaskType;
}
+ public TaskType taskType() {
+ return taskType;
+ }
+
+ public TaskType unsupportedTaskType() {
+ return unsupportedTaskType;
+ }
+
protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager);
protected abstract Map createServiceSettingsMap(TaskType taskType);
@@ -140,7 +148,11 @@ protected Map createServiceSettingsMap(TaskType taskType, Config
protected abstract Map createSecretSettingsMap();
- protected abstract void assertModel(Model model, TaskType taskType);
+ protected abstract void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets);
+
+ protected void assertModel(Model model, TaskType taskType) {
+ assertModel(model, taskType, true);
+ }
protected abstract EnumSet supportedStreamingTasks();
}
@@ -347,10 +359,12 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap
@ParametersFactory
public static Iterable parameters() throws IOException {
+ var chunkingSettingsMap = createRandomChunkingSettingsMap();
+
return Arrays.asList(
new TestCase[][] {
{
- new TestCase(
+ new TestCaseBuilder(
"Test parsing persisted config without chunking settings",
testConfiguration -> getPersistedConfigMap(
testConfiguration.commonConfig.createServiceSettingsMap(
@@ -360,36 +374,106 @@ public static Iterable parameters() throws IOException {
testConfiguration.commonConfig.createTaskSettingsMap(),
null
),
- (service, persistedConfig) -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()),
- null
- ) } }
+ (service, persistedConfig) -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config())
+ ).withNullChunkingSettingsMap().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with chunking settings",
+ testConfiguration -> getPersistedConfigMap(
+ testConfiguration.commonConfig.createServiceSettingsMap(
+ TaskType.TEXT_EMBEDDING,
+ ConfigurationParseContext.PERSISTENT
+ ),
+ testConfiguration.commonConfig.createTaskSettingsMap(),
+ chunkingSettingsMap,
+ null
+ ),
+ (service, persistedConfig) -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config())
+ ).withChunkingSettingsMap(chunkingSettingsMap).build() } }
);
}
public record TestCase(
- @Nullable String description,
+ String description,
Function createPersistedConfig,
BiFunction serviceCallback,
- @Nullable Map chunkingSettingsMap
+ @Nullable Map chunkingSettingsMap,
+ boolean validateChunkingSettings,
+ boolean modelIncludesSecrets
) {}
+ private static class TestCaseBuilder {
+ private final String description;
+ private final Function createPersistedConfig;
+ private final BiFunction serviceCallback;
+ @Nullable
+ private Map chunkingSettingsMap;
+ private boolean validateChunkingSettings;
+ private boolean modelIncludesSecrets;
+
+ TestCaseBuilder(
+ String description,
+ Function createPersistedConfig,
+ BiFunction serviceCallback
+ ) {
+ this.description = description;
+ this.createPersistedConfig = createPersistedConfig;
+ this.serviceCallback = serviceCallback;
+ }
+
+ public TestCaseBuilder withSecrets() {
+ this.modelIncludesSecrets = true;
+ return this;
+ }
+
+ public TestCaseBuilder withChunkingSettingsMap(Map chunkingSettingsMap) {
+ this.chunkingSettingsMap = chunkingSettingsMap;
+ this.validateChunkingSettings = true;
+ return this;
+ }
+
+ /**
+ * Use an empty chunking settings map but still do validation that the chunking settings are set to the appropriate
+ * defaults.
+ */
+ public TestCaseBuilder withEmptyChunkingSettingsMap() {
+ this.chunkingSettingsMap = Map.of();
+ this.validateChunkingSettings = true;
+ return this;
+ }
+
+ public TestCaseBuilder withNullChunkingSettingsMap() {
+ this.chunkingSettingsMap = null;
+ this.validateChunkingSettings = true;
+ return this;
+ }
+
+ public TestCase build() {
+ return new TestCase(
+ description,
+ createPersistedConfig,
+ serviceCallback,
+ chunkingSettingsMap,
+ validateChunkingSettings,
+ modelIncludesSecrets
+ );
+ }
+ }
+
public void testPersistedConfig() throws Exception {
var parseConfigTestConfig = testConfiguration.commonConfig;
var persistedConfig = testCase.createPersistedConfig.apply(testConfiguration);
try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
-
var model = testCase.serviceCallback.apply(service, persistedConfig);
- var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(
- testCase.chunkingSettingsMap == null ? Map.of() : testCase.chunkingSettingsMap
- );
- assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
+ if (testCase.validateChunkingSettings) {
+ var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(testCase.chunkingSettingsMap);
+ assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
+ }
parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
}
-
- parseConfigHelper(service -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfigMap.config()), null);
}
// parsePersistedConfig tests
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
index 5fb878403f05a..50ca2f72abda0 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
@@ -110,8 +110,8 @@ protected Map createSecretSettingsMap() {
}
@Override
- protected void assertModel(Model model, TaskType taskType) {
- Ai21ServiceTests.assertModel(model, taskType);
+ protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
+ Ai21ServiceTests.assertModel(model, taskType, modelIncludesSecrets);
}
@Override
@@ -122,32 +122,35 @@ protected EnumSet supportedStreamingTasks() {
).build();
}
- private static void assertModel(Model model, TaskType taskType) {
+ private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
switch (taskType) {
- case COMPLETION -> assertCompletionModel(model);
- case CHAT_COMPLETION -> assertChatCompletionModel(model);
+ case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets);
+ case CHAT_COMPLETION -> assertChatCompletionModel(model, modelIncludesSecrets);
default -> fail("unexpected task type [" + taskType + "]");
}
}
- private static Ai21Model assertCommonModelFields(Model model) {
+ private static Ai21Model assertCommonModelFields(Model model, boolean modelIncludesSecrets) {
assertThat(model, instanceOf(Ai21Model.class));
var customModel = (Ai21Model) model;
assertThat(customModel.uri.toString(), Matchers.is("https://api.ai21.com/studio/v1/chat/completions"));
assertThat(customModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE));
- assertThat(customModel.getSecretSettings().apiKey(), Matchers.is(new SecureString("secret".toCharArray())));
+
+ if (modelIncludesSecrets) {
+ assertThat(customModel.getSecretSettings().apiKey(), Matchers.is(new SecureString("secret".toCharArray())));
+ }
return customModel;
}
- private static void assertCompletionModel(Model model) {
- var customModel = assertCommonModelFields(model);
+ private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) {
+ var customModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(customModel.getTaskType(), Matchers.is(TaskType.COMPLETION));
}
- private static void assertChatCompletionModel(Model model) {
- var customModel = assertCommonModelFields(model);
+ private static void assertChatCompletionModel(Model model, boolean modelIncludesSecrets) {
+ var customModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(customModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION));
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
index ec303f0c7b796..269642812e78f 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
@@ -96,8 +96,8 @@ protected Map createSecretSettingsMap() {
}
@Override
- protected void assertModel(Model model, TaskType taskType) {
- CustomServiceTests.assertModel(model, taskType);
+ protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
+ CustomServiceTests.assertModel(model, taskType, modelIncludesSecrets);
}
@Override
@@ -112,38 +112,40 @@ protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure)
}).build();
}
- private static void assertModel(Model model, TaskType taskType) {
+ private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
switch (taskType) {
- case TEXT_EMBEDDING -> assertTextEmbeddingModel(model);
- case COMPLETION -> assertCompletionModel(model);
+ case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets);
+ case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets);
default -> fail("unexpected task type [" + taskType + "]");
}
}
- private static void assertTextEmbeddingModel(Model model) {
- var customModel = assertCommonModelFields(model);
+ private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) {
+ var customModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(customModel.getTaskType(), is(TaskType.TEXT_EMBEDDING));
assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(TextEmbeddingResponseParser.class));
}
- private static CustomModel assertCommonModelFields(Model model) {
+ private static CustomModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) {
assertThat(model, instanceOf(CustomModel.class));
var customModel = (CustomModel) model;
assertThat(customModel.getServiceSettings().getUrl(), is("http://www.abc.com"));
assertThat(customModel.getTaskSettings().getParameters(), is(Map.of("test_key", "test_value")));
- assertThat(
- customModel.getSecretSettings().getSecretParameters(),
- is(Map.of("test_key", new SecureString("test_value".toCharArray())))
- );
+ if (modelIncludesSecrets) {
+ assertThat(
+ customModel.getSecretSettings().getSecretParameters(),
+ is(Map.of("test_key", new SecureString("test_value".toCharArray())))
+ );
+ }
return customModel;
}
- private static void assertCompletionModel(Model model) {
- var customModel = assertCommonModelFields(model);
+ private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) {
+ var customModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(customModel.getTaskType(), is(TaskType.COMPLETION));
assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(CompletionResponseParser.class));
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
index 5d81c8b062492..a6773122a0d2b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
@@ -130,8 +130,8 @@ protected Map createSecretSettingsMap() {
}
@Override
- protected void assertModel(Model model, TaskType taskType) {
- LlamaServiceTests.assertModel(model, taskType);
+ protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
+ LlamaServiceTests.assertModel(model, taskType, modelIncludesSecrets);
}
@Override
@@ -146,43 +146,46 @@ protected LlamaEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarity
}).build();
}
- private static void assertModel(Model model, TaskType taskType) {
+ private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
switch (taskType) {
- case TEXT_EMBEDDING -> assertTextEmbeddingModel(model);
- case COMPLETION -> assertCompletionModel(model);
- case CHAT_COMPLETION -> assertChatCompletionModel(model);
+ case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets);
+ case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets);
+ case CHAT_COMPLETION -> assertChatCompletionModel(model, modelIncludesSecrets);
default -> fail("unexpected task type [" + taskType + "]");
}
}
- private static void assertTextEmbeddingModel(Model model) {
- var llamaModel = assertCommonModelFields(model);
+ private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) {
+ var llamaModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.TEXT_EMBEDDING));
}
- private static LlamaModel assertCommonModelFields(Model model) {
+ private static LlamaModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) {
assertThat(model, instanceOf(LlamaModel.class));
var llamaModel = (LlamaModel) model;
assertThat(llamaModel.getServiceSettings().modelId(), is("model_id"));
assertThat(llamaModel.uri.toString(), Matchers.is("http://www.abc.com"));
assertThat(llamaModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE));
- assertThat(
- ((DefaultSecretSettings) llamaModel.getSecretSettings()).apiKey(),
- Matchers.is(new SecureString("secret".toCharArray()))
- );
+
+ if (modelIncludesSecrets) {
+ assertThat(
+ ((DefaultSecretSettings) llamaModel.getSecretSettings()).apiKey(),
+ Matchers.is(new SecureString("secret".toCharArray()))
+ );
+ }
return llamaModel;
}
- private static void assertCompletionModel(Model model) {
- var llamaModel = assertCommonModelFields(model);
+ private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) {
+ var llamaModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.COMPLETION));
}
- private static void assertChatCompletionModel(Model model) {
- var llamaModel = assertCommonModelFields(model);
+ private static void assertChatCompletionModel(Model model, boolean modelIncludesSecrets) {
+ var llamaModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION));
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceParameterizedTests.java
new file mode 100644
index 0000000000000..941dab6699575
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceParameterizedTests.java
@@ -0,0 +1,18 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.openai;
+
+import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests;
+
+import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceTests.createTestConfiguration;
+
+public class OpenAiServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests {
+ public OpenAiServiceParameterizedTests(AbstractInferenceServiceParameterizedTests.TestCase testCase) {
+ super(createTestConfiguration(), testCase);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
index 705afe78d3196..5eceffe200bc4 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
@@ -91,7 +91,6 @@
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
-import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
@@ -151,7 +150,7 @@ public OpenAiServiceTests(TestCase testCase) {
super(createTestConfiguration(), testCase);
}
- private static TestConfiguration createTestConfiguration() {
+ public static TestConfiguration createTestConfiguration() {
return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.RERANK) {
@Override
protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
@@ -179,8 +178,8 @@ protected Map createSecretSettingsMap() {
}
@Override
- protected void assertModel(Model model, TaskType taskType) {
- OpenAiServiceTests.assertModel(model, taskType);
+ protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
+ OpenAiServiceTests.assertModel(model, taskType, modelIncludesSecrets);
}
@Override
@@ -210,14 +209,7 @@ private static Map createServiceSettingsMap(TaskType taskType, C
);
if (taskType == TaskType.TEXT_EMBEDDING) {
- settingsMap.putAll(
- Map.of(
- ServiceFields.SIMILARITY,
- SIMILARITY.toString(),
- ServiceFields.DIMENSIONS,
- DIMENSIONS
- )
- );
+ settingsMap.putAll(Map.of(ServiceFields.SIMILARITY, SIMILARITY.toString(), ServiceFields.DIMENSIONS, DIMENSIONS));
if (parseContext == ConfigurationParseContext.PERSISTENT) {
settingsMap.put(OpenAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, DIMENSIONS_SET_BY_USER);
@@ -231,15 +223,15 @@ private static Map createTaskSettingsMap() {
return new HashMap<>(Map.of(OpenAiServiceFields.USER, USER, OpenAiServiceFields.HEADERS, HEADERS));
}
- private static void assertModel(Model model, TaskType taskType) {
+ private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
switch (taskType) {
- case TEXT_EMBEDDING -> assertTextEmbeddingModel(model);
- case COMPLETION, CHAT_COMPLETION -> assertCompletionModel(model);
+ case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets);
+ case COMPLETION, CHAT_COMPLETION -> assertCompletionModel(model, modelIncludesSecrets);
default -> fail("unexpected task type: " + taskType);
}
}
- private static void assertTextEmbeddingModel(Model model) {
+ private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) {
assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
var embeddingsModel = (OpenAiEmbeddingsModel) model;
@@ -260,10 +252,14 @@ private static void assertTextEmbeddingModel(Model model) {
);
assertThat(embeddingsModel.getTaskSettings(), is(new OpenAiEmbeddingsTaskSettings(USER, HEADERS)));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(SECRET));
+ if (modelIncludesSecrets) {
+ assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(SECRET));
+ } else {
+ assertNull(embeddingsModel.getSecretSettings());
+ }
}
- private static void assertCompletionModel(Model model) {
+ private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) {
assertThat(model, instanceOf(OpenAiChatCompletionModel.class));
var completionModel = (OpenAiChatCompletionModel) model;
@@ -282,7 +278,14 @@ private static void assertCompletionModel(Model model) {
);
assertThat(completionModel.getTaskSettings(), is(new OpenAiChatCompletionTaskSettings(USER, HEADERS)));
- assertThat(completionModel.getSecretSettings().apiKey().toString(), is(SECRET));
+
+ assertSecrets(completionModel.getSecretSettings(), modelIncludesSecrets);
+ }
+
+ private static void assertSecrets(DefaultSecretSettings secretSettings, boolean modelIncludesSecrets) {
+ if (modelIncludesSecrets) {
+ assertThat(secretSettings.apiKey().toString(), is(SECRET));
+ }
}
private static OpenAiEmbeddingsModel createInternalEmbeddingModel(
@@ -343,49 +346,6 @@ public void testParseRequestConfig_MovesModel() throws IOException {
}
}
- public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", null, null, null, null, true),
- getOpenAiTaskSettingsMap(null),
- createRandomChunkingSettingsMap()
- );
-
- var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
- assertNull(embeddingsModel.getSecretSettings());
- }
- }
-
- public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", null, null, null, null, true),
- getOpenAiTaskSettingsMap(null)
- );
-
- var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
- assertNull(embeddingsModel.getSecretSettings());
- }
- }
-
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
try (var service = createOpenAiService()) {
var persistedConfig = getPersistedConfigMap(
From 399a50b5754bd2f7c8008e88415d9b211773d59b Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Thu, 25 Sep 2025 13:04:57 -0400
Subject: [PATCH 03/10] Working tests
---
.../inference/services/ai21/Ai21Service.java | 10 +-
.../services/custom/CustomService.java | 23 +-
.../services/llama/LlamaService.java | 4 +-
.../AbstractInferenceServiceBaseTests.java | 173 ++++++
...actInferenceServiceParameterizedTests.java | 550 ++++++++----------
.../AbstractInferenceServiceTests.java | 537 ++---------------
.../ai21/Ai21ServiceParameterizedTests.java | 18 +
.../services/ai21/Ai21ServiceTests.java | 26 +-
.../CustomServiceParameterizedTests.java | 16 +
.../services/custom/CustomServiceTests.java | 78 +--
.../llama/LlamaServiceParameterizedTests.java | 16 +
.../services/llama/LlamaServiceTests.java | 65 ++-
.../services/openai/OpenAiServiceTests.java | 131 ++---
13 files changed, 638 insertions(+), 1009 deletions(-)
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceParameterizedTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceParameterizedTests.java
create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceParameterizedTests.java
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java
index 438d31d8dd411..64437685af12b 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java
@@ -213,7 +213,7 @@ public Ai21Model parsePersistedConfigWithSecrets(
taskType,
serviceSettingsMap,
secretSettingsMap,
- parsePersistedConfigErrorMsg(modelId, NAME)
+ parsePersistedConfigErrorMsg(modelId, NAME, taskType)
);
}
@@ -222,7 +222,13 @@ public Ai21Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
- return createModelFromPersistent(modelId, taskType, serviceSettingsMap, null, parsePersistedConfigErrorMsg(modelId, NAME));
+ return createModelFromPersistent(
+ modelId,
+ taskType,
+ serviceSettingsMap,
+ null,
+ parsePersistedConfigErrorMsg(modelId, NAME, taskType)
+ );
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
index 7cd069ac2e3e0..f9e1dba847dc7 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
@@ -105,7 +105,10 @@ public void parseRequestConfig(
Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
- var chunkingSettings = extractChunkingSettings(config, taskType);
+ ChunkingSettings chunkingSettings = null;
+ if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
+ chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
+ }
CustomModel model = createModel(
inferenceEntityId,
@@ -156,14 +159,6 @@ private static RequestParameters createParameters(CustomModel model) {
};
}
- private static ChunkingSettings extractChunkingSettings(Map config, TaskType taskType) {
- if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
- return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
- }
-
- return null;
- }
-
@Override
public InferenceServiceConfiguration getConfiguration() {
return Configuration.get();
@@ -229,7 +224,10 @@ public CustomModel parsePersistedConfigWithSecrets(
Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
- var chunkingSettings = extractChunkingSettings(config, taskType);
+ ChunkingSettings chunkingSettings = null;
+ if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
+ chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
+ }
return createModelWithoutLoggingDeprecations(
inferenceEntityId,
@@ -246,7 +244,10 @@ public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskT
Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
- var chunkingSettings = extractChunkingSettings(config, taskType);
+ ChunkingSettings chunkingSettings = null;
+ if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
+ chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
+ }
return createModelWithoutLoggingDeprecations(
inferenceEntityId,
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java
index 829dbe0a18955..c13026c36dd73 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java
@@ -318,7 +318,7 @@ public Model parsePersistedConfigWithSecrets(
serviceSettingsMap,
chunkingSettings,
secretSettingsMap,
- parsePersistedConfigErrorMsg(modelId, NAME)
+ parsePersistedConfigErrorMsg(modelId, NAME, taskType)
);
}
@@ -357,7 +357,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map supportedTaskTypes;
+
+ public CommonConfig(TaskType targetTaskType, @Nullable TaskType unsupportedTaskType, EnumSet supportedTaskTypes) {
+ this.targetTaskType = Objects.requireNonNull(targetTaskType);
+ this.unsupportedTaskType = unsupportedTaskType;
+ this.supportedTaskTypes = Objects.requireNonNull(supportedTaskTypes);
+ }
+
+ public TaskType targetTaskType() {
+ return targetTaskType;
+ }
+
+ public TaskType unsupportedTaskType() {
+ return unsupportedTaskType;
+ }
+
+ public EnumSet supportedTaskTypes() {
+ return supportedTaskTypes;
+ }
+
+ protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager);
+
+ protected abstract Map createServiceSettingsMap(TaskType taskType);
+
+ protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) {
+ return createServiceSettingsMap(taskType);
+ }
+
+ protected abstract Map createTaskSettingsMap();
+
+ protected abstract Map createSecretSettingsMap();
+
+ protected abstract void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets);
+
+ protected void assertModel(Model model, TaskType taskType) {
+ assertModel(model, taskType, true);
+ }
+
+ protected abstract EnumSet supportedStreamingTasks();
+
+ /**
+ * Override this method if the service support reranking. This method won't be called if the service doesn't support reranking.
+ */
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ fail("Reranking services should override this test method to verify window size");
+ }
+ }
+
+ /**
+ * Configurations specific to the {@link SenderService#updateModelWithEmbeddingDetails(Model, int)} tests
+ */
+ public abstract static class UpdateModelConfiguration {
+
+ public boolean isEnabled() {
+ return true;
+ }
+
+ protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure);
+ }
+
+ private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() {
+ @Override
+ public boolean isEnabled() {
+ return false;
+ }
+
+ @Override
+ protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) {
+ throw new UnsupportedOperationException("Update model tests are disabled");
+ }
+ };
+
+ @Override
+ public InferenceService createInferenceService() {
+ return testConfiguration.commonConfig.createService(threadPool, clientManager);
+ }
+
+ @Override
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ testConfiguration.commonConfig.assertRerankerWindowSize(rerankingInferenceService);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java
index f6fdefb1bac6f..37ba073407df8 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java
@@ -10,80 +10,35 @@
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.elasticsearch.ElasticsearchStatusException;
-import org.elasticsearch.action.support.PlainActionFuture;
-import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.InferenceService;
-import org.elasticsearch.inference.InferenceServiceResults;
-import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
-import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.test.http.MockWebServer;
-import org.elasticsearch.threadpool.ThreadPool;
-import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
-import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
-import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
-import org.junit.After;
import org.junit.Assume;
-import org.junit.Before;
import java.io.IOException;
import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
import java.util.Map;
-import java.util.Objects;
import java.util.function.Function;
-import static org.elasticsearch.xpack.inference.Utils.TIMEOUT;
-import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
-import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
-import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
-import static org.mockito.Mockito.mock;
/**
* Base class for testing inference services using parameterized tests.
*/
-public abstract class AbstractInferenceServiceParameterizedTests extends InferenceServiceTestCase {
-
- private final AbstractInferenceServiceTests.TestConfiguration testConfiguration;
-
- protected final MockWebServer webServer = new MockWebServer();
- protected ThreadPool threadPool;
- protected HttpClientManager clientManager;
- protected TestCase testCase;
-
- @Override
- @Before
- public void setUp() throws Exception {
- super.setUp();
- webServer.start();
- threadPool = createThreadPool(inferenceUtilityExecutors());
- clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
- }
-
- @Override
- @After
- public void tearDown() throws Exception {
- super.tearDown();
- clientManager.close();
- terminate(threadPool);
- webServer.close();
- }
+public abstract class AbstractInferenceServiceParameterizedTests extends AbstractInferenceServiceBaseTests {
public AbstractInferenceServiceParameterizedTests(
- AbstractInferenceServiceTests.TestConfiguration testConfiguration,
+ AbstractInferenceServiceBaseTests.TestConfiguration testConfiguration,
TestCase testCase
) {
- this.testConfiguration = Objects.requireNonNull(testConfiguration);
+ super(testConfiguration);
this.testCase = testCase;
}
@@ -92,11 +47,66 @@ public InferenceService createInferenceService() {
return testConfiguration.commonConfig().createService(threadPool, clientManager);
}
+ public record TestCase(
+ String description,
+ Function createPersistedConfig,
+ ServiceParser serviceParser,
+ TaskType expectedTaskType,
+ boolean modelIncludesSecrets,
+ boolean expectFailure
+ ) {}
+
+ private record ServiceParserParams(
+ SenderService service,
+ Utils.PersistedConfig persistedConfig,
+ AbstractInferenceServiceBaseTests.TestConfiguration testConfiguration
+ ) {}
+
+ @FunctionalInterface
+ private interface ServiceParser {
+ Model parseConfigs(ServiceParserParams params);
+ }
+
+ private static class TestCaseBuilder {
+ private final String description;
+ private final Function createPersistedConfig;
+ private final ServiceParser serviceParser;
+ private final TaskType expectedTaskType;
+ private boolean modelIncludesSecrets;
+ private boolean expectFailure;
+
+ TestCaseBuilder(
+ String description,
+ Function createPersistedConfig,
+ ServiceParser serviceParser,
+ TaskType expectedTaskType
+ ) {
+ this.description = description;
+ this.createPersistedConfig = createPersistedConfig;
+ this.serviceParser = serviceParser;
+ this.expectedTaskType = expectedTaskType;
+ }
+
+ public TestCaseBuilder withSecrets() {
+ this.modelIncludesSecrets = true;
+ return this;
+ }
+
+ public TestCaseBuilder expectFailure() {
+ this.expectFailure = true;
+ return this;
+ }
+
+ public TestCase build() {
+ return new TestCase(description, createPersistedConfig, serviceParser, expectedTaskType, modelIncludesSecrets, expectFailure);
+ }
+ }
+
@ParametersFactory
public static Iterable parameters() throws IOException {
return Arrays.asList(
new TestCase[][] {
- // parsePersistedConfig
+ // Test cases for parsePersistedConfig method
{
new TestCaseBuilder(
"Test parsing persisted config without chunking settings",
@@ -106,11 +116,7 @@ public static Iterable parameters() throws IOException {
testConfiguration.commonConfig().createTaskSettingsMap(),
null
),
- (service, persistedConfig, testConfiguration) -> service.parsePersistedConfig(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config()
- ),
+ (params) -> params.service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, params.persistedConfig.config()),
TaskType.TEXT_EMBEDDING
).build() },
{
@@ -123,14 +129,56 @@ public static Iterable parameters() throws IOException {
createRandomChunkingSettingsMap(),
null
),
- (service, persistedConfig, testConfiguration) -> service.parsePersistedConfig(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config()
- ),
+ (params) -> params.service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, params.persistedConfig.config()),
TaskType.TEXT_EMBEDDING
).build() },
- // parsePersistedConfigWithSecrets
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config does not throw when an extra key exists in config",
+ testConfiguration -> {
+ var persistedConfigMap = getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ null
+ );
+ persistedConfigMap.config().put("extra_key", "value");
+ return persistedConfigMap;
+ },
+ (params) -> params.service.parsePersistedConfig("id", TaskType.COMPLETION, params.persistedConfig.config()),
+ TaskType.COMPLETION
+ ).build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config does not throw when an extra key exists in service settings",
+ testConfiguration -> {
+ var serviceSettings = testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT);
+ serviceSettings.put("extra_key", "value");
+
+ return getPersistedConfigMap(serviceSettings, testConfiguration.commonConfig().createTaskSettingsMap(), null);
+ },
+ (params) -> params.service.parsePersistedConfig("id", TaskType.COMPLETION, params.persistedConfig.config()),
+ TaskType.COMPLETION
+ ).build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config does not throw when an extra key exists in task settings",
+ testConfiguration -> {
+ var taskSettingsMap = testConfiguration.commonConfig().createTaskSettingsMap();
+ taskSettingsMap.put("extra_key", "value");
+
+ return getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
+ taskSettingsMap,
+ null
+ );
+ },
+ (params) -> params.service.parsePersistedConfig("id", TaskType.COMPLETION, params.persistedConfig.config()),
+ TaskType.COMPLETION
+ ).build() },
+ // Test cases for parsePersistedConfigWithSecrets method
{
new TestCaseBuilder(
"Test parsing persisted config with secrets creates an embeddings model",
@@ -140,11 +188,11 @@ public static Iterable parameters() throws IOException {
testConfiguration.commonConfig().createTaskSettingsMap(),
testConfiguration.commonConfig().createSecretSettingsMap()
),
- (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets(
+ (params) -> params.service.parsePersistedConfigWithSecrets(
"id",
TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
),
TaskType.TEXT_EMBEDDING
).withSecrets().build() },
@@ -159,11 +207,11 @@ public static Iterable parameters() throws IOException {
createRandomChunkingSettingsMap(),
testConfiguration.commonConfig().createSecretSettingsMap()
),
- (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets(
+ (params) -> params.service.parsePersistedConfigWithSecrets(
"id",
TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
),
TaskType.TEXT_EMBEDDING
).withSecrets().build() },
@@ -177,11 +225,11 @@ public static Iterable parameters() throws IOException {
testConfiguration.commonConfig().createTaskSettingsMap(),
testConfiguration.commonConfig().createSecretSettingsMap()
),
- (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets(
+ (params) -> params.service.parsePersistedConfigWithSecrets(
"id",
TaskType.COMPLETION,
- persistedConfig.config(),
- persistedConfig.secrets()
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
),
TaskType.COMPLETION
).withSecrets().build() },
@@ -194,276 +242,146 @@ public static Iterable parameters() throws IOException {
testConfiguration.commonConfig().createTaskSettingsMap(),
testConfiguration.commonConfig().createSecretSettingsMap()
),
- (service, persistedConfig, testConfiguration) -> service.parsePersistedConfigWithSecrets(
+ (params) -> params.service.parsePersistedConfigWithSecrets(
"id",
- testConfiguration.commonConfig().unsupportedTaskType(),
- persistedConfig.config(),
- persistedConfig.secrets()
+ params.testConfiguration.commonConfig().unsupportedTaskType(),
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
+ ),
+ TaskType.COMPLETION
+ ).withSecrets().expectFailure().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with with secrets does not throw when an extra key exists in config",
+ testConfiguration -> {
+ var persistedConfigMap = getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ testConfiguration.commonConfig().createSecretSettingsMap()
+ );
+ persistedConfigMap.config().put("extra_key", "value");
+ return persistedConfigMap;
+ },
+ (params) -> params.service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.COMPLETION,
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
+ ),
+ TaskType.COMPLETION
+ ).withSecrets().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with with secrets does not throw when an extra key exists in service settings",
+ testConfiguration -> {
+ var serviceSettings = testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT);
+ serviceSettings.put("extra_key", "value");
+
+ return getPersistedConfigMap(
+ serviceSettings,
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ testConfiguration.commonConfig().createSecretSettingsMap()
+ );
+ },
+ (params) -> params.service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.COMPLETION,
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
+ ),
+ TaskType.COMPLETION
+ ).withSecrets().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with with secrets does not throw when an extra key exists in task settings",
+ testConfiguration -> {
+ var taskSettingsMap = testConfiguration.commonConfig().createTaskSettingsMap();
+ taskSettingsMap.put("extra_key", "value");
+
+ return getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
+ taskSettingsMap,
+ testConfiguration.commonConfig().createSecretSettingsMap()
+ );
+ },
+ (params) -> params.service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.COMPLETION,
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
+ ),
+ TaskType.COMPLETION
+ ).withSecrets().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with with secrets does not throw when an extra key exists in secret settings",
+ testConfiguration -> {
+ var secretSettingsMap = testConfiguration.commonConfig().createSecretSettingsMap();
+ secretSettingsMap.put("extra_key", "value");
+
+ return getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ secretSettingsMap
+ );
+ },
+ (params) -> params.service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.COMPLETION,
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
),
TaskType.COMPLETION
).withSecrets().build() } }
);
}
- public record TestCase(
- String description,
- Function createPersistedConfig,
- ServiceCallback serviceCallback,
- TaskType expectedTaskType,
- boolean modelIncludesSecrets,
- boolean expectFailure
- ) {}
-
- @FunctionalInterface
- interface ServiceCallback {
- Model parseConfigs(
- SenderService service,
- Utils.PersistedConfig persistedConfig,
- AbstractInferenceServiceTests.TestConfiguration testConfiguration
- );
- }
-
- private static class TestCaseBuilder {
- private final String description;
- private final Function createPersistedConfig;
- private final ServiceCallback serviceCallback;
- private final TaskType expectedTaskType;
- private boolean modelIncludesSecrets;
- private boolean expectFailure;
-
- TestCaseBuilder(
- String description,
- Function createPersistedConfig,
- ServiceCallback serviceCallback,
- TaskType expectedTaskType
- ) {
- this.description = description;
- this.createPersistedConfig = createPersistedConfig;
- this.serviceCallback = serviceCallback;
- this.expectedTaskType = expectedTaskType;
- }
-
- public TestCaseBuilder withSecrets() {
- this.modelIncludesSecrets = true;
- return this;
- }
-
- public TestCaseBuilder withFailure() {
- this.expectFailure = true;
- return this;
- }
-
- public TestCase build() {
- return new TestCase(description, createPersistedConfig, serviceCallback, expectedTaskType, modelIncludesSecrets, expectFailure);
- }
- }
-
public void testPersistedConfig() throws Exception {
+ // If the service doesn't support the expected task type, then skip the test
+ Assume.assumeTrue(testConfiguration.commonConfig().supportedTaskTypes().contains(testCase.expectedTaskType));
+
var parseConfigTestConfig = testConfiguration.commonConfig();
var persistedConfig = testCase.createPersistedConfig.apply(testConfiguration);
try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var model = testCase.serviceCallback.parseConfigs(service, persistedConfig, testConfiguration);
- if (persistedConfig.config().containsKey(ModelConfigurations.CHUNKING_SETTINGS)) {
- @SuppressWarnings("unchecked")
- var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(
- (Map) persistedConfig.config().get(ModelConfigurations.CHUNKING_SETTINGS)
- );
- assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
+ if (testCase.expectFailure) {
+ assertFailedParse(service, persistedConfig);
+ } else {
+ assertSuccessfulParse(service, persistedConfig);
}
-
- parseConfigTestConfig.assertModel(model, testCase.expectedTaskType, testCase.modelIncludesSecrets);
- }
- }
-
- // parsePersistedConfigWithSecrets
-
- public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() throws Exception {
- var parseConfigTestConfig = testConfiguration.commonConfig();
-
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT),
- parseConfigTestConfig.createTaskSettingsMap(),
- parseConfigTestConfig.createSecretSettingsMap()
- );
-
- var exception = expectThrows(
- ElasticsearchStatusException.class,
- () -> service.parsePersistedConfigWithSecrets(
- "id",
- parseConfigTestConfig.unsupportedTaskType(),
- persistedConfigMap.config(),
- persistedConfigMap.secrets()
- )
- );
-
- assertThat(
- exception.getMessage(),
- containsString(
- Strings.format(fetchPersistedConfigTaskTypeParsingErrorMessageFormat(), parseConfigTestConfig.unsupportedTaskType())
- )
- );
- }
- }
-
- protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() {
- return "service does not support task type [%s]";
- }
-
- public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
- var parseConfigTestConfig = testConfiguration.commonConfig();
-
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT),
- parseConfigTestConfig.createTaskSettingsMap(),
- parseConfigTestConfig.createSecretSettingsMap()
- );
- persistedConfigMap.config().put("extra_key", "value");
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- parseConfigTestConfig.taskType(),
- persistedConfigMap.config(),
- persistedConfigMap.secrets()
- );
-
- parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType());
- }
- }
-
- public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException {
- var parseConfigTestConfig = testConfiguration.commonConfig();
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var serviceSettings = parseConfigTestConfig.createServiceSettingsMap(
- parseConfigTestConfig.taskType(),
- ConfigurationParseContext.PERSISTENT
- );
- serviceSettings.put("extra_key", "value");
- var persistedConfigMap = getPersistedConfigMap(
- serviceSettings,
- parseConfigTestConfig.createTaskSettingsMap(),
- parseConfigTestConfig.createSecretSettingsMap()
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- parseConfigTestConfig.taskType(),
- persistedConfigMap.config(),
- persistedConfigMap.secrets()
- );
-
- parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType());
}
}
- public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException {
- var parseConfigTestConfig = testConfiguration.commonConfig();
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var taskSettings = parseConfigTestConfig.createTaskSettingsMap();
- taskSettings.put("extra_key", "value");
- var config = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT),
- taskSettings,
- parseConfigTestConfig.createSecretSettingsMap()
- );
-
- var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType(), config.config(), config.secrets());
-
- parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType());
- }
- }
-
- public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException {
- var parseConfigTestConfig = testConfiguration.commonConfig();
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var secretSettingsMap = parseConfigTestConfig.createSecretSettingsMap();
- secretSettingsMap.put("extra_key", "value");
- var config = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType(), ConfigurationParseContext.PERSISTENT),
- parseConfigTestConfig.createTaskSettingsMap(),
- secretSettingsMap
- );
-
- var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType(), config.config(), config.secrets());
-
- parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType());
- }
- }
-
- public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException {
- try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
- var listener = new PlainActionFuture();
-
- service.infer(
- getInvalidModel("id", "service"),
- null,
- null,
- null,
- List.of(""),
- false,
- new HashMap<>(),
- InputType.INTERNAL_SEARCH,
- InferenceAction.Request.DEFAULT_TIMEOUT,
- listener
- );
+ private void assertFailedParse(SenderService service, Utils.PersistedConfig persistedConfig) {
+ var exception = expectThrows(
+ ElasticsearchStatusException.class,
+ () -> testCase.serviceParser.parseConfigs(new ServiceParserParams(service, persistedConfig, testConfiguration))
+ );
- var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
- assertThat(
- exception.getMessage(),
- is("The internal model was invalid, please delete the service [service] with id [id] and add it again.")
- );
- }
+ assertThat(
+ exception.getMessage(),
+ containsString(
+ Strings.format("service does not support task type [%s]", testConfiguration.commonConfig().unsupportedTaskType())
+ )
+ );
}
- public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
- Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled());
+ private void assertSuccessfulParse(SenderService service, Utils.PersistedConfig persistedConfig) throws Exception {
+ var model = testCase.serviceParser.parseConfigs(new ServiceParserParams(service, persistedConfig, testConfiguration));
- try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
- var exception = expectThrows(
- ElasticsearchStatusException.class,
- () -> service.updateModelWithEmbeddingDetails(getInvalidModel("id", "service"), randomNonNegativeInt())
+ if (persistedConfig.config().containsKey(ModelConfigurations.CHUNKING_SETTINGS)) {
+ @SuppressWarnings("unchecked")
+ var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(
+ (Map) persistedConfig.config().get(ModelConfigurations.CHUNKING_SETTINGS)
);
-
- assertThat(exception.getMessage(), containsString("Can't update embedding details for model"));
+ assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
}
- }
-
- public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
- Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled());
- try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
- var embeddingSize = randomNonNegativeInt();
- var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(null);
-
- Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
-
- assertEquals(SimilarityMeasure.DOT_PRODUCT, updatedModel.getServiceSettings().similarity());
- assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
- }
- }
-
- public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
- Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled());
-
- try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
- var embeddingSize = randomNonNegativeInt();
- var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(SimilarityMeasure.COSINE);
-
- Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
-
- assertEquals(SimilarityMeasure.COSINE, updatedModel.getServiceSettings().similarity());
- assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
- }
- }
-
- // streaming tests
- public void testSupportedStreamingTasks() throws Exception {
- try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
- assertThat(service.supportedStreamingTasks(), is(testConfiguration.commonConfig().supportedStreamingTasks()));
- assertFalse(service.canStream(TaskType.ANY));
- }
+ testConfiguration.commonConfig().assertModel(model, testCase.expectedTaskType, testCase.modelIncludesSecrets);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java
index b1ff02b6bbc17..9479a8ea6e10d 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java
@@ -7,50 +7,29 @@
package org.elasticsearch.xpack.inference.services;
-import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
-
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.support.PlainActionFuture;
-import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
-import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.test.http.MockWebServer;
-import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
-import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
-import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
-import org.junit.After;
import org.junit.Assume;
-import org.junit.Before;
import java.io.IOException;
-import java.util.Arrays;
-import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.Objects;
-import java.util.function.BiFunction;
-import java.util.function.Function;
import static org.elasticsearch.xpack.inference.Utils.TIMEOUT;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
-import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
-import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
-import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
-import static org.mockito.Mockito.mock;
/**
* Base class for testing inference services.
@@ -60,134 +39,16 @@
* To use this class, extend it and pass the constructor a configuration.
*
*/
-public abstract class AbstractInferenceServiceTests extends InferenceServiceTestCase {
-
- private final TestConfiguration testConfiguration;
-
- protected final MockWebServer webServer = new MockWebServer();
- protected ThreadPool threadPool;
- protected HttpClientManager clientManager;
- protected TestCase testCase;
-
- @Override
- @Before
- public void setUp() throws Exception {
- super.setUp();
- webServer.start();
- threadPool = createThreadPool(inferenceUtilityExecutors());
- clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
- }
+public abstract class AbstractInferenceServiceTests extends AbstractInferenceServiceBaseTests {
- @Override
- @After
- public void tearDown() throws Exception {
- super.tearDown();
- clientManager.close();
- terminate(threadPool);
- webServer.close();
- }
-
- public AbstractInferenceServiceTests(TestConfiguration testConfiguration, TestCase testCase) {
- this.testConfiguration = Objects.requireNonNull(testConfiguration);
- this.testCase = testCase;
- }
-
- /**
- * Main configurations for the tests
- */
- public record TestConfiguration(CommonConfig commonConfig, UpdateModelConfiguration updateModelConfiguration) {
- public static class Builder {
- private final CommonConfig commonConfig;
- private UpdateModelConfiguration updateModelConfiguration = DISABLED_UPDATE_MODEL_TESTS;
-
- public Builder(CommonConfig commonConfig) {
- this.commonConfig = commonConfig;
- }
-
- public Builder enableUpdateModelTests(UpdateModelConfiguration updateModelConfiguration) {
- this.updateModelConfiguration = updateModelConfiguration;
- return this;
- }
-
- public TestConfiguration build() {
- return new TestConfiguration(commonConfig, updateModelConfiguration);
- }
- }
- }
-
- /**
- * Configurations that are useful for most tests
- */
- public abstract static class CommonConfig {
-
- private final TaskType taskType;
- private final TaskType unsupportedTaskType;
-
- public CommonConfig(TaskType taskType, @Nullable TaskType unsupportedTaskType) {
- this.taskType = Objects.requireNonNull(taskType);
- this.unsupportedTaskType = unsupportedTaskType;
- }
-
- public TaskType taskType() {
- return taskType;
- }
-
- public TaskType unsupportedTaskType() {
- return unsupportedTaskType;
- }
-
- protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager);
-
- protected abstract Map createServiceSettingsMap(TaskType taskType);
-
- protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) {
- return createServiceSettingsMap(taskType);
- }
-
- protected abstract Map createTaskSettingsMap();
-
- protected abstract Map createSecretSettingsMap();
-
- protected abstract void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets);
-
- protected void assertModel(Model model, TaskType taskType) {
- assertModel(model, taskType, true);
- }
-
- protected abstract EnumSet supportedStreamingTasks();
- }
-
- /**
- * Configurations specific to the {@link SenderService#updateModelWithEmbeddingDetails(Model, int)} tests
- */
- public abstract static class UpdateModelConfiguration {
-
- public boolean isEnabled() {
- return true;
- }
-
- protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure);
- }
-
- private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() {
- @Override
- public boolean isEnabled() {
- return false;
- }
-
- @Override
- protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) {
- throw new UnsupportedOperationException("Update model tests are disabled");
- }
- };
-
- @Override
- public InferenceService createInferenceService() {
- return testConfiguration.commonConfig.createService(threadPool, clientManager);
+ public AbstractInferenceServiceTests(TestConfiguration testConfiguration) {
+ super(testConfiguration);
}
public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws Exception {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ Assume.assumeTrue(testConfiguration.commonConfig().supportedTaskTypes().contains(TaskType.TEXT_EMBEDDING));
+
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var config = getRequestConfigMap(
@@ -207,7 +68,9 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws Exception {
}
public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws Exception {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ Assume.assumeTrue(testConfiguration.commonConfig().supportedTaskTypes().contains(TaskType.TEXT_EMBEDDING));
+
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var chunkingSettingsMap = createRandomChunkingSettingsMap();
@@ -229,7 +92,7 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP
}
public void testParseRequestConfig_CreatesACompletionModel() throws Exception {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var config = getRequestConfigMap(
@@ -246,12 +109,12 @@ public void testParseRequestConfig_CreatesACompletionModel() throws Exception {
}
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws Exception {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var config = getRequestConfigMap(
parseRequestConfigTestConfig.createServiceSettingsMap(
- parseRequestConfigTestConfig.taskType,
+ parseRequestConfigTestConfig.targetTaskType(),
ConfigurationParseContext.REQUEST
),
parseRequestConfigTestConfig.createTaskSettingsMap(),
@@ -259,23 +122,25 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws Exception
);
var listener = new PlainActionFuture();
- service.parseRequestConfig("id", parseRequestConfigTestConfig.unsupportedTaskType, config, listener);
+ service.parseRequestConfig("id", parseRequestConfigTestConfig.unsupportedTaskType(), config, listener);
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
exception.getMessage(),
- containsString(Strings.format("service does not support task type [%s]", parseRequestConfigTestConfig.unsupportedTaskType))
+ containsString(
+ Strings.format("service does not support task type [%s]", parseRequestConfigTestConfig.unsupportedTaskType())
+ )
);
}
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var config = getRequestConfigMap(
parseRequestConfigTestConfig.createServiceSettingsMap(
- parseRequestConfigTestConfig.taskType,
+ parseRequestConfigTestConfig.targetTaskType(),
ConfigurationParseContext.REQUEST
),
parseRequestConfigTestConfig.createTaskSettingsMap(),
@@ -284,7 +149,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I
config.put("extra_key", "value");
var listener = new PlainActionFuture();
- service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener);
+ service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener);
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]"));
@@ -292,10 +157,10 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var serviceSettings = parseRequestConfigTestConfig.createServiceSettingsMap(
- parseRequestConfigTestConfig.taskType,
+ parseRequestConfigTestConfig.targetTaskType(),
ConfigurationParseContext.REQUEST
);
serviceSettings.put("extra_key", "value");
@@ -306,7 +171,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa
);
var listener = new PlainActionFuture();
- service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener);
+ service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener);
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]"));
@@ -314,13 +179,13 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var taskSettings = parseRequestConfigTestConfig.createTaskSettingsMap();
taskSettings.put("extra_key", "value");
var config = getRequestConfigMap(
parseRequestConfigTestConfig.createServiceSettingsMap(
- parseRequestConfigTestConfig.taskType,
+ parseRequestConfigTestConfig.targetTaskType(),
ConfigurationParseContext.REQUEST
),
taskSettings,
@@ -328,7 +193,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap()
);
var listener = new PlainActionFuture();
- service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener);
+ service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener);
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]"));
@@ -336,13 +201,13 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap()
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var secretSettingsMap = parseRequestConfigTestConfig.createSecretSettingsMap();
secretSettingsMap.put("extra_key", "value");
var config = getRequestConfigMap(
parseRequestConfigTestConfig.createServiceSettingsMap(
- parseRequestConfigTestConfig.taskType,
+ parseRequestConfigTestConfig.targetTaskType(),
ConfigurationParseContext.REQUEST
),
parseRequestConfigTestConfig.createTaskSettingsMap(),
@@ -350,339 +215,15 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap
);
var listener = new PlainActionFuture();
- service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener);
+ service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener);
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]"));
}
}
- @ParametersFactory
- public static Iterable parameters() throws IOException {
- var chunkingSettingsMap = createRandomChunkingSettingsMap();
-
- return Arrays.asList(
- new TestCase[][] {
- {
- new TestCaseBuilder(
- "Test parsing persisted config without chunking settings",
- testConfiguration -> getPersistedConfigMap(
- testConfiguration.commonConfig.createServiceSettingsMap(
- TaskType.TEXT_EMBEDDING,
- ConfigurationParseContext.PERSISTENT
- ),
- testConfiguration.commonConfig.createTaskSettingsMap(),
- null
- ),
- (service, persistedConfig) -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config())
- ).withNullChunkingSettingsMap().build() },
- {
- new TestCaseBuilder(
- "Test parsing persisted config with chunking settings",
- testConfiguration -> getPersistedConfigMap(
- testConfiguration.commonConfig.createServiceSettingsMap(
- TaskType.TEXT_EMBEDDING,
- ConfigurationParseContext.PERSISTENT
- ),
- testConfiguration.commonConfig.createTaskSettingsMap(),
- chunkingSettingsMap,
- null
- ),
- (service, persistedConfig) -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config())
- ).withChunkingSettingsMap(chunkingSettingsMap).build() } }
- );
- }
-
- public record TestCase(
- String description,
- Function createPersistedConfig,
- BiFunction serviceCallback,
- @Nullable Map chunkingSettingsMap,
- boolean validateChunkingSettings,
- boolean modelIncludesSecrets
- ) {}
-
- private static class TestCaseBuilder {
- private final String description;
- private final Function createPersistedConfig;
- private final BiFunction serviceCallback;
- @Nullable
- private Map chunkingSettingsMap;
- private boolean validateChunkingSettings;
- private boolean modelIncludesSecrets;
-
- TestCaseBuilder(
- String description,
- Function createPersistedConfig,
- BiFunction serviceCallback
- ) {
- this.description = description;
- this.createPersistedConfig = createPersistedConfig;
- this.serviceCallback = serviceCallback;
- }
-
- public TestCaseBuilder withSecrets() {
- this.modelIncludesSecrets = true;
- return this;
- }
-
- public TestCaseBuilder withChunkingSettingsMap(Map chunkingSettingsMap) {
- this.chunkingSettingsMap = chunkingSettingsMap;
- this.validateChunkingSettings = true;
- return this;
- }
-
- /**
- * Use an empty chunking settings map but still do validation that the chunking settings are set to the appropriate
- * defaults.
- */
- public TestCaseBuilder withEmptyChunkingSettingsMap() {
- this.chunkingSettingsMap = Map.of();
- this.validateChunkingSettings = true;
- return this;
- }
-
- public TestCaseBuilder withNullChunkingSettingsMap() {
- this.chunkingSettingsMap = null;
- this.validateChunkingSettings = true;
- return this;
- }
-
- public TestCase build() {
- return new TestCase(
- description,
- createPersistedConfig,
- serviceCallback,
- chunkingSettingsMap,
- validateChunkingSettings,
- modelIncludesSecrets
- );
- }
- }
-
- public void testPersistedConfig() throws Exception {
- var parseConfigTestConfig = testConfiguration.commonConfig;
- var persistedConfig = testCase.createPersistedConfig.apply(testConfiguration);
-
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var model = testCase.serviceCallback.apply(service, persistedConfig);
-
- if (testCase.validateChunkingSettings) {
- var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(testCase.chunkingSettingsMap);
- assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
- }
-
- parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
- }
- }
-
- // parsePersistedConfig tests
-
- public void testParsePersistedConfig_CreatesAnEmbeddingsModel() throws Exception {
- var parseConfigTestConfig = testConfiguration.commonConfig;
- var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT),
- parseConfigTestConfig.createTaskSettingsMap(),
- null
- );
-
- parseConfigHelper(service -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfigMap.config()), null);
- }
-
- private void parseConfigHelper(Function serviceParseCallback, @Nullable Map chunkingSettingsMap)
- throws Exception {
- var parseConfigTestConfig = testConfiguration.commonConfig;
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
-
- var model = serviceParseCallback.apply(service);
-
- var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap == null ? Map.of() : chunkingSettingsMap);
- assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
-
- parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
- }
- }
-
- // parsePersistedConfigWithSecrets
-
- public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throws Exception {
- var parseConfigTestConfig = testConfiguration.commonConfig;
-
- var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT),
- parseConfigTestConfig.createTaskSettingsMap(),
- parseConfigTestConfig.createSecretSettingsMap()
- );
-
- parseConfigHelper(
- service -> service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfigMap.config(),
- persistedConfigMap.secrets()
- ),
- null
- );
- }
-
- public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsAreProvided() throws Exception {
- var parseConfigTestConfig = testConfiguration.commonConfig;
-
- var chunkingSettingsMap = createRandomChunkingSettingsMap();
- var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT),
- parseConfigTestConfig.createTaskSettingsMap(),
- chunkingSettingsMap,
- parseConfigTestConfig.createSecretSettingsMap()
- );
-
- parseConfigHelper(
- service -> service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfigMap.config(),
- persistedConfigMap.secrets()
- ),
- chunkingSettingsMap
- );
- }
-
- public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws Exception {
- var parseConfigTestConfig = testConfiguration.commonConfig;
-
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
- parseConfigTestConfig.createTaskSettingsMap(),
- parseConfigTestConfig.createSecretSettingsMap()
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.COMPLETION,
- persistedConfigMap.config(),
- persistedConfigMap.secrets()
- );
- parseConfigTestConfig.assertModel(model, TaskType.COMPLETION);
- }
- }
-
- public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() throws Exception {
- var parseConfigTestConfig = testConfiguration.commonConfig;
-
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT),
- parseConfigTestConfig.createTaskSettingsMap(),
- parseConfigTestConfig.createSecretSettingsMap()
- );
-
- var exception = expectThrows(
- ElasticsearchStatusException.class,
- () -> service.parsePersistedConfigWithSecrets(
- "id",
- parseConfigTestConfig.unsupportedTaskType,
- persistedConfigMap.config(),
- persistedConfigMap.secrets()
- )
- );
-
- assertThat(
- exception.getMessage(),
- containsString(
- Strings.format(fetchPersistedConfigTaskTypeParsingErrorMessageFormat(), parseConfigTestConfig.unsupportedTaskType)
- )
- );
- }
- }
-
- protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() {
- return "service does not support task type [%s]";
- }
-
- public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
- var parseConfigTestConfig = testConfiguration.commonConfig;
-
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT),
- parseConfigTestConfig.createTaskSettingsMap(),
- parseConfigTestConfig.createSecretSettingsMap()
- );
- persistedConfigMap.config().put("extra_key", "value");
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- parseConfigTestConfig.taskType,
- persistedConfigMap.config(),
- persistedConfigMap.secrets()
- );
-
- parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
- }
- }
-
- public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException {
- var parseConfigTestConfig = testConfiguration.commonConfig;
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var serviceSettings = parseConfigTestConfig.createServiceSettingsMap(
- parseConfigTestConfig.taskType,
- ConfigurationParseContext.PERSISTENT
- );
- serviceSettings.put("extra_key", "value");
- var persistedConfigMap = getPersistedConfigMap(
- serviceSettings,
- parseConfigTestConfig.createTaskSettingsMap(),
- parseConfigTestConfig.createSecretSettingsMap()
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- parseConfigTestConfig.taskType,
- persistedConfigMap.config(),
- persistedConfigMap.secrets()
- );
-
- parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
- }
- }
-
- public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException {
- var parseConfigTestConfig = testConfiguration.commonConfig;
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var taskSettings = parseConfigTestConfig.createTaskSettingsMap();
- taskSettings.put("extra_key", "value");
- var config = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT),
- taskSettings,
- parseConfigTestConfig.createSecretSettingsMap()
- );
-
- var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets());
-
- parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
- }
- }
-
- public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException {
- var parseConfigTestConfig = testConfiguration.commonConfig;
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var secretSettingsMap = parseConfigTestConfig.createSecretSettingsMap();
- secretSettingsMap.put("extra_key", "value");
- var config = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType, ConfigurationParseContext.PERSISTENT),
- parseConfigTestConfig.createTaskSettingsMap(),
- secretSettingsMap
- );
-
- var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets());
-
- parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
- }
- }
-
public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException {
- try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) {
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
var listener = new PlainActionFuture();
service.infer(
@@ -707,9 +248,9 @@ public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException {
}
public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
- Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled());
+ Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled());
- try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) {
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
var exception = expectThrows(
ElasticsearchStatusException.class,
() -> service.updateModelWithEmbeddingDetails(getInvalidModel("id", "service"), randomNonNegativeInt())
@@ -720,11 +261,11 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO
}
public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
- Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled());
+ Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled());
- try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) {
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
var embeddingSize = randomNonNegativeInt();
- var model = testConfiguration.updateModelConfiguration.createEmbeddingModel(null);
+ var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(null);
Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
@@ -734,11 +275,11 @@ public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel()
}
public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
- Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled());
+ Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled());
- try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) {
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
var embeddingSize = randomNonNegativeInt();
- var model = testConfiguration.updateModelConfiguration.createEmbeddingModel(SimilarityMeasure.COSINE);
+ var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(SimilarityMeasure.COSINE);
Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
@@ -749,8 +290,8 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel
// streaming tests
public void testSupportedStreamingTasks() throws Exception {
- try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) {
- assertThat(service.supportedStreamingTasks(), is(testConfiguration.commonConfig.supportedStreamingTasks()));
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
+ assertThat(service.supportedStreamingTasks(), is(testConfiguration.commonConfig().supportedStreamingTasks()));
assertFalse(service.canStream(TaskType.ANY));
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceParameterizedTests.java
new file mode 100644
index 0000000000000..1254610f32745
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceParameterizedTests.java
@@ -0,0 +1,18 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.ai21;
+
+import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests;
+
+import static org.elasticsearch.xpack.inference.services.ai21.Ai21ServiceTests.createTestConfiguration;
+
+public class Ai21ServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests {
+ public Ai21ServiceParameterizedTests(AbstractInferenceServiceParameterizedTests.TestCase testCase) {
+ super(createTestConfiguration(), testCase);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
index 50ca2f72abda0..8a1f8d806a45b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
@@ -81,13 +81,17 @@ public class Ai21ServiceTests extends AbstractInferenceServiceTests {
private ThreadPool threadPool;
private HttpClientManager clientManager;
- public Ai21ServiceTests(TestCase testCase) {
- super(createTestConfiguration(), testCase);
+ public Ai21ServiceTests() {
+ super(createTestConfiguration());
}
- private static AbstractInferenceServiceTests.TestConfiguration createTestConfiguration() {
+ public static AbstractInferenceServiceTests.TestConfiguration createTestConfiguration() {
return new AbstractInferenceServiceTests.TestConfiguration.Builder(
- new AbstractInferenceServiceTests.CommonConfig(TaskType.COMPLETION, TaskType.TEXT_EMBEDDING) {
+ new AbstractInferenceServiceTests.CommonConfig(
+ TaskType.COMPLETION,
+ TaskType.TEXT_EMBEDDING,
+ EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)
+ ) {
@Override
protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
@@ -167,10 +171,6 @@ private static Map createSecretSettingsMap() {
return new HashMap<>(Map.of("api_key", "secret"));
}
- protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() {
- return "Failed to parse stored model [id] for [ai21] service, please delete and add the service again";
- }
-
@Before
public void init() throws Exception {
webServer.start();
@@ -185,16 +185,6 @@ public void shutdown() throws IOException {
webServer.close();
}
- @Override
- public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() {
- // The Ai21Service does not support Text Embedding, so this test is not applicable.
- }
-
- @Override
- public void testParseRequestConfig_CreatesAnEmbeddingsModel() {
- // The Ai21Service does not support Text Embedding, so this test is not applicable.
- }
-
public void testParseRequestConfig_CreatesChatCompletionsModel() throws IOException {
var url = "url";
var model = "model";
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceParameterizedTests.java
new file mode 100644
index 0000000000000..a71a107f3b9d1
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceParameterizedTests.java
@@ -0,0 +1,16 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests;
+
+public class CustomServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests {
+ public CustomServiceParameterizedTests(TestCase testCase) {
+ super(CustomServiceTests.createTestConfiguration(), testCase);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
index 269642812e78f..44bf17c3ac96d 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
@@ -69,42 +69,56 @@
public class CustomServiceTests extends AbstractInferenceServiceTests {
- public CustomServiceTests(TestCase testCase) {
- super(createTestConfiguration(), testCase);
+ public CustomServiceTests() {
+ super(createTestConfiguration());
}
- private static TestConfiguration createTestConfiguration() {
- return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION) {
- @Override
- protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
- return CustomServiceTests.createService(threadPool, clientManager);
- }
+ public static TestConfiguration createTestConfiguration() {
+ return new TestConfiguration.Builder(
+ new CommonConfig(
+ TaskType.TEXT_EMBEDDING,
+ TaskType.CHAT_COMPLETION,
+ EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.COMPLETION)
+ ) {
+ @Override
+ protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
+ return CustomServiceTests.createService(threadPool, clientManager);
+ }
- @Override
- protected Map createServiceSettingsMap(TaskType taskType) {
- return CustomServiceTests.createServiceSettingsMap(taskType);
- }
+ @Override
+ protected Map createServiceSettingsMap(TaskType taskType) {
+ return CustomServiceTests.createServiceSettingsMap(taskType);
+ }
- @Override
- protected Map createTaskSettingsMap() {
- return CustomServiceTests.createTaskSettingsMap();
- }
+ @Override
+ protected Map createTaskSettingsMap() {
+ return CustomServiceTests.createTaskSettingsMap();
+ }
- @Override
- protected Map createSecretSettingsMap() {
- return CustomServiceTests.createSecretSettingsMap();
- }
+ @Override
+ protected Map createSecretSettingsMap() {
+ return CustomServiceTests.createSecretSettingsMap();
+ }
- @Override
- protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
- CustomServiceTests.assertModel(model, taskType, modelIncludesSecrets);
- }
+ @Override
+ protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
+ CustomServiceTests.assertModel(model, taskType, modelIncludesSecrets);
+ }
- @Override
- protected EnumSet supportedStreamingTasks() {
- return EnumSet.noneOf(TaskType.class);
+ @Override
+ protected EnumSet supportedStreamingTasks() {
+ return EnumSet.noneOf(TaskType.class);
+ }
+
+ @Override
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ assertThat(
+ rerankingInferenceService.rerankerWindowSize("any model"),
+ CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE)
+ );
+ }
}
- }).enableUpdateModelTests(new UpdateModelConfiguration() {
+ ).enableUpdateModelTests(new UpdateModelConfiguration() {
@Override
protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure) {
return createInternalEmbeddingModel(similarityMeasure);
@@ -808,12 +822,4 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
assertThat(requestMap.get("input"), is(List.of("a")));
}
}
-
- @Override
- protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
- assertThat(
- rerankingInferenceService.rerankerWindowSize("any model"),
- CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE)
- );
- }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceParameterizedTests.java
new file mode 100644
index 0000000000000..04afbc28af10a
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceParameterizedTests.java
@@ -0,0 +1,16 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama;
+
+import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests;
+
+public class LlamaServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests {
+ public LlamaServiceParameterizedTests(AbstractInferenceServiceParameterizedTests.TestCase testCase) {
+ super(LlamaServiceTests.createTestConfiguration(), testCase);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
index a6773122a0d2b..f6a0232db529c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
@@ -77,6 +77,9 @@
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
+import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION;
+import static org.elasticsearch.inference.TaskType.COMPLETION;
+import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -102,43 +105,45 @@ public class LlamaServiceTests extends AbstractInferenceServiceTests {
private ThreadPool threadPool;
private HttpClientManager clientManager;
- public LlamaServiceTests(TestCase testCase) {
- super(createTestConfiguration(), testCase);
+ public LlamaServiceTests() {
+ super(createTestConfiguration());
}
- private static TestConfiguration createTestConfiguration() {
- return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) {
+ public static TestConfiguration createTestConfiguration() {
+ return new TestConfiguration.Builder(
+ new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION)) {
- @Override
- protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
- return LlamaServiceTests.createService(threadPool, clientManager);
- }
+ @Override
+ protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
+ return LlamaServiceTests.createService(threadPool, clientManager);
+ }
- @Override
- protected Map createServiceSettingsMap(TaskType taskType) {
- return LlamaServiceTests.createServiceSettingsMap(taskType);
- }
+ @Override
+ protected Map createServiceSettingsMap(TaskType taskType) {
+ return LlamaServiceTests.createServiceSettingsMap(taskType);
+ }
- @Override
- protected Map createTaskSettingsMap() {
- return new HashMap<>();
- }
+ @Override
+ protected Map createTaskSettingsMap() {
+ return new HashMap<>();
+ }
- @Override
- protected Map createSecretSettingsMap() {
- return LlamaServiceTests.createSecretSettingsMap();
- }
+ @Override
+ protected Map createSecretSettingsMap() {
+ return LlamaServiceTests.createSecretSettingsMap();
+ }
- @Override
- protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
- LlamaServiceTests.assertModel(model, taskType, modelIncludesSecrets);
- }
+ @Override
+ protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
+ LlamaServiceTests.assertModel(model, taskType, modelIncludesSecrets);
+ }
- @Override
- protected EnumSet supportedStreamingTasks() {
- return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
+ @Override
+ protected EnumSet supportedStreamingTasks() {
+ return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
+ }
}
- }).enableUpdateModelTests(new UpdateModelConfiguration() {
+ ).enableUpdateModelTests(new UpdateModelConfiguration() {
@Override
protected LlamaEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) {
return createInternalEmbeddingModel(similarityMeasure);
@@ -239,10 +244,6 @@ private static LlamaEmbeddingsModel createInternalEmbeddingModel(@Nullable Simil
);
}
- protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() {
- return "Failed to parse stored model [id] for [llama] service, please delete and add the service again";
- }
-
@Before
public void init() throws Exception {
webServer.start();
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
index 5eceffe200bc4..c2cda175831ed 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
@@ -87,7 +87,6 @@
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
-import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
@@ -146,47 +145,53 @@ public void shutdown() throws IOException {
webServer.close();
}
- public OpenAiServiceTests(TestCase testCase) {
- super(createTestConfiguration(), testCase);
+ public OpenAiServiceTests() {
+ super(createTestConfiguration());
}
public static TestConfiguration createTestConfiguration() {
- return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.RERANK) {
- @Override
- protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
- return OpenAiServiceTests.createService(threadPool, clientManager);
- }
+ return new TestConfiguration.Builder(
+ new CommonConfig(
+ TaskType.TEXT_EMBEDDING,
+ TaskType.RERANK,
+ EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION)
+ ) {
+ @Override
+ protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
+ return OpenAiServiceTests.createService(threadPool, clientManager);
+ }
- @Override
- protected Map createServiceSettingsMap(TaskType taskType) {
- return createServiceSettingsMap(taskType, ConfigurationParseContext.REQUEST);
- }
+ @Override
+ protected Map createServiceSettingsMap(TaskType taskType) {
+ return createServiceSettingsMap(taskType, ConfigurationParseContext.REQUEST);
+ }
- @Override
- protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) {
- return OpenAiServiceTests.createServiceSettingsMap(taskType, parseContext);
- }
+ @Override
+ protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) {
+ return OpenAiServiceTests.createServiceSettingsMap(taskType, parseContext);
+ }
- @Override
- protected Map createTaskSettingsMap() {
- return OpenAiServiceTests.createTaskSettingsMap();
- }
+ @Override
+ protected Map createTaskSettingsMap() {
+ return OpenAiServiceTests.createTaskSettingsMap();
+ }
- @Override
- protected Map createSecretSettingsMap() {
- return getSecretSettingsMap(SECRET);
- }
+ @Override
+ protected Map createSecretSettingsMap() {
+ return getSecretSettingsMap(SECRET);
+ }
- @Override
- protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
- OpenAiServiceTests.assertModel(model, taskType, modelIncludesSecrets);
- }
+ @Override
+ protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
+ OpenAiServiceTests.assertModel(model, taskType, modelIncludesSecrets);
+ }
- @Override
- protected EnumSet supportedStreamingTasks() {
- return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
+ @Override
+ protected EnumSet supportedStreamingTasks() {
+ return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
+ }
}
- }).enableUpdateModelTests(new UpdateModelConfiguration() {
+ ).enableUpdateModelTests(new UpdateModelConfiguration() {
@Override
protected OpenAiEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) {
return createInternalEmbeddingModel(similarityMeasure, null);
@@ -346,68 +351,6 @@ public void testParseRequestConfig_MovesModel() throws IOException {
}
}
- public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org", null, null, true),
- getOpenAiTaskSettingsMap("user")
- );
- persistedConfig.config().put("extra_key", "value");
-
- var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertNull(embeddingsModel.getSecretSettings());
- }
- }
-
- public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException {
- try (var service = createOpenAiService()) {
- var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true);
- serviceSettingsMap.put("extra_key", "value");
-
- var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getOpenAiTaskSettingsMap("user"));
-
- var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertNull(embeddingsModel.getSecretSettings());
- }
- }
-
- public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException {
- try (var service = createOpenAiService()) {
- var taskSettingsMap = getOpenAiTaskSettingsMap("user");
- taskSettingsMap.put("extra_key", "value");
-
- var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org", null, null, true), taskSettingsMap);
-
- var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertNull(embeddingsModel.getSecretSettings());
- }
- }
-
public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException {
var sender = mock(Sender.class);
From 060eed90f38084a2953e1362bc3131224e8e759c Mon Sep 17 00:00:00 2001
From: elasticsearchmachine
Date: Thu, 25 Sep 2025 17:25:48 +0000
Subject: [PATCH 04/10] [CI] Auto commit changes from spotless
---
.../xpack/inference/services/custom/CustomService.java | 4 +++-
.../inference/services/AbstractInferenceServiceBaseTests.java | 2 +-
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
index f9e1dba847dc7..c69284117ef36 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
@@ -107,7 +107,9 @@ public void parseRequestConfig(
ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
- chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
+ chunkingSettings = ChunkingSettingsBuilder.fromMap(
+ removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
+ );
}
CustomModel model = createModel(
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java
index 46fedc1e71f4e..7657f7e79cf25 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java
@@ -29,7 +29,7 @@
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.mockito.Mockito.mock;
-public abstract class AbstractInferenceServiceBaseTests extends InferenceServiceTestCase{
+public abstract class AbstractInferenceServiceBaseTests extends InferenceServiceTestCase {
protected final TestConfiguration testConfiguration;
protected final MockWebServer webServer = new MockWebServer();
From ac44cc6cff056612ed4f3ecfdc5332b2eedefd11 Mon Sep 17 00:00:00 2001
From: elasticsearchmachine
Date: Thu, 25 Sep 2025 17:32:57 +0000
Subject: [PATCH 05/10] [CI] Update transport version definitions
---
server/src/main/resources/transport/upper_bounds/8.18.csv | 2 +-
server/src/main/resources/transport/upper_bounds/8.19.csv | 2 +-
server/src/main/resources/transport/upper_bounds/9.0.csv | 2 +-
server/src/main/resources/transport/upper_bounds/9.1.csv | 2 +-
server/src/main/resources/transport/upper_bounds/9.2.csv | 2 +-
5 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/server/src/main/resources/transport/upper_bounds/8.18.csv b/server/src/main/resources/transport/upper_bounds/8.18.csv
index ffc592e1809ee..266bfbbd3bf78 100644
--- a/server/src/main/resources/transport/upper_bounds/8.18.csv
+++ b/server/src/main/resources/transport/upper_bounds/8.18.csv
@@ -1 +1 @@
-initial_elasticsearch_8_18_8,8840010
+transform_check_for_dangling_tasks,8840011
diff --git a/server/src/main/resources/transport/upper_bounds/8.19.csv b/server/src/main/resources/transport/upper_bounds/8.19.csv
index 3cc6f439c5ea5..3600b3f8c633a 100644
--- a/server/src/main/resources/transport/upper_bounds/8.19.csv
+++ b/server/src/main/resources/transport/upper_bounds/8.19.csv
@@ -1 +1 @@
-initial_elasticsearch_8_19_5,8841069
+transform_check_for_dangling_tasks,8841070
diff --git a/server/src/main/resources/transport/upper_bounds/9.0.csv b/server/src/main/resources/transport/upper_bounds/9.0.csv
index 8ad2ed1a4cacf..c11e6837bb813 100644
--- a/server/src/main/resources/transport/upper_bounds/9.0.csv
+++ b/server/src/main/resources/transport/upper_bounds/9.0.csv
@@ -1 +1 @@
-initial_elasticsearch_9_0_8,9000017
+transform_check_for_dangling_tasks,9000018
diff --git a/server/src/main/resources/transport/upper_bounds/9.1.csv b/server/src/main/resources/transport/upper_bounds/9.1.csv
index 1cea5dc4d929b..80b97d85f7511 100644
--- a/server/src/main/resources/transport/upper_bounds/9.1.csv
+++ b/server/src/main/resources/transport/upper_bounds/9.1.csv
@@ -1 +1 @@
-initial_elasticsearch_9_1_5,9112008
+transform_check_for_dangling_tasks,9112009
diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv
index b1209b927d8a5..e4c91df18cda8 100644
--- a/server/src/main/resources/transport/upper_bounds/9.2.csv
+++ b/server/src/main/resources/transport/upper_bounds/9.2.csv
@@ -1 +1 @@
-inference_api_openai_embeddings_headers,9169000
+index_reshard_shardcount_summary,9172000
From cc99292fab2956b11320a59f5a81f130a016ee69 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Mon, 29 Sep 2025 15:08:14 -0400
Subject: [PATCH 06/10] Removing deprecated function
---
.../xpack/inference/services/ServiceUtils.java | 12 ------------
.../AlibabaCloudSearchService.java | 4 ++--
.../amazonbedrock/AmazonBedrockService.java | 4 ++--
.../services/anthropic/AnthropicService.java | 4 ++--
.../azureaistudio/AzureAiStudioService.java | 4 ++--
.../services/azureopenai/AzureOpenAiService.java | 4 ++--
.../inference/services/cohere/CohereService.java | 4 ++--
.../contextualai/ContextualAiService.java | 4 ++--
.../elastic/ElasticInferenceService.java | 4 ++--
.../googleaistudio/GoogleAiStudioService.java | 4 ++--
.../googlevertexai/GoogleVertexAiService.java | 4 ++--
.../huggingface/HuggingFaceBaseService.java | 4 ++--
.../services/ibmwatsonx/IbmWatsonxService.java | 4 ++--
.../inference/services/jinaai/JinaAIService.java | 4 ++--
.../services/mistral/MistralService.java | 4 ++--
.../services/voyageai/VoyageAIService.java | 4 ++--
.../amazonbedrock/AmazonBedrockServiceTests.java | 12 ++++++++++--
.../azureaistudio/AzureAiStudioServiceTests.java | 6 +++++-
.../azureopenai/AzureOpenAiServiceTests.java | 12 ++++++++++--
.../services/cohere/CohereServiceTests.java | 16 ++++++++++++----
.../services/jinaai/JinaAIServiceTests.java | 16 ++++++++++++----
.../services/mistral/MistralServiceTests.java | 6 +++++-
.../services/voyageai/VoyageAIServiceTests.java | 16 ++++++++++++----
23 files changed, 96 insertions(+), 60 deletions(-)
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
index 69b15f3e32cc2..3fda2506b9721 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
@@ -1079,18 +1079,6 @@ public interface EnumConstructor> {
E apply(String name) throws IllegalArgumentException;
}
- /**
- * @deprecated use {@link #parsePersistedConfigErrorMsg(String, String, TaskType)} instead
- */
- @Deprecated
- public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName) {
- return format(
- "Failed to parse stored model [%s] for [%s] service, please delete and add the service again",
- inferenceEntityId,
- serviceName
- );
- }
-
public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) {
return format(
"Failed to parse stored model [%s] for [%s] service, error: [%s]. Please delete and add the service again",
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java
index f474850b9f190..749a7f994d3c5 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java
@@ -256,7 +256,7 @@ public AlibabaCloudSearchModel parsePersistedConfigWithSecrets(
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
);
}
@@ -277,7 +277,7 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta
taskSettingsMap,
chunkingSettings,
null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java
index 11204018a5523..8275257608b33 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java
@@ -241,7 +241,7 @@ public Model parsePersistedConfigWithSecrets(
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
- parsePersistedConfigErrorMsg(modelId, NAME),
+ parsePersistedConfigErrorMsg(modelId, NAME, taskType),
ConfigurationParseContext.PERSISTENT
);
}
@@ -263,7 +263,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
);
- MatcherAssert.assertThat(
+ assertThat(
+ thrownException.getMessage(),
+ containsString("Failed to parse stored model [id] for [cohere] service")
+ );
+ assertThat(
thrownException.getMessage(),
- is("Failed to parse stored model [id] for [cohere] service, please delete and add the service again")
+ containsString("The [cohere] service does not support task type [sparse_embedding]")
);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
index fc50acdbd39b6..e5ed8891b368a 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
@@ -443,9 +443,13 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
)
);
- MatcherAssert.assertThat(
+ assertThat(
+ thrownException.getMessage(),
+ containsString("Failed to parse stored model [id] for [jinaai] service")
+ );
+ assertThat(
thrownException.getMessage(),
- is("Failed to parse stored model [id] for [jinaai] service, please delete and add the service again")
+ containsString("The [jinaai] service does not support task type [sparse_embedding]")
);
}
}
@@ -683,9 +687,13 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro
() -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
);
- MatcherAssert.assertThat(
+ assertThat(
+ thrownException.getMessage(),
+ containsString("Failed to parse stored model [id] for [jinaai] service")
+ );
+ assertThat(
thrownException.getMessage(),
- is("Failed to parse stored model [id] for [jinaai] service, please delete and add the service again")
+ containsString("The [jinaai] service does not support task type [sparse_embedding]")
);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
index 50731811e4164..1bf37948012e4 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
@@ -744,7 +744,11 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
assertThat(
thrownException.getMessage(),
- is("Failed to parse stored model [id] for [mistral] service, please delete and add the service again")
+ containsString("Failed to parse stored model [id] for [mistral] service")
+ );
+ assertThat(
+ thrownException.getMessage(),
+ containsString("The [mistral] service does not support task type [sparse_embedding]")
);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
index 8cad8cbad208a..99c6d31c207b6 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
@@ -409,9 +409,13 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
)
);
- MatcherAssert.assertThat(
+ assertThat(
+ thrownException.getMessage(),
+ containsString("Failed to parse stored model [id] for [voyageai] service")
+ );
+ assertThat(
thrownException.getMessage(),
- is("Failed to parse stored model [id] for [voyageai] service, please delete and add the service again")
+ containsString("The [voyageai] service does not support task type [sparse_embedding]")
);
}
}
@@ -624,9 +628,13 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro
() -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
);
- MatcherAssert.assertThat(
+ assertThat(
+ thrownException.getMessage(),
+ containsString("Failed to parse stored model [id] for [voyageai] service")
+ );
+ assertThat(
thrownException.getMessage(),
- is("Failed to parse stored model [id] for [voyageai] service, please delete and add the service again")
+ containsString("The [voyageai] service does not support task type [sparse_embedding]")
);
}
}
From 0896e5c8390544ce8769a50bd288bcbda42b8da0 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Tue, 30 Sep 2025 09:27:44 -0400
Subject: [PATCH 07/10] Moving string creation and refactoring customservice
chunking
---
.../inference/services/ServiceUtils.java | 18 ++++++++++++++++++
.../inference/services/ai21/Ai21Service.java | 18 +++++-------------
.../AlibabaCloudSearchService.java | 18 +++++-------------
.../amazonbedrock/AmazonBedrockService.java | 8 ++------
.../services/anthropic/AnthropicService.java | 18 +++++-------------
.../azureaistudio/AzureAiStudioService.java | 16 +++++-----------
.../azureopenai/AzureOpenAiService.java | 18 +++++-------------
.../services/cohere/CohereService.java | 18 +++++-------------
.../contextualai/ContextualAiService.java | 8 +++-----
.../services/custom/CustomService.java | 19 +++++++++++--------
.../elastic/ElasticInferenceService.java | 16 +++++-----------
.../googleaistudio/GoogleAiStudioService.java | 18 +++++-------------
.../googlevertexai/GoogleVertexAiService.java | 18 +++++-------------
.../ibmwatsonx/IbmWatsonxService.java | 18 +++++-------------
.../services/jinaai/JinaAIService.java | 18 +++++-------------
.../services/llama/LlamaService.java | 19 +++++--------------
.../services/mistral/MistralService.java | 18 +++++-------------
.../services/openai/OpenAiService.java | 16 +++++-----------
.../services/voyageai/VoyageAIService.java | 18 +++++-------------
19 files changed, 109 insertions(+), 209 deletions(-)
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
index 3fda2506b9721..2c383f5db2f5b 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
@@ -1088,6 +1088,24 @@ public static String parsePersistedConfigErrorMsg(String inferenceEntityId, Stri
);
}
+ /**
+ * Create an exception for when the task type is not valid for the service.
+ */
+ public static ElasticsearchStatusException createInvalidTaskTypeException(
+ String inferenceEntityId,
+ String serviceName,
+ TaskType taskType,
+ ConfigurationParseContext parseContext
+ ) {
+ var message = parseContext == ConfigurationParseContext.PERSISTENT
+ ? parsePersistedConfigErrorMsg(inferenceEntityId, serviceName, taskType)
+ : TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName);
+ return new ElasticsearchStatusException(
+ message,
+ RestStatus.BAD_REQUEST
+ );
+ }
+
public static ElasticsearchStatusException createInvalidModelException(Model model) {
return new ElasticsearchStatusException(
format(
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java
index f1375085c1f71..57bcc267ac644 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.ai21;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
@@ -27,7 +26,6 @@
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
@@ -55,7 +53,7 @@
import java.util.Set;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@@ -183,7 +181,6 @@ public void parseRequestConfig(
taskType,
serviceSettingsMap,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -212,8 +209,7 @@ public Ai21Model parsePersistedConfigWithSecrets(
modelId,
taskType,
serviceSettingsMap,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(modelId, NAME, taskType)
+ secretSettingsMap
);
}
@@ -226,8 +222,7 @@ public Ai21Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
switch (taskType) {
case CHAT_COMPLETION, COMPLETION:
return new Ai21ChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context);
default:
- throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ throw createInvalidTaskTypeException(modelId, NAME, taskType, context);
}
}
@@ -261,15 +255,13 @@ private Ai21Model createModelFromPersistent(
String inferenceEntityId,
TaskType taskType,
Map serviceSettings,
- Map secretSettings,
- String failureMessage
+ Map secretSettings
) {
return createModel(
inferenceEntityId,
taskType,
serviceSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java
index 749a7f994d3c5..5f2378d40116d 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.alibabacloudsearch;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -32,7 +31,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -59,7 +57,7 @@
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -135,7 +133,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -165,8 +162,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -175,7 +171,6 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -187,7 +182,6 @@ private static AlibabaCloudSearchModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -229,7 +223,7 @@ private static AlibabaCloudSearchModel createModel(
secretSettings,
context
);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -255,8 +249,7 @@ public AlibabaCloudSearchModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ secretSettingsMap
);
}
@@ -276,8 +269,7 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ null
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java
index 8275257608b33..5e3ba81b7e1bc 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java
@@ -60,7 +60,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -204,7 +204,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -241,7 +240,6 @@ public Model parsePersistedConfigWithSecrets(
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
- parsePersistedConfigErrorMsg(modelId, NAME, taskType),
ConfigurationParseContext.PERSISTENT
);
}
@@ -263,7 +261,6 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
switch (taskType) {
@@ -318,7 +314,7 @@ private static AmazonBedrockModel createModel(
checkChatCompletionProviderForTopKParameter(model);
return model;
}
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java
index ce39fc261312a..1f2c7f6cb4cdc 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.anthropic;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -28,7 +27,6 @@
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -48,7 +46,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@@ -94,7 +92,6 @@ public void parseRequestConfig(
serviceSettingsMap,
taskSettingsMap,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -113,8 +110,7 @@ private static AnthropicModel createModelFromPersistent(
TaskType taskType,
Map serviceSettings,
Map taskSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -122,7 +118,6 @@ private static AnthropicModel createModelFromPersistent(
serviceSettings,
taskSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -133,7 +128,6 @@ private static AnthropicModel createModel(
Map serviceSettings,
Map taskSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -146,7 +140,7 @@ private static AnthropicModel createModel(
secretSettings,
context
);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -166,8 +160,7 @@ public AnthropicModel parsePersistedConfigWithSecrets(
taskType,
serviceSettingsMap,
taskSettingsMap,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ secretSettingsMap
);
}
@@ -181,8 +174,7 @@ public AnthropicModel parsePersistedConfig(String inferenceEntityId, TaskType ta
taskType,
serviceSettingsMap,
taskSettingsMap,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ null
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java
index 3bda3ca281f92..23d46820b688f 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java
@@ -60,7 +60,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -185,7 +185,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -221,8 +220,7 @@ public AzureAiStudioModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ secretSettingsMap
);
}
@@ -242,8 +240,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ null
);
}
@@ -279,7 +276,6 @@ private static AzureAiStudioModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
@@ -305,7 +301,7 @@ private static AzureAiStudioModel createModel(
context
);
case RERANK -> model = new AzureAiStudioRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
}
final var azureAiStudioServiceSettings = (AzureAiStudioServiceSettings) model.getServiceSettings();
checkProviderAndEndpointTypeForTask(taskType, azureAiStudioServiceSettings.provider(), azureAiStudioServiceSettings.endpointType());
@@ -318,8 +314,7 @@ private AzureAiStudioModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- Map secretSettings,
- String failureMessage
+ Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -328,7 +323,6 @@ private AzureAiStudioModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java
index 46ec593b5b5b6..f3d70f004131b 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.azureopenai;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -30,7 +29,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -55,7 +53,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -114,7 +112,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -134,8 +131,7 @@ private static AzureOpenAiModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -144,7 +140,6 @@ private static AzureOpenAiModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -156,7 +151,6 @@ private static AzureOpenAiModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
switch (taskType) {
@@ -183,7 +177,7 @@ private static AzureOpenAiModel createModel(
context
);
}
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
}
}
@@ -209,8 +203,7 @@ public AzureOpenAiModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ secretSettingsMap
);
}
@@ -230,8 +223,7 @@ public AzureOpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ null
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
index 2374c67944501..f1e34138aad5c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.cohere;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -32,7 +31,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -60,7 +58,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -130,7 +128,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -150,8 +147,7 @@ private static CohereModel createModelWithoutLoggingDeprecations(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -160,7 +156,6 @@ private static CohereModel createModelWithoutLoggingDeprecations(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -172,7 +167,6 @@ private static CohereModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -186,7 +180,7 @@ private static CohereModel createModel(
);
case RERANK -> new CohereRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context);
case COMPLETION -> new CohereCompletionModel(inferenceEntityId, serviceSettings, secretSettings, context);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -212,8 +206,7 @@ public CohereModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ secretSettingsMap
);
}
@@ -233,8 +226,7 @@ public CohereModel parsePersistedConfig(String inferenceEntityId, TaskType taskT
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ null
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java
index 6937ac4703262..2f59400287a10 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java
@@ -46,6 +46,8 @@
import java.util.List;
import java.util.Map;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
+
/**
* Contextual AI inference service for reranking tasks.
* This service uses the Contextual AI REST API to perform document reranking.
@@ -97,7 +99,6 @@ public void parseRequestConfig(
serviceSettingsMap,
taskSettingsMap,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -117,11 +118,10 @@ private static ContextualAiRerankModel createModel(
Map serviceSettings,
Map taskSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
if (taskType != TaskType.RERANK) {
- throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
}
return new ContextualAiRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context);
@@ -144,7 +144,6 @@ public ContextualAiRerankModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
secretSettingsMap,
- ServiceUtils.parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType),
ConfigurationParseContext.PERSISTENT
);
}
@@ -160,7 +159,6 @@ public ContextualAiRerankModel parsePersistedConfig(String inferenceEntityId, Ta
serviceSettingsMap,
taskSettingsMap,
null,
- ServiceUtils.parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType),
ConfigurationParseContext.PERSISTENT
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
index 91f97ac585d85..e0c91885999f4 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
@@ -227,10 +227,7 @@ public CustomModel parsePersistedConfigWithSecrets(
Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
- ChunkingSettings chunkingSettings = null;
- if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
- chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
- }
+ var chunkingSettings = extractPersistentChunkingSettings(config, taskType);
return createModelWithoutLoggingDeprecations(
inferenceEntityId,
@@ -242,15 +239,21 @@ public CustomModel parsePersistedConfigWithSecrets(
);
}
+ private static ChunkingSettings extractPersistentChunkingSettings(Map config, TaskType taskType) {
+ if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
+ // note there's
+ return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
+ }
+
+ return null;
+ }
+
@Override
public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) {
Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
- ChunkingSettings chunkingSettings = null;
- if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
- chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
- }
+ var chunkingSettings = extractPersistentChunkingSettings(config, taskType);
return createModelWithoutLoggingDeprecations(
inferenceEntityId,
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java
index a1d7e331db8a2..7063b00c8ac99 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java
@@ -79,7 +79,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -446,7 +446,6 @@ public void parseRequestConfig(
chunkingSettings,
serviceSettingsMap,
elasticInferenceServiceComponents,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -494,7 +493,6 @@ private static ElasticInferenceServiceModel createModel(
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -540,7 +538,7 @@ private static ElasticInferenceServiceModel createModel(
context,
chunkingSettings
);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -566,8 +564,7 @@ public Model parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ secretSettingsMap
);
}
@@ -587,8 +584,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ null
);
}
@@ -603,8 +599,7 @@ private ElasticInferenceServiceModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -614,7 +609,6 @@ private ElasticInferenceServiceModel createModelFromPersistent(
chunkingSettings,
secretSettings,
elasticInferenceServiceComponents,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java
index 02fee095e71e1..f5924cf6f9854 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.googleaistudio;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -31,7 +30,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
@@ -60,7 +58,7 @@
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -127,7 +125,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -149,7 +146,6 @@ private static GoogleAiStudioModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -172,7 +168,7 @@ private static GoogleAiStudioModel createModel(
secretSettings,
context
);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -198,8 +194,7 @@ public GoogleAiStudioModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ secretSettingsMap
);
}
@@ -209,8 +204,7 @@ private static GoogleAiStudioModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- Map secretSettings,
- String failureMessage
+ Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -219,7 +213,6 @@ private static GoogleAiStudioModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -240,8 +233,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ null
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java
index 5fa251f74dd9a..f8318b1d6f838 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java
@@ -8,7 +8,6 @@
package org.elasticsearch.xpack.inference.services.googlevertexai;
import org.elasticsearch.ElasticsearchException;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -31,7 +30,6 @@
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
@@ -63,7 +61,7 @@
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -149,7 +147,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -185,8 +182,7 @@ public Model parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ secretSettingsMap
);
}
@@ -206,8 +202,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ null
);
}
@@ -356,8 +351,7 @@ private static GoogleVertexAiModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- Map secretSettings,
- String failureMessage
+ Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -366,7 +360,6 @@ private static GoogleVertexAiModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -378,7 +371,6 @@ private static GoogleVertexAiModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -412,7 +404,7 @@ private static GoogleVertexAiModel createModel(
context
);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java
index c1bfc2ec6807e..8b4d8fd1f7dae 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.ibmwatsonx;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -30,7 +29,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
@@ -62,7 +60,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@@ -128,7 +126,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -150,7 +147,6 @@ private static IbmWatsonxModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -181,7 +177,7 @@ private static IbmWatsonxModel createModel(
secretSettings,
context
);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -207,8 +203,7 @@ public IbmWatsonxModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ secretSettingsMap
);
}
@@ -228,8 +223,7 @@ private static IbmWatsonxModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- Map secretSettings,
- String failureMessage
+ Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -238,7 +232,6 @@ private static IbmWatsonxModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -259,8 +252,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ null
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java
index 4169fb3a62be5..e0f7ddf864876 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.jinaai;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -31,7 +30,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -57,7 +55,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -121,7 +119,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -141,8 +138,7 @@ private static JinaAIModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -151,7 +147,6 @@ private static JinaAIModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -163,7 +158,6 @@ private static JinaAIModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -177,7 +171,7 @@ private static JinaAIModel createModel(
context
);
case RERANK -> new JinaAIRerankModel(inferenceEntityId, NAME, serviceSettings, taskSettings, secretSettings, context);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -203,8 +197,7 @@ public JinaAIModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ secretSettingsMap
);
}
@@ -224,8 +217,7 @@ public JinaAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskT
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ null
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java
index 4d99151c65159..fed4824201e42 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.llama;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
@@ -28,7 +27,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
@@ -64,7 +62,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -138,7 +136,6 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
* @param serviceSettings the settings for the inference service
* @param chunkingSettings the settings for chunking, if applicable
* @param secretSettings the secret settings for the model, such as API keys or tokens
- * @param failureMessage the message to use in case of failure
* @param context the context for parsing configuration settings
* @return a new instance of LlamaModel based on the provided parameters
*/
@@ -148,7 +145,6 @@ protected LlamaModel createModel(
Map serviceSettings,
ChunkingSettings chunkingSettings,
Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
switch (taskType) {
@@ -157,7 +153,7 @@ protected LlamaModel createModel(
case CHAT_COMPLETION, COMPLETION:
return new LlamaChatCompletionModel(inferenceId, taskType, NAME, serviceSettings, secretSettings, context);
default:
- throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ throw createInvalidTaskTypeException(inferenceId, NAME, taskType, context);
}
}
@@ -283,7 +279,6 @@ public void parseRequestConfig(
serviceSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -318,8 +313,7 @@ public Model parsePersistedConfigWithSecrets(
taskType,
serviceSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(modelId, NAME, taskType)
+ secretSettingsMap
);
}
@@ -328,8 +322,7 @@ private LlamaModel createModelFromPersistent(
TaskType taskType,
Map serviceSettings,
ChunkingSettings chunkingSettings,
- Map secretSettings,
- String failureMessage
+ Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -337,7 +330,6 @@ private LlamaModel createModelFromPersistent(
serviceSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -357,8 +349,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
switch (taskType) {
@@ -299,7 +293,7 @@ private static MistralModel createModel(
case CHAT_COMPLETION, COMPLETION:
return new MistralChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context);
default:
- throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ throw createInvalidTaskTypeException(modelId, NAME, taskType, context);
}
}
@@ -308,8 +302,7 @@ private MistralModel createModelFromPersistent(
TaskType taskType,
Map serviceSettings,
ChunkingSettings chunkingSettings,
- Map secretSettings,
- String failureMessage
+ Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -317,7 +310,6 @@ private MistralModel createModelFromPersistent(
serviceSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
index 3a8464a5b8464..944c0af330c2a 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
@@ -65,7 +65,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -139,7 +139,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -159,8 +158,7 @@ private static OpenAiModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -169,7 +167,6 @@ private static OpenAiModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -181,7 +178,6 @@ private static OpenAiModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -204,7 +200,7 @@ private static OpenAiModel createModel(
secretSettings,
context
);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -232,8 +228,7 @@ public OpenAiModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ secretSettingsMap
);
}
@@ -255,8 +250,7 @@ public OpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskT
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ null
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java
index 3d2968b0c8a89..1b4be3842ca52 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.voyageai;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -31,7 +30,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -56,7 +54,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -152,7 +150,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -172,8 +169,7 @@ private static VoyageAIModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -182,7 +178,6 @@ private static VoyageAIModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -194,7 +189,6 @@ private static VoyageAIModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -208,7 +202,7 @@ private static VoyageAIModel createModel(
context
);
case RERANK -> new VoyageAIRerankModel(inferenceEntityId, NAME, serviceSettings, taskSettings, secretSettings, context);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -234,8 +228,7 @@ public VoyageAIModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ secretSettingsMap
);
}
@@ -255,8 +248,7 @@ public VoyageAIModel parsePersistedConfig(String inferenceEntityId, TaskType tas
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME, taskType)
+ null
);
}
From 49c1301afb6fed9bf1a561a02b55ffb3eee1d5b4 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Tue, 30 Sep 2025 09:41:46 -0400
Subject: [PATCH 08/10] Removing usages of persistent function
---
.../xpack/inference/services/ServiceUtils.java | 18 ++++++++++--------
.../services/custom/CustomService.java | 6 +++---
.../huggingface/HuggingFaceBaseService.java | 4 ----
.../HuggingFaceModelParameters.java | 1 -
.../huggingface/HuggingFaceService.java | 5 ++---
.../elser/HuggingFaceElserService.java | 5 ++---
6 files changed, 17 insertions(+), 22 deletions(-)
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
index 2c383f5db2f5b..9d03278470d35 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
@@ -1079,14 +1079,7 @@ public interface EnumConstructor> {
E apply(String name) throws IllegalArgumentException;
}
- public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) {
- return format(
- "Failed to parse stored model [%s] for [%s] service, error: [%s]. Please delete and add the service again",
- inferenceEntityId,
- serviceName,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName)
- );
- }
+
/**
* Create an exception for when the task type is not valid for the service.
@@ -1106,6 +1099,15 @@ public static ElasticsearchStatusException createInvalidTaskTypeException(
);
}
+ private static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) {
+ return format(
+ "Failed to parse stored model [%s] for [%s] service, error: [%s]. Please delete and add the service again",
+ inferenceEntityId,
+ serviceName,
+ TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName)
+ );
+ }
+
public static ElasticsearchStatusException createInvalidModelException(Model model) {
return new ElasticsearchStatusException(
format(
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
index e0c91885999f4..c5c8853c7374a 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
@@ -57,9 +57,9 @@
import java.util.List;
import java.util.Map;
-import static org.elasticsearch.inference.TaskType.unsupportedTaskTypeErrorMsg;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -211,7 +211,7 @@ private static CustomModel createModel(
ConfigurationParseContext context
) {
if (supportedTaskTypes.contains(taskType) == false) {
- throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST);
+ throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
}
return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, chunkingSettings, context);
}
@@ -241,7 +241,7 @@ public CustomModel parsePersistedConfigWithSecrets(
private static ChunkingSettings extractPersistentChunkingSettings(Map config, TaskType taskType) {
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
- // note there's
+ // note there's
return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java
index 470228d80c5d1..403a1758983ac 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java
@@ -31,7 +31,6 @@
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -84,7 +83,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
ConfigurationParseContext.REQUEST
)
);
@@ -123,7 +121,6 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, name(), taskType),
ConfigurationParseContext.PERSISTENT
)
);
@@ -147,7 +144,6 @@ public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType
taskSettingsMap,
chunkingSettings,
null,
- parsePersistedConfigErrorMsg(inferenceEntityId, name(), taskType),
ConfigurationParseContext.PERSISTENT
)
);
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java
index 6dabaa66ffb2b..7600207eebad0 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java
@@ -20,6 +20,5 @@ public record HuggingFaceModelParameters(
Map taskSettings,
ChunkingSettings chunkingSettings,
Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java
index d0a98d8252923..e727a9e20cd8c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.huggingface;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -26,7 +25,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
@@ -54,6 +52,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
/**
* This class is responsible for managing the Hugging Face inference service.
@@ -124,7 +123,7 @@ protected HuggingFaceModel createModel(HuggingFaceModelParameters params) {
params.secretSettings(),
params.context()
);
- default -> throw new ElasticsearchStatusException(params.failureMessage(), RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(params.inferenceEntityId(), NAME, params.taskType(), params.context());
};
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java
index 775a4e90ae034..081d5c63b84ff 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.huggingface.elser;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -25,7 +24,6 @@
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
@@ -50,6 +48,7 @@
import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL;
@@ -87,7 +86,7 @@ protected HuggingFaceModel createModel(HuggingFaceModelParameters input) {
input.secretSettings(),
input.context()
);
- default -> throw new ElasticsearchStatusException(input.failureMessage(), RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(input.inferenceEntityId(), NAME, input.taskType(), input.context());
};
}
From 6ec99f9a80ff0801d51c9bb32a51c306a19188cf Mon Sep 17 00:00:00 2001
From: elasticsearchmachine
Date: Tue, 30 Sep 2025 13:50:17 +0000
Subject: [PATCH 09/10] [CI] Auto commit changes from spotless
---
.../inference/services/ServiceUtils.java | 7 +----
.../inference/services/ai21/Ai21Service.java | 30 +++----------------
.../services/anthropic/AnthropicService.java | 16 ++--------
.../azureaistudio/AzureAiStudioService.java | 9 +-----
.../azureopenai/AzureOpenAiService.java | 9 +-----
.../contextualai/ContextualAiService.java | 9 +-----
.../elastic/ElasticInferenceService.java | 9 +-----
.../googleaistudio/GoogleAiStudioService.java | 9 +-----
.../googlevertexai/GoogleVertexAiService.java | 9 +-----
.../ibmwatsonx/IbmWatsonxService.java | 9 +-----
.../services/jinaai/JinaAIService.java | 9 +-----
.../services/llama/LlamaService.java | 16 ++--------
.../services/mistral/MistralService.java | 16 ++--------
.../services/openai/OpenAiService.java | 9 +-----
.../services/voyageai/VoyageAIService.java | 9 +-----
.../AmazonBedrockServiceTests.java | 10 ++-----
.../AzureAiStudioServiceTests.java | 5 +---
.../azureopenai/AzureOpenAiServiceTests.java | 10 ++-----
.../services/cohere/CohereServiceTests.java | 20 +++----------
.../services/jinaai/JinaAIServiceTests.java | 20 +++----------
.../services/mistral/MistralServiceTests.java | 10 ++-----
.../voyageai/VoyageAIServiceTests.java | 10 ++-----
22 files changed, 38 insertions(+), 222 deletions(-)
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
index 9d03278470d35..2607e68acadbb 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
@@ -1079,8 +1079,6 @@ public interface EnumConstructor> {
E apply(String name) throws IllegalArgumentException;
}
-
-
/**
* Create an exception for when the task type is not valid for the service.
*/
@@ -1093,10 +1091,7 @@ public static ElasticsearchStatusException createInvalidTaskTypeException(
var message = parseContext == ConfigurationParseContext.PERSISTENT
? parsePersistedConfigErrorMsg(inferenceEntityId, serviceName, taskType)
: TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName);
- return new ElasticsearchStatusException(
- message,
- RestStatus.BAD_REQUEST
- );
+ return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST);
}
private static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java
index 57bcc267ac644..69eddd7ecade2 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java
@@ -176,13 +176,7 @@ public void parseRequestConfig(
Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
- Ai21Model model = createModel(
- modelId,
- taskType,
- serviceSettingsMap,
- serviceSettingsMap,
- ConfigurationParseContext.REQUEST
- );
+ Ai21Model model = createModel(modelId, taskType, serviceSettingsMap, serviceSettingsMap, ConfigurationParseContext.REQUEST);
throwIfNotEmptyMap(config, NAME);
throwIfNotEmptyMap(serviceSettingsMap, NAME);
@@ -205,12 +199,7 @@ public Ai21Model parsePersistedConfigWithSecrets(
removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
- return createModelFromPersistent(
- modelId,
- taskType,
- serviceSettingsMap,
- secretSettingsMap
- );
+ return createModelFromPersistent(modelId, taskType, serviceSettingsMap, secretSettingsMap);
}
@Override
@@ -218,12 +207,7 @@ public Ai21Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
- return createModelFromPersistent(
- modelId,
- taskType,
- serviceSettingsMap,
- null
- );
+ return createModelFromPersistent(modelId, taskType, serviceSettingsMap, null);
}
@Override
@@ -257,13 +241,7 @@ private Ai21Model createModelFromPersistent(
Map serviceSettings,
Map secretSettings
) {
- return createModel(
- inferenceEntityId,
- taskType,
- serviceSettings,
- secretSettings,
- ConfigurationParseContext.PERSISTENT
- );
+ return createModel(inferenceEntityId, taskType, serviceSettings, secretSettings, ConfigurationParseContext.PERSISTENT);
}
/**
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java
index 1f2c7f6cb4cdc..29a1582ce6236 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java
@@ -155,13 +155,7 @@ public AnthropicModel parsePersistedConfigWithSecrets(
Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- secretSettingsMap
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap);
}
@Override
@@ -169,13 +163,7 @@ public AnthropicModel parsePersistedConfig(String inferenceEntityId, TaskType ta
Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- null
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java
index 23d46820b688f..60b8bbdf86fa5 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java
@@ -234,14 +234,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java
index f3d70f004131b..de067ae0096b5 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java
@@ -217,14 +217,7 @@ public AzureOpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java
index 2f59400287a10..9088de1dacd2d 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java
@@ -153,14 +153,7 @@ public ContextualAiRerankModel parsePersistedConfig(String inferenceEntityId, Ta
Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
- return createModel(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- null,
- ConfigurationParseContext.PERSISTENT
- );
+ return createModel(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null, ConfigurationParseContext.PERSISTENT);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java
index 7063b00c8ac99..c4156c0bfd6b9 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java
@@ -578,14 +578,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java
index f5924cf6f9854..dc08ec8544e3c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java
@@ -227,14 +227,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java
index f8318b1d6f838..66a4dd0649730 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java
@@ -196,14 +196,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java
index 8b4d8fd1f7dae..ee836d03747de 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java
@@ -246,14 +246,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java
index e0f7ddf864876..5e95f85e78ecd 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java
@@ -211,14 +211,7 @@ public JinaAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskT
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java
index fed4824201e42..b3c0e12927fff 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java
@@ -308,13 +308,7 @@ public Model parsePersistedConfigWithSecrets(
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- modelId,
- taskType,
- serviceSettingsMap,
- chunkingSettings,
- secretSettingsMap
- );
+ return createModelFromPersistent(modelId, taskType, serviceSettingsMap, chunkingSettings, secretSettingsMap);
}
private LlamaModel createModelFromPersistent(
@@ -344,13 +338,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
);
- assertThat(
- thrownException.getMessage(),
- containsString("Failed to parse stored model [id] for [amazonbedrock] service")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [amazonbedrock] service"));
assertThat(
thrownException.getMessage(),
containsString("The [amazonbedrock] service does not support task type [sparse_embedding]")
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
index 3e17a603c85c2..127b7d1c4cfae 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
@@ -807,10 +807,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
() -> service.parsePersistedConfigWithSecrets("id", TaskType.SPARSE_EMBEDDING, config.config(), config.secrets())
);
- assertThat(
- thrownException.getMessage(),
- containsString("Failed to parse stored model [id] for [azureaistudio] service")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [azureaistudio] service"));
assertThat(
thrownException.getMessage(),
containsString("The [azureaistudio] service does not support task type [sparse_embedding]")
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
index 0012e1f27be02..cf3ac6979b8e3 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
@@ -435,10 +435,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
)
);
- assertThat(
- thrownException.getMessage(),
- containsString("Failed to parse stored model [id] for [azureopenai] service")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [azureopenai] service"));
assertThat(
thrownException.getMessage(),
containsString("The [azureopenai] service does not support task type [sparse_embedding]")
@@ -672,10 +669,7 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro
() -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
);
- assertThat(
- thrownException.getMessage(),
- containsString("Failed to parse stored model [id] for [azureopenai] service")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [azureopenai] service"));
assertThat(
thrownException.getMessage(),
containsString("The [azureopenai] service does not support task type [sparse_embedding]")
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
index 453a5ee72d4c1..9642e7b85cdc7 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
@@ -450,14 +450,8 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
)
);
- assertThat(
- thrownException.getMessage(),
- containsString("Failed to parse stored model [id] for [cohere] service")
- );
- assertThat(
- thrownException.getMessage(),
- containsString("The [cohere] service does not support task type [sparse_embedding]")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [cohere] service"));
+ assertThat(thrownException.getMessage(), containsString("The [cohere] service does not support task type [sparse_embedding]"));
}
}
@@ -691,14 +685,8 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro
() -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
);
- assertThat(
- thrownException.getMessage(),
- containsString("Failed to parse stored model [id] for [cohere] service")
- );
- assertThat(
- thrownException.getMessage(),
- containsString("The [cohere] service does not support task type [sparse_embedding]")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [cohere] service"));
+ assertThat(thrownException.getMessage(), containsString("The [cohere] service does not support task type [sparse_embedding]"));
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
index e5ed8891b368a..998559d102ab7 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
@@ -443,14 +443,8 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
)
);
- assertThat(
- thrownException.getMessage(),
- containsString("Failed to parse stored model [id] for [jinaai] service")
- );
- assertThat(
- thrownException.getMessage(),
- containsString("The [jinaai] service does not support task type [sparse_embedding]")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [jinaai] service"));
+ assertThat(thrownException.getMessage(), containsString("The [jinaai] service does not support task type [sparse_embedding]"));
}
}
@@ -687,14 +681,8 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro
() -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
);
- assertThat(
- thrownException.getMessage(),
- containsString("Failed to parse stored model [id] for [jinaai] service")
- );
- assertThat(
- thrownException.getMessage(),
- containsString("The [jinaai] service does not support task type [sparse_embedding]")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [jinaai] service"));
+ assertThat(thrownException.getMessage(), containsString("The [jinaai] service does not support task type [sparse_embedding]"));
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
index 1bf37948012e4..c8f6d0ee0e2fe 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
@@ -742,14 +742,8 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
() -> service.parsePersistedConfigWithSecrets("id", TaskType.SPARSE_EMBEDDING, config.config(), config.secrets())
);
- assertThat(
- thrownException.getMessage(),
- containsString("Failed to parse stored model [id] for [mistral] service")
- );
- assertThat(
- thrownException.getMessage(),
- containsString("The [mistral] service does not support task type [sparse_embedding]")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [mistral] service"));
+ assertThat(thrownException.getMessage(), containsString("The [mistral] service does not support task type [sparse_embedding]"));
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
index 99c6d31c207b6..d7f5726af85e0 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
@@ -409,10 +409,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
)
);
- assertThat(
- thrownException.getMessage(),
- containsString("Failed to parse stored model [id] for [voyageai] service")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [voyageai] service"));
assertThat(
thrownException.getMessage(),
containsString("The [voyageai] service does not support task type [sparse_embedding]")
@@ -628,10 +625,7 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro
() -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
);
- assertThat(
- thrownException.getMessage(),
- containsString("Failed to parse stored model [id] for [voyageai] service")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [voyageai] service"));
assertThat(
thrownException.getMessage(),
containsString("The [voyageai] service does not support task type [sparse_embedding]")
From d52d4dc8d4580ede6f18bada75eb702968a1ad97 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner
Date: Tue, 30 Sep 2025 10:35:17 -0400
Subject: [PATCH 10/10] Finishing comment
---
.../xpack/inference/services/custom/CustomService.java | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
index c5c8853c7374a..aa7fd1337428f 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
@@ -241,7 +241,13 @@ public CustomModel parsePersistedConfigWithSecrets(
private static ChunkingSettings extractPersistentChunkingSettings(Map config, TaskType taskType) {
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
- // note there's
+ /*
+ * There's a sutle difference between how the chunking settings are parsed for the request context vs the persistent context.
+ * For persistent context, to support backwards compatibility, if the chunking settings are not present, removeFromMap will
+ * return null which results in the older word boundary chunking settings being used as the default.
+ * For request context, removeFromMapOrDefaultEmpty returns an empty map which results in the newer sentence boundary chunking
+ * settings being used as the default.
+ */
return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}