Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/128584.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 128584
summary: '`InferenceService` support aliases'
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ default void init(Client client) {}

String name();

/**
* The aliases that map to {@link #name()}. {@link InferenceServiceRegistry} allows users to create and use inference services by one
* of their aliases.
*/
default List<String> aliases() {
return List.of();
}

/**
* Parse model configuration from the {@code config map} from a request and return
* the parsed {@link Model}. This requires that both the secrets and service settings be contained in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,22 @@
public class InferenceServiceRegistry implements Closeable {

private final Map<String, InferenceService> services;
private final Map<String, String> aliases;
private final List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();

public InferenceServiceRegistry(
List<InferenceServiceExtension> inferenceServicePlugins,
InferenceServiceExtension.InferenceServiceFactoryContext factoryContext
) {
// TODO check names are unique
// toMap verifies that the names and aliases are unique
services = inferenceServicePlugins.stream()
.flatMap(r -> r.getInferenceServiceFactories().stream())
.map(factory -> factory.create(factoryContext))
.collect(Collectors.toMap(InferenceService::name, Function.identity()));
aliases = services.values()
.stream()
.flatMap(service -> service.aliases().stream().distinct().map(alias -> Map.entry(alias, service.name())))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}

public void init(Client client) {
Expand All @@ -56,13 +61,8 @@ public Map<String, InferenceService> getServices() {
}

public Optional<InferenceService> getService(String serviceName) {

if ("elser".equals(serviceName)) { // ElserService.NAME before removal
// here we are aliasing the elser service to use the elasticsearch service instead
return Optional.ofNullable(services.get("elasticsearch")); // ElasticsearchInternalService.NAME
} else {
return Optional.ofNullable(services.get(serviceName));
}
var serviceKey = aliases.getOrDefault(serviceName, serviceName);
return Optional.ofNullable(services.get(serviceKey));
}

public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void testDefaultModels() throws IOException {
var rerankModel = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
assertDefaultRerankConfig(rerankModel);

putModel("my-model", mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING));
putModel("my-model", mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service"));
var registeredModels = getMinimalConfigs();
assertThat(registeredModels.size(), equalTo(1));
assertTrue(registeredModels.containsKey("my-model"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ static String updateConfig(@Nullable TaskType taskTypeInBody, String apiKey, int
""", taskType, apiKey, temperature);
}

static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody) {
static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody, String service) {
var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
return Strings.format("""
{
%s
"service": "streaming_completion_test_service",
"service": "%s",
"service_settings": {
"model": "my_model",
"api_key": "abc64"
Expand All @@ -133,7 +133,7 @@ static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody
"temperature": 3
}
}
""", taskType);
""", taskType, service);
}

static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean shouldReturnHiddenField) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws

public void testUnsupportedStream() throws Exception {
String modelId = "streaming";
putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING));
putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service"));
var singleModel = getModel(modelId);
assertEquals(modelId, singleModel.get("inference_id"));
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type"));
Expand All @@ -326,8 +326,16 @@ public void testUnsupportedStream() throws Exception {
}

public void testSupportedStream() throws Exception {
testSupportedStream("streaming_completion_test_service");
}

public void testSupportedStreamForAlias() throws Exception {
testSupportedStream("streaming_completion_test_service_alias");
}

private void testSupportedStream(String serviceName) throws Exception {
String modelId = "streaming";
putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION));
putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION, serviceName));
var singleModel = getModel(modelId);
assertEquals(modelId, singleModel.get("inference_id"));
assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type"));
Expand All @@ -352,7 +360,7 @@ public void testSupportedStream() throws Exception {

public void testUnifiedCompletionInference() throws Exception {
String modelId = "streaming";
putModel(modelId, mockCompletionServiceModelConfig(TaskType.CHAT_COMPLETION));
putModel(modelId, mockCompletionServiceModelConfig(TaskType.CHAT_COMPLETION, "streaming_completion_test_service"));
var singleModel = getModel(modelId);
assertEquals(modelId, singleModel.get("inference_id"));
assertEquals(TaskType.CHAT_COMPLETION.toString(), singleModel.get("task_type"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"text_embedding_test_service",
"voyageai",
"watsonxai",
"sagemaker"
"amazon_sagemaker"
).toArray()
)
);
Expand Down Expand Up @@ -93,7 +93,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
"text_embedding_test_service",
"voyageai",
"watsonxai",
"sagemaker"
"amazon_sagemaker"
).toArray()
)
);
Expand Down Expand Up @@ -143,7 +143,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"openai",
"streaming_completion_test_service",
"hugging_face",
"sagemaker"
"amazon_sagemaker"
).toArray()
)
);
Expand All @@ -158,7 +158,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
assertThat(
providers,
containsInAnyOrder(
List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "sagemaker").toArray()
List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "amazon_sagemaker").toArray()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public List<Factory> getInferenceServiceFactories() {

public static class TestInferenceService extends AbstractTestInferenceService {
private static final String NAME = "streaming_completion_test_service";
private static final String ALIAS = "streaming_completion_test_service_alias";
private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);

private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
Expand All @@ -75,6 +76,11 @@ public String name() {
return NAME;
}

@Override
public List<String> aliases() {
return List.of(ALIAS);
}

@Override
protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap) {
return TestServiceSettings.fromMap(serviceSettingsMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,11 @@ public String name() {
return NAME;
}

@Override
public List<String> aliases() {
return List.of(OLD_ELSER_SERVICE_NAME);
}

private RankedDocsResults textSimilarityResultsToRankedDocs(
List<? extends InferenceResults> results,
Function<Integer, String> inputSupplier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails;

public class SageMakerService implements InferenceService {
public static final String NAME = "sagemaker";
public static final String NAME = "amazon_sagemaker";
private static final String DISPLAY_NAME = "Amazon SageMaker";
private static final List<String> ALIASES = List.of("sagemaker", "amazonsagemaker");
private static final int DEFAULT_BATCH_SIZE = 256;
private static final TimeValue DEFAULT_TIMEOUT = TimeValue.THIRTY_SECONDS;
private final SageMakerModelBuilder modelBuilder;
Expand All @@ -67,7 +69,7 @@ public SageMakerService(
this.threadPool = threadPool;
this.configuration = new LazyInitializable<>(
() -> new InferenceServiceConfiguration.Builder().setService(NAME)
.setName("Amazon SageMaker")
.setName(DISPLAY_NAME)
.setTaskTypes(supportedTaskTypes())
.setConfigurations(configurationMap.get())
.build()
Expand All @@ -79,6 +81,11 @@ public String name() {
return NAME;
}

@Override
public List<String> aliases() {
return ALIASES;
}

@Override
public void parseRequestConfig(
String modelId,
Expand Down