diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
index b8ff560ced006..2607e68acadbb 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
@@ -1079,11 +1079,27 @@ public interface EnumConstructor> {
E apply(String name) throws IllegalArgumentException;
}
- public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName) {
+ /**
+ * Create an exception for when the task type is not valid for the service.
+ */
+ public static ElasticsearchStatusException createInvalidTaskTypeException(
+ String inferenceEntityId,
+ String serviceName,
+ TaskType taskType,
+ ConfigurationParseContext parseContext
+ ) {
+ var message = parseContext == ConfigurationParseContext.PERSISTENT
+ ? parsePersistedConfigErrorMsg(inferenceEntityId, serviceName, taskType)
+ : TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName);
+ return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST);
+ }
+
+ private static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) {
return format(
- "Failed to parse stored model [%s] for [%s] service, please delete and add the service again",
+ "Failed to parse stored model [%s] for [%s] service, error: [%s]. Please delete and add the service again",
inferenceEntityId,
- serviceName
+ serviceName,
+ TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName)
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java
index b677ec642075e..69eddd7ecade2 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.ai21;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
@@ -27,7 +26,6 @@
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
@@ -55,7 +53,7 @@
import java.util.Set;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@@ -178,14 +176,7 @@ public void parseRequestConfig(
Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
- Ai21Model model = createModel(
- modelId,
- taskType,
- serviceSettingsMap,
- serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
- ConfigurationParseContext.REQUEST
- );
+ Ai21Model model = createModel(modelId, taskType, serviceSettingsMap, serviceSettingsMap, ConfigurationParseContext.REQUEST);
throwIfNotEmptyMap(config, NAME);
throwIfNotEmptyMap(serviceSettingsMap, NAME);
@@ -208,13 +199,7 @@ public Ai21Model parsePersistedConfigWithSecrets(
removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
- return createModelFromPersistent(
- modelId,
- taskType,
- serviceSettingsMap,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(modelId, NAME)
- );
+ return createModelFromPersistent(modelId, taskType, serviceSettingsMap, secretSettingsMap);
}
@Override
@@ -222,7 +207,7 @@ public Ai21Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
- return createModelFromPersistent(modelId, taskType, serviceSettingsMap, null, parsePersistedConfigErrorMsg(modelId, NAME));
+ return createModelFromPersistent(modelId, taskType, serviceSettingsMap, null);
}
@Override
@@ -240,14 +225,13 @@ private static Ai21Model createModel(
TaskType taskType,
Map serviceSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
switch (taskType) {
case CHAT_COMPLETION, COMPLETION:
return new Ai21ChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context);
default:
- throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ throw createInvalidTaskTypeException(modelId, NAME, taskType, context);
}
}
@@ -255,17 +239,9 @@ private Ai21Model createModelFromPersistent(
String inferenceEntityId,
TaskType taskType,
Map serviceSettings,
- Map secretSettings,
- String failureMessage
+ Map secretSettings
) {
- return createModel(
- inferenceEntityId,
- taskType,
- serviceSettings,
- secretSettings,
- failureMessage,
- ConfigurationParseContext.PERSISTENT
- );
+ return createModel(inferenceEntityId, taskType, serviceSettings, secretSettings, ConfigurationParseContext.PERSISTENT);
}
/**
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java
index f474850b9f190..5f2378d40116d 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.alibabacloudsearch;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -32,7 +31,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -59,7 +57,7 @@
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -135,7 +133,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -165,8 +162,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -175,7 +171,6 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -187,7 +182,6 @@ private static AlibabaCloudSearchModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -229,7 +223,7 @@ private static AlibabaCloudSearchModel createModel(
secretSettings,
context
);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -255,8 +249,7 @@ public AlibabaCloudSearchModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ secretSettingsMap
);
}
@@ -276,8 +269,7 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ null
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java
index 11204018a5523..5e3ba81b7e1bc 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java
@@ -60,7 +60,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -204,7 +204,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -241,7 +240,6 @@ public Model parsePersistedConfigWithSecrets(
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
- parsePersistedConfigErrorMsg(modelId, NAME),
ConfigurationParseContext.PERSISTENT
);
}
@@ -263,7 +261,6 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
switch (taskType) {
@@ -318,7 +314,7 @@ private static AmazonBedrockModel createModel(
checkChatCompletionProviderForTopKParameter(model);
return model;
}
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java
index 224d62f83f28b..29a1582ce6236 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.anthropic;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -28,7 +27,6 @@
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -48,7 +46,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@@ -94,7 +92,6 @@ public void parseRequestConfig(
serviceSettingsMap,
taskSettingsMap,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -113,8 +110,7 @@ private static AnthropicModel createModelFromPersistent(
TaskType taskType,
Map serviceSettings,
Map taskSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -122,7 +118,6 @@ private static AnthropicModel createModelFromPersistent(
serviceSettings,
taskSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -133,7 +128,6 @@ private static AnthropicModel createModel(
Map serviceSettings,
Map taskSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -146,7 +140,7 @@ private static AnthropicModel createModel(
secretSettings,
context
);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -161,14 +155,7 @@ public AnthropicModel parsePersistedConfigWithSecrets(
Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap);
}
@Override
@@ -176,14 +163,7 @@ public AnthropicModel parsePersistedConfig(String inferenceEntityId, TaskType ta
Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java
index 7578aa702ad7c..60b8bbdf86fa5 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java
@@ -60,7 +60,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -185,7 +185,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -221,8 +220,7 @@ public AzureAiStudioModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ secretSettingsMap
);
}
@@ -236,15 +234,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
@@ -279,7 +269,6 @@ private static AzureAiStudioModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
@@ -305,7 +294,7 @@ private static AzureAiStudioModel createModel(
context
);
case RERANK -> model = new AzureAiStudioRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
}
final var azureAiStudioServiceSettings = (AzureAiStudioServiceSettings) model.getServiceSettings();
checkProviderAndEndpointTypeForTask(taskType, azureAiStudioServiceSettings.provider(), azureAiStudioServiceSettings.endpointType());
@@ -318,8 +307,7 @@ private AzureAiStudioModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- Map secretSettings,
- String failureMessage
+ Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -328,7 +316,6 @@ private AzureAiStudioModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java
index 077e5361dd46f..de067ae0096b5 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.azureopenai;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -30,7 +29,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -55,7 +53,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -114,7 +112,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -134,8 +131,7 @@ private static AzureOpenAiModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -144,7 +140,6 @@ private static AzureOpenAiModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -156,7 +151,6 @@ private static AzureOpenAiModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
switch (taskType) {
@@ -183,7 +177,7 @@ private static AzureOpenAiModel createModel(
context
);
}
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
}
}
@@ -209,8 +203,7 @@ public AzureOpenAiModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ secretSettingsMap
);
}
@@ -224,15 +217,7 @@ public AzureOpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
index 2561f198075e2..f1e34138aad5c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.cohere;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -32,7 +31,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -60,7 +58,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -130,7 +128,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -150,8 +147,7 @@ private static CohereModel createModelWithoutLoggingDeprecations(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -160,7 +156,6 @@ private static CohereModel createModelWithoutLoggingDeprecations(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -172,7 +167,6 @@ private static CohereModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -186,7 +180,7 @@ private static CohereModel createModel(
);
case RERANK -> new CohereRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context);
case COMPLETION -> new CohereCompletionModel(inferenceEntityId, serviceSettings, secretSettings, context);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -212,8 +206,7 @@ public CohereModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ secretSettingsMap
);
}
@@ -233,8 +226,7 @@ public CohereModel parsePersistedConfig(String inferenceEntityId, TaskType taskT
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ null
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java
index c67fc328acdd2..9088de1dacd2d 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java
@@ -46,6 +46,8 @@
import java.util.List;
import java.util.Map;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
+
/**
* Contextual AI inference service for reranking tasks.
* This service uses the Contextual AI REST API to perform document reranking.
@@ -97,7 +99,6 @@ public void parseRequestConfig(
serviceSettingsMap,
taskSettingsMap,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -117,11 +118,10 @@ private static ContextualAiRerankModel createModel(
Map serviceSettings,
Map taskSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
if (taskType != TaskType.RERANK) {
- throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
}
return new ContextualAiRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context);
@@ -144,7 +144,6 @@ public ContextualAiRerankModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
secretSettingsMap,
- ServiceUtils.parsePersistedConfigErrorMsg(inferenceEntityId, NAME),
ConfigurationParseContext.PERSISTENT
);
}
@@ -154,15 +153,7 @@ public ContextualAiRerankModel parsePersistedConfig(String inferenceEntityId, Ta
Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
- return createModel(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- null,
- ServiceUtils.parsePersistedConfigErrorMsg(inferenceEntityId, NAME),
- ConfigurationParseContext.PERSISTENT
- );
+ return createModel(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null, ConfigurationParseContext.PERSISTENT);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
index fd29b02012185..aa7fd1337428f 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
@@ -57,9 +57,9 @@
import java.util.List;
import java.util.Map;
-import static org.elasticsearch.inference.TaskType.unsupportedTaskTypeErrorMsg;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -106,7 +106,12 @@ public void parseRequestConfig(
Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
- var chunkingSettings = extractChunkingSettings(config, taskType);
+ ChunkingSettings chunkingSettings = null;
+ if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
+ chunkingSettings = ChunkingSettingsBuilder.fromMap(
+ removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
+ );
+ }
CustomModel model = createModel(
inferenceEntityId,
@@ -157,14 +162,6 @@ private static RequestParameters createParameters(CustomModel model) {
};
}
- private static ChunkingSettings extractChunkingSettings(Map config, TaskType taskType) {
- if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
- return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
- }
-
- return null;
- }
-
@Override
public InferenceServiceConfiguration getConfiguration() {
return Configuration.get();
@@ -214,7 +211,7 @@ private static CustomModel createModel(
ConfigurationParseContext context
) {
if (supportedTaskTypes.contains(taskType) == false) {
- throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST);
+ throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
}
return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, chunkingSettings, context);
}
@@ -230,7 +227,7 @@ public CustomModel parsePersistedConfigWithSecrets(
Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
- var chunkingSettings = extractChunkingSettings(config, taskType);
+ var chunkingSettings = extractPersistentChunkingSettings(config, taskType);
return createModelWithoutLoggingDeprecations(
inferenceEntityId,
@@ -242,12 +239,27 @@ public CustomModel parsePersistedConfigWithSecrets(
);
}
+ private static ChunkingSettings extractPersistentChunkingSettings(Map config, TaskType taskType) {
+ if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
+ /*
+ * There's a sutle difference between how the chunking settings are parsed for the request context vs the persistent context.
+ * For persistent context, to support backwards compatibility, if the chunking settings are not present, removeFromMap will
+ * return null which results in the older word boundary chunking settings being used as the default.
+ * For request context, removeFromMapOrDefaultEmpty returns an empty map which results in the newer sentence boundary chunking
+ * settings being used as the default.
+ */
+ return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
+ }
+
+ return null;
+ }
+
@Override
public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) {
Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
- var chunkingSettings = extractChunkingSettings(config, taskType);
+ var chunkingSettings = extractPersistentChunkingSettings(config, taskType);
return createModelWithoutLoggingDeprecations(
inferenceEntityId,
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java
index cc871da8eb860..c4156c0bfd6b9 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java
@@ -79,7 +79,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -446,7 +446,6 @@ public void parseRequestConfig(
chunkingSettings,
serviceSettingsMap,
elasticInferenceServiceComponents,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -494,7 +493,6 @@ private static ElasticInferenceServiceModel createModel(
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -540,7 +538,7 @@ private static ElasticInferenceServiceModel createModel(
context,
chunkingSettings
);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -566,8 +564,7 @@ public Model parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ secretSettingsMap
);
}
@@ -581,15 +578,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
@@ -603,8 +592,7 @@ private ElasticInferenceServiceModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -614,7 +602,6 @@ private ElasticInferenceServiceModel createModelFromPersistent(
chunkingSettings,
secretSettings,
elasticInferenceServiceComponents,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java
index 97bd2502d25b6..dc08ec8544e3c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.googleaistudio;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -31,7 +30,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
@@ -60,7 +58,7 @@
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -127,7 +125,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -149,7 +146,6 @@ private static GoogleAiStudioModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -172,7 +168,7 @@ private static GoogleAiStudioModel createModel(
secretSettings,
context
);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -198,8 +194,7 @@ public GoogleAiStudioModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ secretSettingsMap
);
}
@@ -209,8 +204,7 @@ private static GoogleAiStudioModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- Map secretSettings,
- String failureMessage
+ Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -219,7 +213,6 @@ private static GoogleAiStudioModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -234,15 +227,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java
index 19ab67b920f04..66a4dd0649730 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java
@@ -8,7 +8,6 @@
package org.elasticsearch.xpack.inference.services.googlevertexai;
import org.elasticsearch.ElasticsearchException;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -31,7 +30,6 @@
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
@@ -63,7 +61,7 @@
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -149,7 +147,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -185,8 +182,7 @@ public Model parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ secretSettingsMap
);
}
@@ -200,15 +196,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
@@ -356,8 +344,7 @@ private static GoogleVertexAiModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- Map secretSettings,
- String failureMessage
+ Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -366,7 +353,6 @@ private static GoogleVertexAiModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -378,7 +364,6 @@ private static GoogleVertexAiModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -412,7 +397,7 @@ private static GoogleVertexAiModel createModel(
context
);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java
index 325f88c8904a3..403a1758983ac 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java
@@ -31,7 +31,6 @@
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -84,7 +83,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
ConfigurationParseContext.REQUEST
)
);
@@ -123,7 +121,6 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
)
);
@@ -147,7 +144,6 @@ public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType
taskSettingsMap,
chunkingSettings,
null,
- parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
)
);
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java
index 6dabaa66ffb2b..7600207eebad0 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModelParameters.java
@@ -20,6 +20,5 @@ public record HuggingFaceModelParameters(
Map taskSettings,
ChunkingSettings chunkingSettings,
Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java
index d0a98d8252923..e727a9e20cd8c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.huggingface;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -26,7 +25,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
@@ -54,6 +52,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
/**
* This class is responsible for managing the Hugging Face inference service.
@@ -124,7 +123,7 @@ protected HuggingFaceModel createModel(HuggingFaceModelParameters params) {
params.secretSettings(),
params.context()
);
- default -> throw new ElasticsearchStatusException(params.failureMessage(), RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(params.inferenceEntityId(), NAME, params.taskType(), params.context());
};
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java
index 775a4e90ae034..081d5c63b84ff 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.huggingface.elser;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -25,7 +24,6 @@
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
@@ -50,6 +48,7 @@
import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL;
@@ -87,7 +86,7 @@ protected HuggingFaceModel createModel(HuggingFaceModelParameters input) {
input.secretSettings(),
input.context()
);
- default -> throw new ElasticsearchStatusException(input.failureMessage(), RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(input.inferenceEntityId(), NAME, input.taskType(), input.context());
};
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java
index 8cdc8cd182425..ee836d03747de 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.ibmwatsonx;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -30,7 +29,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
@@ -62,7 +60,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@@ -128,7 +126,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -150,7 +147,6 @@ private static IbmWatsonxModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -181,7 +177,7 @@ private static IbmWatsonxModel createModel(
secretSettings,
context
);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -207,8 +203,7 @@ public IbmWatsonxModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ secretSettingsMap
);
}
@@ -228,8 +223,7 @@ private static IbmWatsonxModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- Map secretSettings,
- String failureMessage
+ Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -238,7 +232,6 @@ private static IbmWatsonxModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -253,15 +246,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java
index f6bd954617b76..5e95f85e78ecd 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.jinaai;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -31,7 +30,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -57,7 +55,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -121,7 +119,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -141,8 +138,7 @@ private static JinaAIModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -151,7 +147,6 @@ private static JinaAIModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -163,7 +158,6 @@ private static JinaAIModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -177,7 +171,7 @@ private static JinaAIModel createModel(
context
);
case RERANK -> new JinaAIRerankModel(inferenceEntityId, NAME, serviceSettings, taskSettings, secretSettings, context);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -203,8 +197,7 @@ public JinaAIModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ secretSettingsMap
);
}
@@ -218,15 +211,7 @@ public JinaAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskT
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java
index a74f3202e5fb4..b3c0e12927fff 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.llama;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
@@ -28,7 +27,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
@@ -64,7 +62,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -138,7 +136,6 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
* @param serviceSettings the settings for the inference service
* @param chunkingSettings the settings for chunking, if applicable
* @param secretSettings the secret settings for the model, such as API keys or tokens
- * @param failureMessage the message to use in case of failure
* @param context the context for parsing configuration settings
* @return a new instance of LlamaModel based on the provided parameters
*/
@@ -148,7 +145,6 @@ protected LlamaModel createModel(
Map serviceSettings,
ChunkingSettings chunkingSettings,
Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
switch (taskType) {
@@ -157,7 +153,7 @@ protected LlamaModel createModel(
case CHAT_COMPLETION, COMPLETION:
return new LlamaChatCompletionModel(inferenceId, taskType, NAME, serviceSettings, secretSettings, context);
default:
- throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ throw createInvalidTaskTypeException(inferenceId, NAME, taskType, context);
}
}
@@ -283,7 +279,6 @@ public void parseRequestConfig(
serviceSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -313,14 +308,7 @@ public Model parsePersistedConfigWithSecrets(
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- modelId,
- taskType,
- serviceSettingsMap,
- chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(modelId, NAME)
- );
+ return createModelFromPersistent(modelId, taskType, serviceSettingsMap, chunkingSettings, secretSettingsMap);
}
private LlamaModel createModelFromPersistent(
@@ -328,8 +316,7 @@ private LlamaModel createModelFromPersistent(
TaskType taskType,
Map serviceSettings,
ChunkingSettings chunkingSettings,
- Map secretSettings,
- String failureMessage
+ Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -337,7 +324,6 @@ private LlamaModel createModelFromPersistent(
serviceSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -352,14 +338,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
switch (taskType) {
@@ -299,7 +281,7 @@ private static MistralModel createModel(
case CHAT_COMPLETION, COMPLETION:
return new MistralChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context);
default:
- throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ throw createInvalidTaskTypeException(modelId, NAME, taskType, context);
}
}
@@ -308,8 +290,7 @@ private MistralModel createModelFromPersistent(
TaskType taskType,
Map serviceSettings,
ChunkingSettings chunkingSettings,
- Map secretSettings,
- String failureMessage
+ Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -317,7 +298,6 @@ private MistralModel createModelFromPersistent(
serviceSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
index ae49f5dcef13b..9e77c0d63e336 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
@@ -65,7 +65,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -139,7 +139,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -159,8 +158,7 @@ private static OpenAiModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -169,7 +167,6 @@ private static OpenAiModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -181,7 +178,6 @@ private static OpenAiModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -204,7 +200,7 @@ private static OpenAiModel createModel(
secretSettings,
context
);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -232,8 +228,7 @@ public OpenAiModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ secretSettingsMap
);
}
@@ -249,15 +244,7 @@ public OpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskT
moveModelFromTaskToServiceSettings(taskSettingsMap, serviceSettingsMap);
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java
index 88840cc1202ac..6d340320c655c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java
@@ -47,7 +47,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
// The rate limit for usage tier 1 is 500 request per minute for most of the completion models
// To find this information you need to access your account's limits https://platform.openai.com/account/limits
// 500 requests per minute
- private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500);
+ public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500);
public static OpenAiChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java
index 20b9bf931290c..e9b6a7c77a5fa 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java
@@ -50,11 +50,11 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
public static final String NAME = "openai_service_settings";
- static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
+ public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
// The rate limit for usage tier 1 is 3000 request per minute for the text embedding models
// To find this information you need to access your account's limits https://platform.openai.com/account/limits
// 3000 requests per minute
- private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
+ public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
public static OpenAiEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) {
return switch (context) {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java
index c69aeec203e4c..f2db85f263f9f 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java
@@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.services.voyageai;
-import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
@@ -31,7 +30,6 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -56,7 +54,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -152,7 +150,6 @@ public void parseRequestConfig(
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
- TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
@@ -172,8 +169,7 @@ private static VoyageAIModel createModelFromPersistent(
Map serviceSettings,
Map taskSettings,
ChunkingSettings chunkingSettings,
- @Nullable Map secretSettings,
- String failureMessage
+ @Nullable Map secretSettings
) {
return createModel(
inferenceEntityId,
@@ -182,7 +178,6 @@ private static VoyageAIModel createModelFromPersistent(
taskSettings,
chunkingSettings,
secretSettings,
- failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@@ -194,7 +189,6 @@ private static VoyageAIModel createModel(
Map taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map secretSettings,
- String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
@@ -208,7 +202,7 @@ private static VoyageAIModel createModel(
context
);
case RERANK -> new VoyageAIRerankModel(inferenceEntityId, NAME, serviceSettings, taskSettings, secretSettings, context);
- default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
};
}
@@ -234,8 +228,7 @@ public VoyageAIModel parsePersistedConfigWithSecrets(
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
- secretSettingsMap,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ secretSettingsMap
);
}
@@ -249,15 +242,7 @@ public VoyageAIModel parsePersistedConfig(String inferenceEntityId, TaskType tas
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
- return createModelFromPersistent(
- inferenceEntityId,
- taskType,
- serviceSettingsMap,
- taskSettingsMap,
- chunkingSettings,
- null,
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
- );
+ return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null);
}
@Override
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java
new file mode 100644
index 0000000000000..7657f7e79cf25
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java
@@ -0,0 +1,173 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services;
+
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.RerankingInferenceService;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.junit.After;
+import org.junit.Before;
+
+import java.util.EnumSet;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.mockito.Mockito.mock;
+
+public abstract class AbstractInferenceServiceBaseTests extends InferenceServiceTestCase {
+ protected final TestConfiguration testConfiguration;
+
+ protected final MockWebServer webServer = new MockWebServer();
+ protected ThreadPool threadPool;
+ protected HttpClientManager clientManager;
+ protected AbstractInferenceServiceParameterizedTests.TestCase testCase;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ webServer.start();
+ threadPool = createThreadPool(inferenceUtilityExecutors());
+ clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+ }
+
+ @Override
+ @After
+ public void tearDown() throws Exception {
+ super.tearDown();
+ clientManager.close();
+ terminate(threadPool);
+ webServer.close();
+ }
+
+ public AbstractInferenceServiceBaseTests(TestConfiguration testConfiguration) {
+ this.testConfiguration = Objects.requireNonNull(testConfiguration);
+ }
+
+ /**
+ * Main configurations for the tests
+ */
+ public record TestConfiguration(CommonConfig commonConfig, UpdateModelConfiguration updateModelConfiguration) {
+ public static class Builder {
+ private final CommonConfig commonConfig;
+ private UpdateModelConfiguration updateModelConfiguration = DISABLED_UPDATE_MODEL_TESTS;
+
+ public Builder(CommonConfig commonConfig) {
+ this.commonConfig = commonConfig;
+ }
+
+ public TestConfiguration.Builder enableUpdateModelTests(UpdateModelConfiguration updateModelConfiguration) {
+ this.updateModelConfiguration = updateModelConfiguration;
+ return this;
+ }
+
+ public TestConfiguration build() {
+ return new TestConfiguration(commonConfig, updateModelConfiguration);
+ }
+ }
+ }
+
+ /**
+ * Configurations that are useful for most tests
+ */
+ public abstract static class CommonConfig {
+
+ private final TaskType targetTaskType;
+ private final TaskType unsupportedTaskType;
+ private final EnumSet supportedTaskTypes;
+
+ public CommonConfig(TaskType targetTaskType, @Nullable TaskType unsupportedTaskType, EnumSet supportedTaskTypes) {
+ this.targetTaskType = Objects.requireNonNull(targetTaskType);
+ this.unsupportedTaskType = unsupportedTaskType;
+ this.supportedTaskTypes = Objects.requireNonNull(supportedTaskTypes);
+ }
+
+ public TaskType targetTaskType() {
+ return targetTaskType;
+ }
+
+ public TaskType unsupportedTaskType() {
+ return unsupportedTaskType;
+ }
+
+ public EnumSet supportedTaskTypes() {
+ return supportedTaskTypes;
+ }
+
+ protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager);
+
+ protected abstract Map createServiceSettingsMap(TaskType taskType);
+
+ protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) {
+ return createServiceSettingsMap(taskType);
+ }
+
+ protected abstract Map createTaskSettingsMap();
+
+ protected abstract Map createSecretSettingsMap();
+
+ protected abstract void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets);
+
+ protected void assertModel(Model model, TaskType taskType) {
+ assertModel(model, taskType, true);
+ }
+
+ protected abstract EnumSet supportedStreamingTasks();
+
+ /**
+ * Override this method if the service support reranking. This method won't be called if the service doesn't support reranking.
+ */
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ fail("Reranking services should override this test method to verify window size");
+ }
+ }
+
+ /**
+ * Configurations specific to the {@link SenderService#updateModelWithEmbeddingDetails(Model, int)} tests
+ */
+ public abstract static class UpdateModelConfiguration {
+
+ public boolean isEnabled() {
+ return true;
+ }
+
+ protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure);
+ }
+
+ private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() {
+ @Override
+ public boolean isEnabled() {
+ return false;
+ }
+
+ @Override
+ protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) {
+ throw new UnsupportedOperationException("Update model tests are disabled");
+ }
+ };
+
+ @Override
+ public InferenceService createInferenceService() {
+ return testConfiguration.commonConfig.createService(threadPool, clientManager);
+ }
+
+ @Override
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ testConfiguration.commonConfig.assertRerankerWindowSize(rerankingInferenceService);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java
new file mode 100644
index 0000000000000..37ba073407df8
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedTests.java
@@ -0,0 +1,387 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services;
+
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.core.Strings;
+import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.inference.Utils;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
+import org.junit.Assume;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Map;
+import java.util.function.Function;
+
+import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
+import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.is;
+
+/**
+ * Base class for testing inference services using parameterized tests.
+ */
+public abstract class AbstractInferenceServiceParameterizedTests extends AbstractInferenceServiceBaseTests {
+
+ public AbstractInferenceServiceParameterizedTests(
+ AbstractInferenceServiceBaseTests.TestConfiguration testConfiguration,
+ TestCase testCase
+ ) {
+ super(testConfiguration);
+ this.testCase = testCase;
+ }
+
+ @Override
+ public InferenceService createInferenceService() {
+ return testConfiguration.commonConfig().createService(threadPool, clientManager);
+ }
+
+ public record TestCase(
+ String description,
+ Function createPersistedConfig,
+ ServiceParser serviceParser,
+ TaskType expectedTaskType,
+ boolean modelIncludesSecrets,
+ boolean expectFailure
+ ) {}
+
+ private record ServiceParserParams(
+ SenderService service,
+ Utils.PersistedConfig persistedConfig,
+ AbstractInferenceServiceBaseTests.TestConfiguration testConfiguration
+ ) {}
+
+ @FunctionalInterface
+ private interface ServiceParser {
+ Model parseConfigs(ServiceParserParams params);
+ }
+
+ private static class TestCaseBuilder {
+ private final String description;
+ private final Function createPersistedConfig;
+ private final ServiceParser serviceParser;
+ private final TaskType expectedTaskType;
+ private boolean modelIncludesSecrets;
+ private boolean expectFailure;
+
+ TestCaseBuilder(
+ String description,
+ Function createPersistedConfig,
+ ServiceParser serviceParser,
+ TaskType expectedTaskType
+ ) {
+ this.description = description;
+ this.createPersistedConfig = createPersistedConfig;
+ this.serviceParser = serviceParser;
+ this.expectedTaskType = expectedTaskType;
+ }
+
+ public TestCaseBuilder withSecrets() {
+ this.modelIncludesSecrets = true;
+ return this;
+ }
+
+ public TestCaseBuilder expectFailure() {
+ this.expectFailure = true;
+ return this;
+ }
+
+ public TestCase build() {
+ return new TestCase(description, createPersistedConfig, serviceParser, expectedTaskType, modelIncludesSecrets, expectFailure);
+ }
+ }
+
+ @ParametersFactory
+ public static Iterable parameters() throws IOException {
+ return Arrays.asList(
+ new TestCase[][] {
+ // Test cases for parsePersistedConfig method
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config without chunking settings",
+ testConfiguration -> getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ null
+ ),
+ (params) -> params.service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, params.persistedConfig.config()),
+ TaskType.TEXT_EMBEDDING
+ ).build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with chunking settings",
+ testConfiguration -> getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ createRandomChunkingSettingsMap(),
+ null
+ ),
+ (params) -> params.service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, params.persistedConfig.config()),
+ TaskType.TEXT_EMBEDDING
+ ).build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config does not throw when an extra key exists in config",
+ testConfiguration -> {
+ var persistedConfigMap = getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ null
+ );
+ persistedConfigMap.config().put("extra_key", "value");
+ return persistedConfigMap;
+ },
+ (params) -> params.service.parsePersistedConfig("id", TaskType.COMPLETION, params.persistedConfig.config()),
+ TaskType.COMPLETION
+ ).build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config does not throw when an extra key exists in service settings",
+ testConfiguration -> {
+ var serviceSettings = testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT);
+ serviceSettings.put("extra_key", "value");
+
+ return getPersistedConfigMap(serviceSettings, testConfiguration.commonConfig().createTaskSettingsMap(), null);
+ },
+ (params) -> params.service.parsePersistedConfig("id", TaskType.COMPLETION, params.persistedConfig.config()),
+ TaskType.COMPLETION
+ ).build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config does not throw when an extra key exists in task settings",
+ testConfiguration -> {
+ var taskSettingsMap = testConfiguration.commonConfig().createTaskSettingsMap();
+ taskSettingsMap.put("extra_key", "value");
+
+ return getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
+ taskSettingsMap,
+ null
+ );
+ },
+ (params) -> params.service.parsePersistedConfig("id", TaskType.COMPLETION, params.persistedConfig.config()),
+ TaskType.COMPLETION
+ ).build() },
+ // Test cases for parsePersistedConfigWithSecrets method
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with secrets creates an embeddings model",
+ testConfiguration -> getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ testConfiguration.commonConfig().createSecretSettingsMap()
+ ),
+ (params) -> params.service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
+ ),
+ TaskType.TEXT_EMBEDDING
+ ).withSecrets().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with with secrets creates an embeddings "
+ + "model when chunking settings are provided",
+ testConfiguration -> getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ createRandomChunkingSettingsMap(),
+ testConfiguration.commonConfig().createSecretSettingsMap()
+ ),
+ (params) -> params.service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
+ ),
+ TaskType.TEXT_EMBEDDING
+ ).withSecrets().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with with secrets creates a completion "
+ + "model when chunking settings are not provided",
+ testConfiguration -> getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ testConfiguration.commonConfig().createSecretSettingsMap()
+ ),
+ (params) -> params.service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.COMPLETION,
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
+ ),
+ TaskType.COMPLETION
+ ).withSecrets().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with with secrets throws exception for unsupported task type",
+ testConfiguration -> getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ testConfiguration.commonConfig().createSecretSettingsMap()
+ ),
+ (params) -> params.service.parsePersistedConfigWithSecrets(
+ "id",
+ params.testConfiguration.commonConfig().unsupportedTaskType(),
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
+ ),
+ TaskType.COMPLETION
+ ).withSecrets().expectFailure().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with with secrets does not throw when an extra key exists in config",
+ testConfiguration -> {
+ var persistedConfigMap = getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ testConfiguration.commonConfig().createSecretSettingsMap()
+ );
+ persistedConfigMap.config().put("extra_key", "value");
+ return persistedConfigMap;
+ },
+ (params) -> params.service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.COMPLETION,
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
+ ),
+ TaskType.COMPLETION
+ ).withSecrets().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with with secrets does not throw when an extra key exists in service settings",
+ testConfiguration -> {
+ var serviceSettings = testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT);
+ serviceSettings.put("extra_key", "value");
+
+ return getPersistedConfigMap(
+ serviceSettings,
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ testConfiguration.commonConfig().createSecretSettingsMap()
+ );
+ },
+ (params) -> params.service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.COMPLETION,
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
+ ),
+ TaskType.COMPLETION
+ ).withSecrets().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with with secrets does not throw when an extra key exists in task settings",
+ testConfiguration -> {
+ var taskSettingsMap = testConfiguration.commonConfig().createTaskSettingsMap();
+ taskSettingsMap.put("extra_key", "value");
+
+ return getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
+ taskSettingsMap,
+ testConfiguration.commonConfig().createSecretSettingsMap()
+ );
+ },
+ (params) -> params.service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.COMPLETION,
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
+ ),
+ TaskType.COMPLETION
+ ).withSecrets().build() },
+ {
+ new TestCaseBuilder(
+ "Test parsing persisted config with with secrets does not throw when an extra key exists in secret settings",
+ testConfiguration -> {
+ var secretSettingsMap = testConfiguration.commonConfig().createSecretSettingsMap();
+ secretSettingsMap.put("extra_key", "value");
+
+ return getPersistedConfigMap(
+ testConfiguration.commonConfig()
+ .createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.PERSISTENT),
+ testConfiguration.commonConfig().createTaskSettingsMap(),
+ secretSettingsMap
+ );
+ },
+ (params) -> params.service.parsePersistedConfigWithSecrets(
+ "id",
+ TaskType.COMPLETION,
+ params.persistedConfig.config(),
+ params.persistedConfig.secrets()
+ ),
+ TaskType.COMPLETION
+ ).withSecrets().build() } }
+ );
+ }
+
+ public void testPersistedConfig() throws Exception {
+ // If the service doesn't support the expected task type, then skip the test
+ Assume.assumeTrue(testConfiguration.commonConfig().supportedTaskTypes().contains(testCase.expectedTaskType));
+
+ var parseConfigTestConfig = testConfiguration.commonConfig();
+ var persistedConfig = testCase.createPersistedConfig.apply(testConfiguration);
+
+ try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
+
+ if (testCase.expectFailure) {
+ assertFailedParse(service, persistedConfig);
+ } else {
+ assertSuccessfulParse(service, persistedConfig);
+ }
+ }
+ }
+
+ private void assertFailedParse(SenderService service, Utils.PersistedConfig persistedConfig) {
+ var exception = expectThrows(
+ ElasticsearchStatusException.class,
+ () -> testCase.serviceParser.parseConfigs(new ServiceParserParams(service, persistedConfig, testConfiguration))
+ );
+
+ assertThat(
+ exception.getMessage(),
+ containsString(
+ Strings.format("service does not support task type [%s]", testConfiguration.commonConfig().unsupportedTaskType())
+ )
+ );
+ }
+
+ private void assertSuccessfulParse(SenderService service, Utils.PersistedConfig persistedConfig) throws Exception {
+ var model = testCase.serviceParser.parseConfigs(new ServiceParserParams(service, persistedConfig, testConfiguration));
+
+ if (persistedConfig.config().containsKey(ModelConfigurations.CHUNKING_SETTINGS)) {
+ @SuppressWarnings("unchecked")
+ var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(
+ (Map) persistedConfig.config().get(ModelConfigurations.CHUNKING_SETTINGS)
+ );
+ assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
+ }
+
+ testConfiguration.commonConfig().assertModel(model, testCase.expectedTaskType, testCase.modelIncludesSecrets);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java
index ec07d7b547004..9479a8ea6e10d 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java
@@ -9,39 +9,27 @@
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.support.PlainActionFuture;
-import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.test.http.MockWebServer;
-import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
-import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
-import org.junit.After;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.junit.Assume;
-import org.junit.Before;
import java.io.IOException;
-import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.Objects;
import static org.elasticsearch.xpack.inference.Utils.TIMEOUT;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
-import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
-import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
-import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
-import static org.mockito.Mockito.mock;
/**
* Base class for testing inference services.
@@ -51,132 +39,64 @@
* To use this class, extend it and pass the constructor a configuration.
*
*/
-public abstract class AbstractInferenceServiceTests extends InferenceServiceTestCase {
-
- protected final MockWebServer webServer = new MockWebServer();
- protected ThreadPool threadPool;
- protected HttpClientManager clientManager;
-
- @Override
- @Before
- public void setUp() throws Exception {
- super.setUp();
- webServer.start();
- threadPool = createThreadPool(inferenceUtilityExecutors());
- clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
- }
-
- @Override
- @After
- public void tearDown() throws Exception {
- super.tearDown();
- clientManager.close();
- terminate(threadPool);
- webServer.close();
- }
-
- private final TestConfiguration testConfiguration;
+public abstract class AbstractInferenceServiceTests extends AbstractInferenceServiceBaseTests {
public AbstractInferenceServiceTests(TestConfiguration testConfiguration) {
- this.testConfiguration = Objects.requireNonNull(testConfiguration);
- }
-
- /**
- * Main configurations for the tests
- */
- public record TestConfiguration(CommonConfig commonConfig, UpdateModelConfiguration updateModelConfiguration) {
- public static class Builder {
- private final CommonConfig commonConfig;
- private UpdateModelConfiguration updateModelConfiguration = DISABLED_UPDATE_MODEL_TESTS;
-
- public Builder(CommonConfig commonConfig) {
- this.commonConfig = commonConfig;
- }
-
- public Builder enableUpdateModelTests(UpdateModelConfiguration updateModelConfiguration) {
- this.updateModelConfiguration = updateModelConfiguration;
- return this;
- }
-
- public TestConfiguration build() {
- return new TestConfiguration(commonConfig, updateModelConfiguration);
- }
- }
+ super(testConfiguration);
}
- /**
- * Configurations that useful for most tests
- */
- public abstract static class CommonConfig {
-
- private final TaskType taskType;
- private final TaskType unsupportedTaskType;
-
- public CommonConfig(TaskType taskType, @Nullable TaskType unsupportedTaskType) {
- this.taskType = Objects.requireNonNull(taskType);
- this.unsupportedTaskType = unsupportedTaskType;
- }
-
- protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager);
-
- protected abstract Map createServiceSettingsMap(TaskType taskType);
-
- protected abstract Map createTaskSettingsMap();
-
- protected abstract Map createSecretSettingsMap();
+ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws Exception {
+ Assume.assumeTrue(testConfiguration.commonConfig().supportedTaskTypes().contains(TaskType.TEXT_EMBEDDING));
- protected abstract void assertModel(Model model, TaskType taskType);
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
- protected abstract EnumSet supportedStreamingTasks();
- }
+ try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
+ var config = getRequestConfigMap(
+ parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.REQUEST),
+ parseRequestConfigTestConfig.createTaskSettingsMap(),
+ parseRequestConfigTestConfig.createSecretSettingsMap()
+ );
- /**
- * Configurations specific to the {@link SenderService#updateModelWithEmbeddingDetails(Model, int)} tests
- */
- public abstract static class UpdateModelConfiguration {
+ var listener = new PlainActionFuture();
+ service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener);
- public boolean isEnabled() {
- return true;
+ var model = listener.actionGet(TIMEOUT);
+ var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(Map.of());
+ assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
+ parseRequestConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
}
-
- protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure);
}
- private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() {
- @Override
- public boolean isEnabled() {
- return false;
- }
+ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws Exception {
+ Assume.assumeTrue(testConfiguration.commonConfig().supportedTaskTypes().contains(TaskType.TEXT_EMBEDDING));
- @Override
- protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) {
- throw new UnsupportedOperationException("Update model tests are disabled");
- }
- };
-
- public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws Exception {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
+ var chunkingSettingsMap = createRandomChunkingSettingsMap();
var config = getRequestConfigMap(
- parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
+ parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING, ConfigurationParseContext.REQUEST),
parseRequestConfigTestConfig.createTaskSettingsMap(),
+ chunkingSettingsMap,
parseRequestConfigTestConfig.createSecretSettingsMap()
);
var listener = new PlainActionFuture();
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener);
- parseRequestConfigTestConfig.assertModel(listener.actionGet(TIMEOUT), TaskType.TEXT_EMBEDDING);
+ var model = listener.actionGet(TIMEOUT);
+ var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap);
+ assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
+ parseRequestConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
}
}
public void testParseRequestConfig_CreatesACompletionModel() throws Exception {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var config = getRequestConfigMap(
- parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION),
+ parseRequestConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION, ConfigurationParseContext.REQUEST),
parseRequestConfigTestConfig.createTaskSettingsMap(),
parseRequestConfigTestConfig.createSecretSettingsMap()
);
@@ -189,39 +109,47 @@ public void testParseRequestConfig_CreatesACompletionModel() throws Exception {
}
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws Exception {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var config = getRequestConfigMap(
- parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType),
+ parseRequestConfigTestConfig.createServiceSettingsMap(
+ parseRequestConfigTestConfig.targetTaskType(),
+ ConfigurationParseContext.REQUEST
+ ),
parseRequestConfigTestConfig.createTaskSettingsMap(),
parseRequestConfigTestConfig.createSecretSettingsMap()
);
var listener = new PlainActionFuture();
- service.parseRequestConfig("id", parseRequestConfigTestConfig.unsupportedTaskType, config, listener);
+ service.parseRequestConfig("id", parseRequestConfigTestConfig.unsupportedTaskType(), config, listener);
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
exception.getMessage(),
- containsString(Strings.format("service does not support task type [%s]", parseRequestConfigTestConfig.unsupportedTaskType))
+ containsString(
+ Strings.format("service does not support task type [%s]", parseRequestConfigTestConfig.unsupportedTaskType())
+ )
);
}
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var config = getRequestConfigMap(
- parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType),
+ parseRequestConfigTestConfig.createServiceSettingsMap(
+ parseRequestConfigTestConfig.targetTaskType(),
+ ConfigurationParseContext.REQUEST
+ ),
parseRequestConfigTestConfig.createTaskSettingsMap(),
parseRequestConfigTestConfig.createSecretSettingsMap()
);
config.put("extra_key", "value");
var listener = new PlainActionFuture();
- service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener);
+ service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener);
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]"));
@@ -229,9 +157,12 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
- var serviceSettings = parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType);
+ var serviceSettings = parseRequestConfigTestConfig.createServiceSettingsMap(
+ parseRequestConfigTestConfig.targetTaskType(),
+ ConfigurationParseContext.REQUEST
+ );
serviceSettings.put("extra_key", "value");
var config = getRequestConfigMap(
serviceSettings,
@@ -240,7 +171,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa
);
var listener = new PlainActionFuture();
- service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener);
+ service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener);
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]"));
@@ -248,18 +179,21 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var taskSettings = parseRequestConfigTestConfig.createTaskSettingsMap();
taskSettings.put("extra_key", "value");
var config = getRequestConfigMap(
- parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType),
+ parseRequestConfigTestConfig.createServiceSettingsMap(
+ parseRequestConfigTestConfig.targetTaskType(),
+ ConfigurationParseContext.REQUEST
+ ),
taskSettings,
parseRequestConfigTestConfig.createSecretSettingsMap()
);
var listener = new PlainActionFuture();
- service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener);
+ service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener);
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]"));
@@ -267,179 +201,29 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap()
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException {
- var parseRequestConfigTestConfig = testConfiguration.commonConfig;
+ var parseRequestConfigTestConfig = testConfiguration.commonConfig();
try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) {
var secretSettingsMap = parseRequestConfigTestConfig.createSecretSettingsMap();
secretSettingsMap.put("extra_key", "value");
var config = getRequestConfigMap(
- parseRequestConfigTestConfig.createServiceSettingsMap(parseRequestConfigTestConfig.taskType),
+ parseRequestConfigTestConfig.createServiceSettingsMap(
+ parseRequestConfigTestConfig.targetTaskType(),
+ ConfigurationParseContext.REQUEST
+ ),
parseRequestConfigTestConfig.createTaskSettingsMap(),
secretSettingsMap
);
var listener = new PlainActionFuture();
- service.parseRequestConfig("id", parseRequestConfigTestConfig.taskType, config, listener);
+ service.parseRequestConfig("id", parseRequestConfigTestConfig.targetTaskType(), config, listener);
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("Configuration contains settings [{extra_key=value}]"));
}
}
- // parsePersistedConfigWithSecrets
-
- public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throws Exception {
- var parseConfigTestConfig = testConfiguration.commonConfig;
-
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
- parseConfigTestConfig.createTaskSettingsMap(),
- parseConfigTestConfig.createSecretSettingsMap()
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfigMap.config(),
- persistedConfigMap.secrets()
- );
- parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
- }
- }
-
- public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws Exception {
- var parseConfigTestConfig = testConfiguration.commonConfig;
-
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(TaskType.COMPLETION),
- parseConfigTestConfig.createTaskSettingsMap(),
- parseConfigTestConfig.createSecretSettingsMap()
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.COMPLETION,
- persistedConfigMap.config(),
- persistedConfigMap.secrets()
- );
- parseConfigTestConfig.assertModel(model, TaskType.COMPLETION);
- }
- }
-
- public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() throws Exception {
- var parseConfigTestConfig = testConfiguration.commonConfig;
-
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType),
- parseConfigTestConfig.createTaskSettingsMap(),
- parseConfigTestConfig.createSecretSettingsMap()
- );
-
- var exception = expectThrows(
- ElasticsearchStatusException.class,
- () -> service.parsePersistedConfigWithSecrets(
- "id",
- parseConfigTestConfig.unsupportedTaskType,
- persistedConfigMap.config(),
- persistedConfigMap.secrets()
- )
- );
-
- assertThat(
- exception.getMessage(),
- containsString(
- Strings.format(fetchPersistedConfigTaskTypeParsingErrorMessageFormat(), parseConfigTestConfig.unsupportedTaskType)
- )
- );
- }
- }
-
- protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() {
- return "service does not support task type [%s]";
- }
-
- public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
- var parseConfigTestConfig = testConfiguration.commonConfig;
-
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var persistedConfigMap = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType),
- parseConfigTestConfig.createTaskSettingsMap(),
- parseConfigTestConfig.createSecretSettingsMap()
- );
- persistedConfigMap.config().put("extra_key", "value");
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- parseConfigTestConfig.taskType,
- persistedConfigMap.config(),
- persistedConfigMap.secrets()
- );
-
- parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
- }
- }
-
- public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException {
- var parseConfigTestConfig = testConfiguration.commonConfig;
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var serviceSettings = parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType);
- serviceSettings.put("extra_key", "value");
- var persistedConfigMap = getPersistedConfigMap(
- serviceSettings,
- parseConfigTestConfig.createTaskSettingsMap(),
- parseConfigTestConfig.createSecretSettingsMap()
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- parseConfigTestConfig.taskType,
- persistedConfigMap.config(),
- persistedConfigMap.secrets()
- );
-
- parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
- }
- }
-
- public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException {
- var parseConfigTestConfig = testConfiguration.commonConfig;
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var taskSettings = parseConfigTestConfig.createTaskSettingsMap();
- taskSettings.put("extra_key", "value");
- var config = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType),
- taskSettings,
- parseConfigTestConfig.createSecretSettingsMap()
- );
-
- var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets());
-
- parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
- }
- }
-
- public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException {
- var parseConfigTestConfig = testConfiguration.commonConfig;
- try (var service = parseConfigTestConfig.createService(threadPool, clientManager)) {
- var secretSettingsMap = parseConfigTestConfig.createSecretSettingsMap();
- secretSettingsMap.put("extra_key", "value");
- var config = getPersistedConfigMap(
- parseConfigTestConfig.createServiceSettingsMap(parseConfigTestConfig.taskType),
- parseConfigTestConfig.createTaskSettingsMap(),
- secretSettingsMap
- );
-
- var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets());
-
- parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
- }
- }
-
public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException {
- try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) {
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
var listener = new PlainActionFuture();
service.infer(
@@ -464,9 +248,9 @@ public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException {
}
public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
- Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled());
+ Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled());
- try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) {
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
var exception = expectThrows(
ElasticsearchStatusException.class,
() -> service.updateModelWithEmbeddingDetails(getInvalidModel("id", "service"), randomNonNegativeInt())
@@ -477,11 +261,11 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO
}
public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
- Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled());
+ Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled());
- try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) {
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
var embeddingSize = randomNonNegativeInt();
- var model = testConfiguration.updateModelConfiguration.createEmbeddingModel(null);
+ var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(null);
Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
@@ -491,11 +275,11 @@ public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel()
}
public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
- Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled());
+ Assume.assumeTrue(testConfiguration.updateModelConfiguration().isEnabled());
- try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) {
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
var embeddingSize = randomNonNegativeInt();
- var model = testConfiguration.updateModelConfiguration.createEmbeddingModel(SimilarityMeasure.COSINE);
+ var model = testConfiguration.updateModelConfiguration().createEmbeddingModel(SimilarityMeasure.COSINE);
Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
@@ -506,8 +290,8 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel
// streaming tests
public void testSupportedStreamingTasks() throws Exception {
- try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) {
- assertThat(service.supportedStreamingTasks(), is(testConfiguration.commonConfig.supportedStreamingTasks()));
+ try (var service = testConfiguration.commonConfig().createService(threadPool, clientManager)) {
+ assertThat(service.supportedStreamingTasks(), is(testConfiguration.commonConfig().supportedStreamingTasks()));
assertFalse(service.canStream(TaskType.ANY));
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceParameterizedTests.java
new file mode 100644
index 0000000000000..1254610f32745
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceParameterizedTests.java
@@ -0,0 +1,18 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.ai21;
+
+import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests;
+
+import static org.elasticsearch.xpack.inference.services.ai21.Ai21ServiceTests.createTestConfiguration;
+
+public class Ai21ServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests {
+ public Ai21ServiceParameterizedTests(AbstractInferenceServiceParameterizedTests.TestCase testCase) {
+ super(createTestConfiguration(), testCase);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
index cbb119d3e5710..8a1f8d806a45b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java
@@ -19,7 +19,6 @@
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.EmptyTaskSettings;
-import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
@@ -86,9 +85,13 @@ public Ai21ServiceTests() {
super(createTestConfiguration());
}
- private static AbstractInferenceServiceTests.TestConfiguration createTestConfiguration() {
+ public static AbstractInferenceServiceTests.TestConfiguration createTestConfiguration() {
return new AbstractInferenceServiceTests.TestConfiguration.Builder(
- new AbstractInferenceServiceTests.CommonConfig(TaskType.COMPLETION, TaskType.TEXT_EMBEDDING) {
+ new AbstractInferenceServiceTests.CommonConfig(
+ TaskType.COMPLETION,
+ TaskType.TEXT_EMBEDDING,
+ EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)
+ ) {
@Override
protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
@@ -111,8 +114,8 @@ protected Map createSecretSettingsMap() {
}
@Override
- protected void assertModel(Model model, TaskType taskType) {
- Ai21ServiceTests.assertModel(model, taskType);
+ protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
+ Ai21ServiceTests.assertModel(model, taskType, modelIncludesSecrets);
}
@Override
@@ -123,32 +126,35 @@ protected EnumSet supportedStreamingTasks() {
).build();
}
- private static void assertModel(Model model, TaskType taskType) {
+ private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
switch (taskType) {
- case COMPLETION -> assertCompletionModel(model);
- case CHAT_COMPLETION -> assertChatCompletionModel(model);
+ case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets);
+ case CHAT_COMPLETION -> assertChatCompletionModel(model, modelIncludesSecrets);
default -> fail("unexpected task type [" + taskType + "]");
}
}
- private static Ai21Model assertCommonModelFields(Model model) {
+ private static Ai21Model assertCommonModelFields(Model model, boolean modelIncludesSecrets) {
assertThat(model, instanceOf(Ai21Model.class));
var customModel = (Ai21Model) model;
assertThat(customModel.uri.toString(), Matchers.is("https://api.ai21.com/studio/v1/chat/completions"));
assertThat(customModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE));
- assertThat(customModel.getSecretSettings().apiKey(), Matchers.is(new SecureString("secret".toCharArray())));
+
+ if (modelIncludesSecrets) {
+ assertThat(customModel.getSecretSettings().apiKey(), Matchers.is(new SecureString("secret".toCharArray())));
+ }
return customModel;
}
- private static void assertCompletionModel(Model model) {
- var customModel = assertCommonModelFields(model);
+ private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) {
+ var customModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(customModel.getTaskType(), Matchers.is(TaskType.COMPLETION));
}
- private static void assertChatCompletionModel(Model model) {
- var customModel = assertCommonModelFields(model);
+ private static void assertChatCompletionModel(Model model, boolean modelIncludesSecrets) {
+ var customModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(customModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION));
}
@@ -165,10 +171,6 @@ private static Map createSecretSettingsMap() {
return new HashMap<>(Map.of("api_key", "secret"));
}
- protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() {
- return "Failed to parse stored model [id] for [ai21] service, please delete and add the service again";
- }
-
@Before
public void init() throws Exception {
webServer.start();
@@ -183,16 +185,6 @@ public void shutdown() throws IOException {
webServer.close();
}
- @Override
- public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() {
- // The Ai21Service does not support Text Embedding, so this test is not applicable.
- }
-
- @Override
- public void testParseRequestConfig_CreatesAnEmbeddingsModel() {
- // The Ai21Service does not support Text Embedding, so this test is not applicable.
- }
-
public void testParseRequestConfig_CreatesChatCompletionsModel() throws IOException {
var url = "url";
var model = "model";
@@ -561,9 +553,4 @@ private Map getRequestConfigMap(Map serviceSetti
return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings));
}
-
- @Override
- public InferenceService createInferenceService() {
- return createService();
- }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
index d7eb32861da92..36a4ce7b18c46 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
@@ -628,9 +628,10 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
)
);
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [amazonbedrock] service"));
assertThat(
thrownException.getMessage(),
- is("Failed to parse stored model [id] for [amazonbedrock] service, please delete and add the service again")
+ containsString("The [amazonbedrock] service does not support task type [sparse_embedding]")
);
}
}
@@ -876,9 +877,10 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro
() -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
);
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [amazonbedrock] service"));
assertThat(
thrownException.getMessage(),
- is("Failed to parse stored model [id] for [amazonbedrock] service, please delete and add the service again")
+ containsString("The [amazonbedrock] service does not support task type [sparse_embedding]")
);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
index f1531929db8c3..127b7d1c4cfae 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
@@ -807,9 +807,10 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
() -> service.parsePersistedConfigWithSecrets("id", TaskType.SPARSE_EMBEDDING, config.config(), config.secrets())
);
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [azureaistudio] service"));
assertThat(
thrownException.getMessage(),
- is("Failed to parse stored model [id] for [azureaistudio] service, please delete and add the service again")
+ containsString("The [azureaistudio] service does not support task type [sparse_embedding]")
);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
index 55b10f2e2b9d7..cf3ac6979b8e3 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
@@ -435,9 +435,10 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
)
);
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [azureopenai] service"));
assertThat(
thrownException.getMessage(),
- is("Failed to parse stored model [id] for [azureopenai] service, please delete and add the service again")
+ containsString("The [azureopenai] service does not support task type [sparse_embedding]")
);
}
}
@@ -668,9 +669,10 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro
() -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
);
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [azureopenai] service"));
assertThat(
thrownException.getMessage(),
- is("Failed to parse stored model [id] for [azureopenai] service, please delete and add the service again")
+ containsString("The [azureopenai] service does not support task type [sparse_embedding]")
);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
index 67f545a8104de..9642e7b85cdc7 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
@@ -450,10 +450,8 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
)
);
- MatcherAssert.assertThat(
- thrownException.getMessage(),
- is("Failed to parse stored model [id] for [cohere] service, please delete and add the service again")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [cohere] service"));
+ assertThat(thrownException.getMessage(), containsString("The [cohere] service does not support task type [sparse_embedding]"));
}
}
@@ -687,10 +685,8 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro
() -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
);
- MatcherAssert.assertThat(
- thrownException.getMessage(),
- is("Failed to parse stored model [id] for [cohere] service, please delete and add the service again")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [cohere] service"));
+ assertThat(thrownException.getMessage(), containsString("The [cohere] service does not support task type [sparse_embedding]"));
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceParameterizedTests.java
new file mode 100644
index 0000000000000..a71a107f3b9d1
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceParameterizedTests.java
@@ -0,0 +1,16 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests;
+
+public class CustomServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests {
+ public CustomServiceParameterizedTests(TestCase testCase) {
+ super(CustomServiceTests.createTestConfiguration(), testCase);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
index 55bb98705a2a3..44bf17c3ac96d 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
@@ -16,7 +16,6 @@
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
-import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
@@ -74,38 +73,52 @@ public CustomServiceTests() {
super(createTestConfiguration());
}
- private static TestConfiguration createTestConfiguration() {
- return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION) {
- @Override
- protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
- return CustomServiceTests.createService(threadPool, clientManager);
- }
+ public static TestConfiguration createTestConfiguration() {
+ return new TestConfiguration.Builder(
+ new CommonConfig(
+ TaskType.TEXT_EMBEDDING,
+ TaskType.CHAT_COMPLETION,
+ EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.COMPLETION)
+ ) {
+ @Override
+ protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
+ return CustomServiceTests.createService(threadPool, clientManager);
+ }
- @Override
- protected Map createServiceSettingsMap(TaskType taskType) {
- return CustomServiceTests.createServiceSettingsMap(taskType);
- }
+ @Override
+ protected Map createServiceSettingsMap(TaskType taskType) {
+ return CustomServiceTests.createServiceSettingsMap(taskType);
+ }
- @Override
- protected Map createTaskSettingsMap() {
- return CustomServiceTests.createTaskSettingsMap();
- }
+ @Override
+ protected Map createTaskSettingsMap() {
+ return CustomServiceTests.createTaskSettingsMap();
+ }
- @Override
- protected Map createSecretSettingsMap() {
- return CustomServiceTests.createSecretSettingsMap();
- }
+ @Override
+ protected Map createSecretSettingsMap() {
+ return CustomServiceTests.createSecretSettingsMap();
+ }
- @Override
- protected void assertModel(Model model, TaskType taskType) {
- CustomServiceTests.assertModel(model, taskType);
- }
+ @Override
+ protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
+ CustomServiceTests.assertModel(model, taskType, modelIncludesSecrets);
+ }
- @Override
- protected EnumSet supportedStreamingTasks() {
- return EnumSet.noneOf(TaskType.class);
+ @Override
+ protected EnumSet supportedStreamingTasks() {
+ return EnumSet.noneOf(TaskType.class);
+ }
+
+ @Override
+ protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
+ assertThat(
+ rerankingInferenceService.rerankerWindowSize("any model"),
+ CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE)
+ );
+ }
}
- }).enableUpdateModelTests(new UpdateModelConfiguration() {
+ ).enableUpdateModelTests(new UpdateModelConfiguration() {
@Override
protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure) {
return createInternalEmbeddingModel(similarityMeasure);
@@ -113,38 +126,40 @@ protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure)
}).build();
}
- private static void assertModel(Model model, TaskType taskType) {
+ private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
switch (taskType) {
- case TEXT_EMBEDDING -> assertTextEmbeddingModel(model);
- case COMPLETION -> assertCompletionModel(model);
+ case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets);
+ case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets);
default -> fail("unexpected task type [" + taskType + "]");
}
}
- private static void assertTextEmbeddingModel(Model model) {
- var customModel = assertCommonModelFields(model);
+ private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) {
+ var customModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(customModel.getTaskType(), is(TaskType.TEXT_EMBEDDING));
assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(TextEmbeddingResponseParser.class));
}
- private static CustomModel assertCommonModelFields(Model model) {
+ private static CustomModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) {
assertThat(model, instanceOf(CustomModel.class));
var customModel = (CustomModel) model;
assertThat(customModel.getServiceSettings().getUrl(), is("http://www.abc.com"));
assertThat(customModel.getTaskSettings().getParameters(), is(Map.of("test_key", "test_value")));
- assertThat(
- customModel.getSecretSettings().getSecretParameters(),
- is(Map.of("test_key", new SecureString("test_value".toCharArray())))
- );
+ if (modelIncludesSecrets) {
+ assertThat(
+ customModel.getSecretSettings().getSecretParameters(),
+ is(Map.of("test_key", new SecureString("test_value".toCharArray())))
+ );
+ }
return customModel;
}
- private static void assertCompletionModel(Model model) {
- var customModel = assertCommonModelFields(model);
+ private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) {
+ var customModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(customModel.getTaskType(), is(TaskType.COMPLETION));
assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(CompletionResponseParser.class));
}
@@ -807,17 +822,4 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
assertThat(requestMap.get("input"), is(List.of("a")));
}
}
-
- @Override
- public InferenceService createInferenceService() {
- return createService(threadPool, clientManager);
- }
-
- @Override
- protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
- assertThat(
- rerankingInferenceService.rerankerWindowSize("any model"),
- CoreMatchers.is(RerankingInferenceService.CONSERVATIVE_DEFAULT_WINDOW_SIZE)
- );
- }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
index fc50acdbd39b6..998559d102ab7 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
@@ -443,10 +443,8 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
)
);
- MatcherAssert.assertThat(
- thrownException.getMessage(),
- is("Failed to parse stored model [id] for [jinaai] service, please delete and add the service again")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [jinaai] service"));
+ assertThat(thrownException.getMessage(), containsString("The [jinaai] service does not support task type [sparse_embedding]"));
}
}
@@ -683,10 +681,8 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro
() -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
);
- MatcherAssert.assertThat(
- thrownException.getMessage(),
- is("Failed to parse stored model [id] for [jinaai] service, please delete and add the service again")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [jinaai] service"));
+ assertThat(thrownException.getMessage(), containsString("The [jinaai] service does not support task type [sparse_embedding]"));
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceParameterizedTests.java
new file mode 100644
index 0000000000000..04afbc28af10a
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceParameterizedTests.java
@@ -0,0 +1,16 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.llama;
+
+import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests;
+
+public class LlamaServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests {
+ public LlamaServiceParameterizedTests(AbstractInferenceServiceParameterizedTests.TestCase testCase) {
+ super(LlamaServiceTests.createTestConfiguration(), testCase);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
index 243235211a7de..f6a0232db529c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
@@ -77,6 +77,9 @@
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
+import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION;
+import static org.elasticsearch.inference.TaskType.COMPLETION;
+import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -106,39 +109,41 @@ public LlamaServiceTests() {
super(createTestConfiguration());
}
- private static TestConfiguration createTestConfiguration() {
- return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) {
+ public static TestConfiguration createTestConfiguration() {
+ return new TestConfiguration.Builder(
+ new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION)) {
- @Override
- protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
- return LlamaServiceTests.createService(threadPool, clientManager);
- }
+ @Override
+ protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
+ return LlamaServiceTests.createService(threadPool, clientManager);
+ }
- @Override
- protected Map createServiceSettingsMap(TaskType taskType) {
- return LlamaServiceTests.createServiceSettingsMap(taskType);
- }
+ @Override
+ protected Map createServiceSettingsMap(TaskType taskType) {
+ return LlamaServiceTests.createServiceSettingsMap(taskType);
+ }
- @Override
- protected Map createTaskSettingsMap() {
- return new HashMap<>();
- }
+ @Override
+ protected Map createTaskSettingsMap() {
+ return new HashMap<>();
+ }
- @Override
- protected Map createSecretSettingsMap() {
- return LlamaServiceTests.createSecretSettingsMap();
- }
+ @Override
+ protected Map createSecretSettingsMap() {
+ return LlamaServiceTests.createSecretSettingsMap();
+ }
- @Override
- protected void assertModel(Model model, TaskType taskType) {
- LlamaServiceTests.assertModel(model, taskType);
- }
+ @Override
+ protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
+ LlamaServiceTests.assertModel(model, taskType, modelIncludesSecrets);
+ }
- @Override
- protected EnumSet supportedStreamingTasks() {
- return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
+ @Override
+ protected EnumSet supportedStreamingTasks() {
+ return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
+ }
}
- }).enableUpdateModelTests(new UpdateModelConfiguration() {
+ ).enableUpdateModelTests(new UpdateModelConfiguration() {
@Override
protected LlamaEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) {
return createInternalEmbeddingModel(similarityMeasure);
@@ -146,43 +151,46 @@ protected LlamaEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarity
}).build();
}
- private static void assertModel(Model model, TaskType taskType) {
+ private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
switch (taskType) {
- case TEXT_EMBEDDING -> assertTextEmbeddingModel(model);
- case COMPLETION -> assertCompletionModel(model);
- case CHAT_COMPLETION -> assertChatCompletionModel(model);
+ case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets);
+ case COMPLETION -> assertCompletionModel(model, modelIncludesSecrets);
+ case CHAT_COMPLETION -> assertChatCompletionModel(model, modelIncludesSecrets);
default -> fail("unexpected task type [" + taskType + "]");
}
}
- private static void assertTextEmbeddingModel(Model model) {
- var llamaModel = assertCommonModelFields(model);
+ private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) {
+ var llamaModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.TEXT_EMBEDDING));
}
- private static LlamaModel assertCommonModelFields(Model model) {
+ private static LlamaModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) {
assertThat(model, instanceOf(LlamaModel.class));
var llamaModel = (LlamaModel) model;
assertThat(llamaModel.getServiceSettings().modelId(), is("model_id"));
assertThat(llamaModel.uri.toString(), Matchers.is("http://www.abc.com"));
assertThat(llamaModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE));
- assertThat(
- ((DefaultSecretSettings) llamaModel.getSecretSettings()).apiKey(),
- Matchers.is(new SecureString("secret".toCharArray()))
- );
+
+ if (modelIncludesSecrets) {
+ assertThat(
+ ((DefaultSecretSettings) llamaModel.getSecretSettings()).apiKey(),
+ Matchers.is(new SecureString("secret".toCharArray()))
+ );
+ }
return llamaModel;
}
- private static void assertCompletionModel(Model model) {
- var llamaModel = assertCommonModelFields(model);
+ private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) {
+ var llamaModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.COMPLETION));
}
- private static void assertChatCompletionModel(Model model) {
- var llamaModel = assertCommonModelFields(model);
+ private static void assertChatCompletionModel(Model model, boolean modelIncludesSecrets) {
+ var llamaModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION));
}
@@ -236,10 +244,6 @@ private static LlamaEmbeddingsModel createInternalEmbeddingModel(@Nullable Simil
);
}
- protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() {
- return "Failed to parse stored model [id] for [llama] service, please delete and add the service again";
- }
-
@Before
public void init() throws Exception {
webServer.start();
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
index 50731811e4164..c8f6d0ee0e2fe 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
@@ -742,10 +742,8 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
() -> service.parsePersistedConfigWithSecrets("id", TaskType.SPARSE_EMBEDDING, config.config(), config.secrets())
);
- assertThat(
- thrownException.getMessage(),
- is("Failed to parse stored model [id] for [mistral] service, please delete and add the service again")
- );
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [mistral] service"));
+ assertThat(thrownException.getMessage(), containsString("The [mistral] service does not support task type [sparse_embedding]"));
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceParameterizedTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceParameterizedTests.java
new file mode 100644
index 0000000000000..941dab6699575
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceParameterizedTests.java
@@ -0,0 +1,18 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.openai;
+
+import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedTests;
+
+import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceTests.createTestConfiguration;
+
+public class OpenAiServiceParameterizedTests extends AbstractInferenceServiceParameterizedTests {
+ public OpenAiServiceParameterizedTests(AbstractInferenceServiceParameterizedTests.TestCase testCase) {
+ super(createTestConfiguration(), testCase);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
index 676dca2778141..c2cda175831ed 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
@@ -18,8 +18,10 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
@@ -49,24 +51,33 @@
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
-import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase;
+import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests;
+import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
+import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.hamcrest.CoreMatchers;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
+import java.net.URI;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
+import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
@@ -76,11 +87,9 @@
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
-import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
-import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
@@ -102,8 +111,22 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
-public class OpenAiServiceTests extends InferenceServiceTestCase {
+public class OpenAiServiceTests extends AbstractInferenceServiceTests {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+ private static final String MODEL = "model";
+ private static final String URL = "http://www.elastic.co";
+ private static final String ORGANIZATION = "org";
+ private static final int MAX_INPUT_TOKENS = 123;
+ private static final SimilarityMeasure SIMILARITY = SimilarityMeasure.DOT_PRODUCT;
+ private static final int DIMENSIONS = 100;
+ private static final boolean DIMENSIONS_SET_BY_USER = true;
+ private static final String USER = "user";
+ private static final String HEADER_KEY = "header_key";
+ private static final String HEADER_VALUE = "header_value";
+ private static final Map HEADERS = Map.of(HEADER_KEY, HEADER_VALUE);
+ private static final String SECRET = "secret";
+ private static final String INFERENCE_ID = "id";
+
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
private HttpClientManager clientManager;
@@ -122,221 +145,184 @@ public void shutdown() throws IOException {
webServer.close();
}
- public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModel() throws IOException {
- try (var service = createOpenAiService()) {
- ActionListener modelVerificationListener = ActionListener.wrap(model -> {
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }, exception -> fail("Unexpected exception: " + exception));
-
- service.parseRequestConfig(
- "id",
- TaskType.TEXT_EMBEDDING,
- getRequestConfigMap(
- getServiceSettingsMap("model", "url", "org"),
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- ),
- modelVerificationListener
- );
- }
+ public OpenAiServiceTests() {
+ super(createTestConfiguration());
}
- public void testParseRequestConfig_CreatesAnOpenAiChatCompletionsModel() throws IOException {
- var url = "url";
- var organization = "org";
- var model = "model";
- var user = "user";
- var secret = "secret";
-
- try (var service = createOpenAiService()) {
- ActionListener modelVerificationListener = ActionListener.wrap(m -> {
- assertThat(m, instanceOf(OpenAiChatCompletionModel.class));
-
- var completionsModel = (OpenAiChatCompletionModel) m;
-
- assertThat(completionsModel.getServiceSettings().uri().toString(), is(url));
- assertThat(completionsModel.getServiceSettings().organizationId(), is(organization));
- assertThat(completionsModel.getServiceSettings().modelId(), is(model));
- assertThat(completionsModel.getTaskSettings().user(), is(user));
- assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret));
+ public static TestConfiguration createTestConfiguration() {
+ return new TestConfiguration.Builder(
+ new CommonConfig(
+ TaskType.TEXT_EMBEDDING,
+ TaskType.RERANK,
+ EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION)
+ ) {
+ @Override
+ protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
+ return OpenAiServiceTests.createService(threadPool, clientManager);
+ }
- }, exception -> fail("Unexpected exception: " + exception));
+ @Override
+ protected Map createServiceSettingsMap(TaskType taskType) {
+ return createServiceSettingsMap(taskType, ConfigurationParseContext.REQUEST);
+ }
- service.parseRequestConfig(
- "id",
- TaskType.COMPLETION,
- getRequestConfigMap(
- getServiceSettingsMap(model, url, organization),
- getOpenAiTaskSettingsMap(user),
- getSecretSettingsMap(secret)
- ),
- modelVerificationListener
- );
- }
- }
+ @Override
+ protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) {
+ return OpenAiServiceTests.createServiceSettingsMap(taskType, parseContext);
+ }
- public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
- try (var service = createOpenAiService()) {
- ActionListener modelVerificationListener = ActionListener.wrap(
- model -> fail("Expected exception, but got model: " + model),
- exception -> {
- assertThat(exception, instanceOf(ElasticsearchStatusException.class));
- assertThat(exception.getMessage(), is("The [openai] service does not support task type [sparse_embedding]"));
+ @Override
+ protected Map createTaskSettingsMap() {
+ return OpenAiServiceTests.createTaskSettingsMap();
}
- );
- service.parseRequestConfig(
- "id",
- TaskType.SPARSE_EMBEDDING,
- getRequestConfigMap(
- getServiceSettingsMap("model", "url", "org"),
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- ),
- modelVerificationListener
- );
- }
- }
+ @Override
+ protected Map createSecretSettingsMap() {
+ return getSecretSettingsMap(SECRET);
+ }
- public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
- try (var service = createOpenAiService()) {
- var config = getRequestConfigMap(
- getServiceSettingsMap("model", "url", "org"),
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- );
- config.put("extra_key", "value");
-
- ActionListener modelVerificationListener = ActionListener.wrap(
- model -> fail("Expected exception, but got model: " + model),
- exception -> {
- assertThat(exception, instanceOf(ElasticsearchStatusException.class));
- assertThat(
- exception.getMessage(),
- is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service")
- );
+ @Override
+ protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
+ OpenAiServiceTests.assertModel(model, taskType, modelIncludesSecrets);
}
- );
- service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
- }
+ @Override
+ protected EnumSet supportedStreamingTasks() {
+ return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
+ }
+ }
+ ).enableUpdateModelTests(new UpdateModelConfiguration() {
+ @Override
+ protected OpenAiEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) {
+ return createInternalEmbeddingModel(similarityMeasure, null);
+ }
+ }).build();
}
- public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException {
- try (var service = createOpenAiService()) {
- var serviceSettings = getServiceSettingsMap("model", "url", "org");
- serviceSettings.put("extra_key", "value");
-
- var config = getRequestConfigMap(serviceSettings, getOpenAiTaskSettingsMap("user"), getSecretSettingsMap("secret"));
+ private static Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) {
+ var settingsMap = new HashMap(
+ Map.of(
+ ServiceFields.MODEL_ID,
+ MODEL,
+ ServiceFields.URL,
+ URL,
+ OpenAiServiceFields.ORGANIZATION,
+ ORGANIZATION,
+ ServiceFields.MAX_INPUT_TOKENS,
+ MAX_INPUT_TOKENS
+ )
+ );
- ActionListener modelVerificationListener = ActionListener.wrap((model) -> {
- fail("Expected exception, but got model: " + model);
- }, e -> {
- assertThat(e, instanceOf(ElasticsearchStatusException.class));
- assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service"));
- });
+ if (taskType == TaskType.TEXT_EMBEDDING) {
+ settingsMap.putAll(Map.of(ServiceFields.SIMILARITY, SIMILARITY.toString(), ServiceFields.DIMENSIONS, DIMENSIONS));
- service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
+ if (parseContext == ConfigurationParseContext.PERSISTENT) {
+ settingsMap.put(OpenAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, DIMENSIONS_SET_BY_USER);
+ }
}
- }
- public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException {
- try (var service = createOpenAiService()) {
- var taskSettingsMap = getOpenAiTaskSettingsMap("user");
- taskSettingsMap.put("extra_key", "value");
-
- var config = getRequestConfigMap(getServiceSettingsMap("model", "url", "org"), taskSettingsMap, getSecretSettingsMap("secret"));
+ return settingsMap;
+ }
- ActionListener modelVerificationListener = ActionListener.wrap((model) -> {
- fail("Expected exception, but got model: " + model);
- }, e -> {
- assertThat(e, instanceOf(ElasticsearchStatusException.class));
- assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service"));
- });
+ private static Map createTaskSettingsMap() {
+ return new HashMap<>(Map.of(OpenAiServiceFields.USER, USER, OpenAiServiceFields.HEADERS, HEADERS));
+ }
- service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
+ private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) {
+ switch (taskType) {
+ case TEXT_EMBEDDING -> assertTextEmbeddingModel(model, modelIncludesSecrets);
+ case COMPLETION, CHAT_COMPLETION -> assertCompletionModel(model, modelIncludesSecrets);
+ default -> fail("unexpected task type: " + taskType);
}
}
- public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException {
- try (var service = createOpenAiService()) {
- var secretSettingsMap = getSecretSettingsMap("secret");
- secretSettingsMap.put("extra_key", "value");
-
- var config = getRequestConfigMap(
- getServiceSettingsMap("model", "url", "org"),
- getOpenAiTaskSettingsMap("user"),
- secretSettingsMap
- );
+ private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) {
+ assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
- ActionListener modelVerificationListener = ActionListener.wrap((model) -> {
- fail("Expected exception, but got model: " + model);
- }, e -> {
- assertThat(e, instanceOf(ElasticsearchStatusException.class));
- assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [openai] service"));
- });
+ var embeddingsModel = (OpenAiEmbeddingsModel) model;
+ assertThat(
+ embeddingsModel.getServiceSettings(),
+ is(
+ new OpenAiEmbeddingsServiceSettings(
+ MODEL,
+ URI.create(URL),
+ ORGANIZATION,
+ SIMILARITY,
+ DIMENSIONS,
+ MAX_INPUT_TOKENS,
+ DIMENSIONS_SET_BY_USER,
+ OpenAiEmbeddingsServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS
+ )
+ )
+ );
- service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
+ assertThat(embeddingsModel.getTaskSettings(), is(new OpenAiEmbeddingsTaskSettings(USER, HEADERS)));
+ if (modelIncludesSecrets) {
+ assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(SECRET));
+ } else {
+ assertNull(embeddingsModel.getSecretSettings());
}
}
- public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException {
- try (var service = createOpenAiService()) {
- ActionListener modelVerificationListener = ActionListener.wrap(model -> {
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
+ private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) {
+ assertThat(model, instanceOf(OpenAiChatCompletionModel.class));
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }, exception -> fail("Unexpected exception: " + exception));
+ var completionModel = (OpenAiChatCompletionModel) model;
- service.parseRequestConfig(
- "id",
- TaskType.TEXT_EMBEDDING,
- getRequestConfigMap(
- getServiceSettingsMap("model", null, null),
- getOpenAiTaskSettingsMap(null),
- getSecretSettingsMap("secret")
- ),
- modelVerificationListener
- );
- }
- }
+ assertThat(
+ completionModel.getServiceSettings(),
+ is(
+ new OpenAiChatCompletionServiceSettings(
+ MODEL,
+ URI.create(URL),
+ ORGANIZATION,
+ MAX_INPUT_TOKENS,
+ OpenAiChatCompletionServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS
+ )
+ )
+ );
- public void testParseRequestConfig_CreatesAnOpenAiChatCompletionsModelWithoutUserWithoutUserUrlOrganization() throws IOException {
- var model = "model";
- var secret = "secret";
+ assertThat(completionModel.getTaskSettings(), is(new OpenAiChatCompletionTaskSettings(USER, HEADERS)));
- try (var service = createOpenAiService()) {
- ActionListener modelVerificationListener = ActionListener.wrap(m -> {
- assertThat(m, instanceOf(OpenAiChatCompletionModel.class));
+ assertSecrets(completionModel.getSecretSettings(), modelIncludesSecrets);
+ }
- var completionsModel = (OpenAiChatCompletionModel) m;
- assertNull(completionsModel.getServiceSettings().uri());
- assertNull(completionsModel.getServiceSettings().organizationId());
- assertThat(completionsModel.getServiceSettings().modelId(), is(model));
- assertNull(completionsModel.getTaskSettings().user());
- assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret));
+ private static void assertSecrets(DefaultSecretSettings secretSettings, boolean modelIncludesSecrets) {
+ if (modelIncludesSecrets) {
+ assertThat(secretSettings.apiKey().toString(), is(SECRET));
+ }
+ }
- }, exception -> fail("Unexpected exception: " + exception));
+ private static OpenAiEmbeddingsModel createInternalEmbeddingModel(
+ SimilarityMeasure similarityMeasure,
+ @Nullable ChunkingSettings chunkingSettings
+ ) {
+ return createInternalEmbeddingModel(similarityMeasure, URL, chunkingSettings);
+ }
- service.parseRequestConfig(
- "id",
- TaskType.COMPLETION,
- getRequestConfigMap(getServiceSettingsMap(model, null, null), getOpenAiTaskSettingsMap(null), getSecretSettingsMap(secret)),
- modelVerificationListener
- );
- }
+ private static OpenAiEmbeddingsModel createInternalEmbeddingModel(
+ SimilarityMeasure similarityMeasure,
+ @Nullable String url,
+ @Nullable ChunkingSettings chunkingSettings
+ ) {
+ return new OpenAiEmbeddingsModel(
+ INFERENCE_ID,
+ TaskType.TEXT_EMBEDDING,
+ "service",
+ new OpenAiEmbeddingsServiceSettings(
+ MODEL,
+ url == null ? null : URI.create(url),
+ ORGANIZATION,
+ similarityMeasure,
+ DIMENSIONS,
+ DIMENSIONS,
+ false,
+ null
+ ),
+ new OpenAiEmbeddingsTaskSettings(USER, HEADERS),
+ chunkingSettings,
+ new DefaultSecretSettings(new SecureString(SECRET.toCharArray()))
+ );
}
public void testParseRequestConfig_MovesModel() throws IOException {
@@ -365,497 +351,6 @@ public void testParseRequestConfig_MovesModel() throws IOException {
}
}
- public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
- try (var service = createOpenAiService()) {
- ActionListener modelVerificationListener = ActionListener.wrap(model -> {
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }, exception -> fail("Unexpected exception: " + exception));
-
- service.parseRequestConfig(
- "id",
- TaskType.TEXT_EMBEDDING,
- getRequestConfigMap(
- getServiceSettingsMap("model", null, null),
- getOpenAiTaskSettingsMap(null),
- createRandomChunkingSettingsMap(),
- getSecretSettingsMap("secret")
- ),
- modelVerificationListener
- );
- }
- }
-
- public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
- try (var service = createOpenAiService()) {
- ActionListener modelVerificationListener = ActionListener.wrap(model -> {
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }, exception -> fail("Unexpected exception: " + exception));
-
- service.parseRequestConfig(
- "id",
- TaskType.TEXT_EMBEDDING,
- getRequestConfigMap(
- getServiceSettingsMap("model", null, null),
- getOpenAiTaskSettingsMap(null),
- getSecretSettingsMap("secret")
- ),
- modelVerificationListener
- );
- }
- }
-
- public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModel() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org", 100, null, false),
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org"),
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- );
-
- var thrownException = expectThrows(
- ElasticsearchStatusException.class,
- () -> service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.SPARSE_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- )
- );
-
- assertThat(
- thrownException.getMessage(),
- is("Failed to parse stored model [id] for [openai] service, please delete and add the service again")
- );
- }
- }
-
- public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", null, null, null, null, true),
- getOpenAiTaskSettingsMap(null),
- getSecretSettingsMap("secret")
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", null, null, null, null, true),
- getOpenAiTaskSettingsMap(null),
- createRandomChunkingSettingsMap(),
- getSecretSettingsMap("secret")
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", null, null, null, null, true),
- getOpenAiTaskSettingsMap(null),
- getSecretSettingsMap("secret")
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org", null, null, true),
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- );
- persistedConfig.config().put("extra_key", "value");
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException {
- try (var service = createOpenAiService()) {
- var secretSettingsMap = getSecretSettingsMap("secret");
- secretSettingsMap.put("extra_key", "value");
-
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org", null, null, true),
- getOpenAiTaskSettingsMap("user"),
- secretSettingsMap
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org", null, null, true),
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- );
- persistedConfig.secrets().put("extra_key", "value");
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException {
- try (var service = createOpenAiService()) {
- var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true);
- serviceSettingsMap.put("extra_key", "value");
-
- var persistedConfig = getPersistedConfigMap(
- serviceSettingsMap,
- getOpenAiTaskSettingsMap("user"),
- getSecretSettingsMap("secret")
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException {
- try (var service = createOpenAiService()) {
- var taskSettingsMap = getOpenAiTaskSettingsMap("user");
- taskSettingsMap.put("extra_key", "value");
-
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org", null, null, true),
- taskSettingsMap,
- getSecretSettingsMap("secret")
- );
-
- var model = service.parsePersistedConfigWithSecrets(
- "id",
- TaskType.TEXT_EMBEDDING,
- persistedConfig.config(),
- persistedConfig.secrets()
- );
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
- }
- }
-
- public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModel() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org", null, null, true),
- getOpenAiTaskSettingsMap("user")
- );
-
- var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertNull(embeddingsModel.getSecretSettings());
- }
- }
-
- public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org"), getOpenAiTaskSettingsMap("user"));
-
- var thrownException = expectThrows(
- ElasticsearchStatusException.class,
- () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
- );
-
- assertThat(
- thrownException.getMessage(),
- is("Failed to parse stored model [id] for [openai] service, please delete and add the service again")
- );
- }
- }
-
- public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", null, null, null, null, true),
- getOpenAiTaskSettingsMap(null)
- );
-
- var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertNull(embeddingsModel.getSecretSettings());
- }
- }
-
- public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", null, null, null, null, true),
- getOpenAiTaskSettingsMap(null),
- createRandomChunkingSettingsMap()
- );
-
- var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
- assertNull(embeddingsModel.getSecretSettings());
- }
- }
-
- public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", null, null, null, null, true),
- getOpenAiTaskSettingsMap(null)
- );
-
- var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertNull(embeddingsModel.getServiceSettings().uri());
- assertNull(embeddingsModel.getServiceSettings().organizationId());
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertNull(embeddingsModel.getTaskSettings().user());
- assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
- assertNull(embeddingsModel.getSecretSettings());
- }
- }
-
- public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
- try (var service = createOpenAiService()) {
- var persistedConfig = getPersistedConfigMap(
- getServiceSettingsMap("model", "url", "org", null, null, true),
- getOpenAiTaskSettingsMap("user")
- );
- persistedConfig.config().put("extra_key", "value");
-
- var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertNull(embeddingsModel.getSecretSettings());
- }
- }
-
- public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException {
- try (var service = createOpenAiService()) {
- var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true);
- serviceSettingsMap.put("extra_key", "value");
-
- var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getOpenAiTaskSettingsMap("user"));
-
- var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertNull(embeddingsModel.getSecretSettings());
- }
- }
-
- public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException {
- try (var service = createOpenAiService()) {
- var taskSettingsMap = getOpenAiTaskSettingsMap("user");
- taskSettingsMap.put("extra_key", "value");
-
- var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org", null, null, true), taskSettingsMap);
-
- var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
-
- assertThat(model, instanceOf(OpenAiEmbeddingsModel.class));
-
- var embeddingsModel = (OpenAiEmbeddingsModel) model;
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
- assertThat(embeddingsModel.getServiceSettings().organizationId(), is("org"));
- assertThat(embeddingsModel.getServiceSettings().modelId(), is("model"));
- assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
- assertNull(embeddingsModel.getSecretSettings());
- }
- }
-
public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException {
var sender = mock(Sender.class);
@@ -1681,6 +1176,11 @@ public void testGetConfiguration() throws Exception {
}
}
+ private static OpenAiService createService(ThreadPool threadPool, HttpClientManager clientManager) {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+ return new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty());
+ }
+
private OpenAiService createOpenAiService() {
return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty());
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
index 8cad8cbad208a..d7f5726af85e0 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
@@ -409,9 +409,10 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM
)
);
- MatcherAssert.assertThat(
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [voyageai] service"));
+ assertThat(
thrownException.getMessage(),
- is("Failed to parse stored model [id] for [voyageai] service, please delete and add the service again")
+ containsString("The [voyageai] service does not support task type [sparse_embedding]")
);
}
}
@@ -624,9 +625,10 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro
() -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
);
- MatcherAssert.assertThat(
+ assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [voyageai] service"));
+ assertThat(
thrownException.getMessage(),
- is("Failed to parse stored model [id] for [voyageai] service, please delete and add the service again")
+ containsString("The [voyageai] service does not support task type [sparse_embedding]")
);
}
}