From 68b9a152059ac590efc3594dfb3375fd6281e6c7 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Wed, 28 May 2025 16:36:21 -0400 Subject: [PATCH] [ML] InferenceService support aliases (#128584) "elser" is an alias for "elasticsearch", and "sagemaker" is an alias for "amazon_sagemaker". Users can continue to create and use providers by their alias. Elasticsearch will continue to support the alias when it reads the configuration from the internal index. --- docs/changelog/128584.yaml | 5 +++++ .../inference/InferenceService.java | 8 ++++++++ .../inference/InferenceServiceRegistry.java | 16 ++++++++-------- .../xpack/inference/DefaultEndPointsIT.java | 2 +- .../xpack/inference/InferenceBaseRestTest.java | 6 +++--- .../xpack/inference/InferenceCrudIT.java | 14 +++++++++++--- .../xpack/inference/InferenceGetServicesIT.java | 8 ++++---- .../TestStreamingCompletionServiceExtension.java | 6 ++++++ .../ElasticsearchInternalService.java | 5 +++++ .../services/sagemaker/SageMakerService.java | 11 +++++++++-- 10 files changed, 60 insertions(+), 21 deletions(-) create mode 100644 docs/changelog/128584.yaml diff --git a/docs/changelog/128584.yaml b/docs/changelog/128584.yaml new file mode 100644 index 0000000000000..e5e380559786d --- /dev/null +++ b/docs/changelog/128584.yaml @@ -0,0 +1,5 @@ +pr: 128584 +summary: '`InferenceService` support aliases' +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index d85acb021506a..e3e9abf7dc3f2 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -27,6 +27,14 @@ default void init(Client client) {} String name(); + /** + * The aliases that map to {@link #name()}. {@link InferenceServiceRegistry} allows users to create and use inference services by one + * of their aliases. + */ + default List aliases() { + return List.of(); + } + /** * Parse model configuration from the {@code config map} from a request and return * the parsed {@link Model}. This requires that both the secrets and service settings be contained in the diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index 8ef9b59f5545a..deed0610c258f 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -24,17 +24,22 @@ public class InferenceServiceRegistry implements Closeable { private final Map services; + private final Map aliases; private final List namedWriteables = new ArrayList<>(); public InferenceServiceRegistry( List inferenceServicePlugins, InferenceServiceExtension.InferenceServiceFactoryContext factoryContext ) { - // TODO check names are unique + // toMap verifies that the names and aliases are unique services = inferenceServicePlugins.stream() .flatMap(r -> r.getInferenceServiceFactories().stream()) .map(factory -> factory.create(factoryContext)) .collect(Collectors.toMap(InferenceService::name, Function.identity())); + aliases = services.values() + .stream() + .flatMap(service -> service.aliases().stream().distinct().map(alias -> Map.entry(alias, service.name()))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); } public void init(Client client) { @@ -56,13 +61,8 @@ public Map getServices() { } public Optional getService(String serviceName) { - - if ("elser".equals(serviceName)) { // ElserService.NAME before removal - // here we are aliasing the elser service to use the elasticsearch service instead - return Optional.ofNullable(services.get("elasticsearch")); // ElasticsearchInternalService.NAME - } else { - return Optional.ofNullable(services.get(serviceName)); - } + var serviceKey = aliases.getOrDefault(serviceName, serviceName); + return Optional.ofNullable(services.get(serviceKey)); } public List getNamedWriteables() { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java index 5d4aa583d4310..fb2eaf3c26b8e 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java @@ -65,7 +65,7 @@ public void testDefaultModels() throws IOException { var rerankModel = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID); assertDefaultRerankConfig(rerankModel); - putModel("my-model", mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING)); + putModel("my-model", mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service")); var registeredModels = getMinimalConfigs(); assertThat(registeredModels.size(), equalTo(1)); assertTrue(registeredModels.containsKey("my-model")); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index d76d76a8d516c..69256d49fe1d2 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -119,12 +119,12 @@ static String updateConfig(@Nullable TaskType taskTypeInBody, String apiKey, int """, taskType, apiKey, temperature); } - static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody) { + static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody, String service) { var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\","; return Strings.format(""" { %s - "service": "streaming_completion_test_service", + "service": "%s", "service_settings": { "model": "my_model", "api_key": "abc64" @@ -133,7 +133,7 @@ static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody "temperature": 3 } } - """, taskType); + """, taskType, service); } static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean shouldReturnHiddenField) { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index f85874d6e580c..0a98787514010 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -305,7 +305,7 @@ public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws public void testUnsupportedStream() throws Exception { String modelId = "streaming"; - putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING)); + putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service")); var singleModel = getModel(modelId); assertEquals(modelId, singleModel.get("inference_id")); assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type")); @@ -326,8 +326,16 @@ public void testUnsupportedStream() throws Exception { } public void testSupportedStream() throws Exception { + testSupportedStream("streaming_completion_test_service"); + } + + public void testSupportedStreamForAlias() throws Exception { + testSupportedStream("streaming_completion_test_service_alias"); + } + + private void testSupportedStream(String serviceName) throws Exception { String modelId = "streaming"; - putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION)); + putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION, serviceName)); var singleModel = getModel(modelId); assertEquals(modelId, singleModel.get("inference_id")); assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type")); @@ -352,7 +360,7 @@ public void testSupportedStream() throws Exception { public void testUnifiedCompletionInference() throws Exception { String modelId = "streaming"; - putModel(modelId, mockCompletionServiceModelConfig(TaskType.CHAT_COMPLETION)); + putModel(modelId, mockCompletionServiceModelConfig(TaskType.CHAT_COMPLETION, "streaming_completion_test_service")); var singleModel = getModel(modelId); assertEquals(modelId, singleModel.get("inference_id")); assertEquals(TaskType.CHAT_COMPLETION.toString(), singleModel.get("task_type")); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 1160002a60ac0..ff9bc83f741f1 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -54,7 +54,7 @@ public void testGetServicesWithoutTaskType() throws IOException { "text_embedding_test_service", "voyageai", "watsonxai", - "sagemaker" + "amazon_sagemaker" ).toArray() ) ); @@ -93,7 +93,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { "text_embedding_test_service", "voyageai", "watsonxai", - "sagemaker" + "amazon_sagemaker" ).toArray() ) ); @@ -143,7 +143,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "openai", "streaming_completion_test_service", "hugging_face", - "sagemaker" + "amazon_sagemaker" ).toArray() ) ); @@ -158,7 +158,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { assertThat( providers, containsInAnyOrder( - List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "sagemaker").toArray() + List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "amazon_sagemaker").toArray() ) ); } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 917e8e128256f..f711f47aa560d 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -60,6 +60,7 @@ public List getInferenceServiceFactories() { public static class TestInferenceService extends AbstractTestInferenceService { private static final String NAME = "streaming_completion_test_service"; + private static final String ALIAS = "streaming_completion_test_service_alias"; private static final Set supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); private static final EnumSet supportedTaskTypes = EnumSet.of( @@ -75,6 +76,11 @@ public String name() { return NAME; } + @Override + public List aliases() { + return List.of(ALIAS); + } + @Override protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { return TestServiceSettings.fromMap(serviceSettingsMap); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 8232240b2c9ba..629dafbc5b52a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -778,6 +778,11 @@ public String name() { return NAME; } + @Override + public List aliases() { + return List.of(OLD_ELSER_SERVICE_NAME); + } + private RankedDocsResults textSimilarityResultsToRankedDocs( List results, Function inputSupplier, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java index 581108140a7b8..f1962dc107c1a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -45,7 +45,9 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails; public class SageMakerService implements InferenceService { - public static final String NAME = "sagemaker"; + public static final String NAME = "amazon_sagemaker"; + private static final String DISPLAY_NAME = "Amazon SageMaker"; + private static final List ALIASES = List.of("sagemaker", "amazonsagemaker"); private static final int DEFAULT_BATCH_SIZE = 256; private static final TimeValue DEFAULT_TIMEOUT = TimeValue.THIRTY_SECONDS; private final SageMakerModelBuilder modelBuilder; @@ -67,7 +69,7 @@ public SageMakerService( this.threadPool = threadPool; this.configuration = new LazyInitializable<>( () -> new InferenceServiceConfiguration.Builder().setService(NAME) - .setName("Amazon SageMaker") + .setName(DISPLAY_NAME) .setTaskTypes(supportedTaskTypes()) .setConfigurations(configurationMap.get()) .build() @@ -79,6 +81,11 @@ public String name() { return NAME; } + @Override + public List aliases() { + return ALIASES; + } + @Override public void parseRequestConfig( String modelId,