diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index d437533a8603d..2c99563955746 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -210,4 +210,8 @@ default List defaultConfigIds() { default void defaultConfigs(ActionListener> defaultsListener) { defaultsListener.onResponse(List.of()); } + + default void updateModelsWithDynamicFields(List model, ActionListener> listener) { + listener.onResponse(model); + } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java index f81ebc25dc860..0bfb6e9e43b03 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java @@ -109,6 +109,55 @@ public void testModelIdDoesNotMatch() throws IOException { ); } + public void testNumAllocationsIsUpdated() throws IOException { + var modelId = "update_num_allocations"; + var deploymentId = modelId; + + CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client()); + var response = startMlNodeDeploymemnt(modelId, deploymentId); + assertOkOrCreated(response); + + var inferenceId = "test_num_allocations_updated"; + var putModel = putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING); + var serviceSettings = putModel.get("service_settings"); + assertThat( + putModel.toString(), + serviceSettings, + is( + Map.of( + "num_allocations", + 1, + "num_threads", + 1, + "model_id", + "update_num_allocations", + "deployment_id", + "update_num_allocations" + ) + ) + ); + + assertOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2)); + + var updatedServiceSettings = getModel(inferenceId).get("service_settings"); + assertThat( + updatedServiceSettings.toString(), + updatedServiceSettings, + is( + Map.of( + "num_allocations", + 2, + "num_threads", + 1, + "model_id", + "update_num_allocations", + "deployment_id", + "update_num_allocations" + ) + ) + ); + } + private String endpointConfig(String deploymentId) { return Strings.format(""" { @@ -147,6 +196,20 @@ private Response startMlNodeDeploymemnt(String modelId, String deploymentId) thr return client().performRequest(request); } + private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations) throws IOException { + String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update"; + + var body = Strings.format(""" + { + "number_of_allocations": %d + } + """, numAllocations); + + Request request = new Request("POST", endPoint); + request.setJsonEntity(body); + return client().performRequest(request); + } + protected void stopMlNodeDeployment(String deploymentId) throws IOException { String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop"; Request request = new Request("POST", endpoint); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index cbc50c361e3b5..37de2caadb475 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -24,6 +24,7 @@ import java.util.stream.Stream; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalToIgnoringCase; import static org.hamcrest.Matchers.hasSize; @@ -326,4 +327,9 @@ public void testSupportedStream() throws Exception { deleteModel(modelId); } } + + public void testGetZeroModels() throws IOException { + var models = getModels("_all", TaskType.RERANK); + assertThat(models, empty()); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index edcec45b50a16..01e663df4a3ea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -9,13 +9,13 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.injection.guice.Inject; @@ -29,8 +29,11 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; import java.util.List; import java.util.concurrent.Executor; +import java.util.stream.Collectors; public class TransportGetInferenceModelAction extends HandledTransportAction< GetInferenceModelAction.Request, @@ -96,39 +99,77 @@ private void getSingleModel( var model = service.get() .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); - delegate.onResponse(new GetInferenceModelAction.Response(List.of(model.getConfigurations()))); + + service.get() + .updateModelsWithDynamicFields( + List.of(model), + delegate.delegateFailureAndWrap( + (l2, updatedModels) -> l2.onResponse( + new GetInferenceModelAction.Response( + updatedModels.stream().map(Model::getConfigurations).collect(Collectors.toList()) + ) + ) + ) + ); })); } private void getAllModels(boolean persistDefaultEndpoints, ActionListener listener) { modelRegistry.getAllModels( persistDefaultEndpoints, - listener.delegateFailureAndWrap((l, models) -> executor.execute(ActionRunnable.supply(l, () -> parseModels(models)))) + listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener))) ); } private void getModelsByTaskType(TaskType taskType, ActionListener listener) { modelRegistry.getModelsByTaskType( taskType, - listener.delegateFailureAndWrap((l, models) -> executor.execute(ActionRunnable.supply(l, () -> parseModels(models)))) + listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener))) ); } - private GetInferenceModelAction.Response parseModels(List unparsedModels) { - var parsedModels = new ArrayList(); + private void parseModels(List unparsedModels, ActionListener listener) { + if (unparsedModels.isEmpty()) { + listener.onResponse(new GetInferenceModelAction.Response(List.of())); + return; + } - for (var unparsedModel : unparsedModels) { - var service = serviceRegistry.getService(unparsedModel.service()); - if (service.isEmpty()) { - throw serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId()); + var parsedModelsByService = new HashMap>(); + try { + for (var unparsedModel : unparsedModels) { + var service = serviceRegistry.getService(unparsedModel.service()); + if (service.isEmpty()) { + throw serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId()); + } + var list = parsedModelsByService.computeIfAbsent(service.get().name(), s -> new ArrayList<>()); + list.add( + service.get() + .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()) + ); } - parsedModels.add( - service.get() - .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()) - .getConfigurations() + + var groupedListener = new GroupedActionListener>( + parsedModelsByService.entrySet().size(), + listener.delegateFailureAndWrap((delegate, listOfListOfModels) -> { + var modifiable = new ArrayList(); + for (var l : listOfListOfModels) { + modifiable.addAll(l); + } + modifiable.sort(Comparator.comparing(Model::getInferenceEntityId)); + delegate.onResponse( + new GetInferenceModelAction.Response(modifiable.stream().map(Model::getConfigurations).collect(Collectors.toList())) + ); + }) ); + + for (var entry : parsedModelsByService.entrySet()) { + serviceRegistry.getService(entry.getKey()) + .get() // must be non-null to get this far + .updateModelsWithDynamicFields(entry.getValue(), groupedListener); + } + } catch (Exception e) { + listener.onFailure(e); } - return new GetInferenceModelAction.Response(parsedModels); } private ElasticsearchStatusException serviceNotFoundException(String service, String inferenceId) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java index d38def8dca47f..8b2969c39b7ba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java @@ -21,7 +21,7 @@ public abstract class ElasticsearchInternalModel extends Model { - protected final ElasticsearchInternalServiceSettings internalServiceSettings; + protected ElasticsearchInternalServiceSettings internalServiceSettings; public ElasticsearchInternalModel( String inferenceEntityId, @@ -91,6 +91,10 @@ public ElasticsearchInternalServiceSettings getServiceSettings() { return (ElasticsearchInternalServiceSettings) super.getServiceSettings(); } + public void updateNumAllocations(Integer numAllocations) { + this.internalServiceSettings.setNumAllocations(numAllocations); + } + @Override public String toString() { return Strings.toString(this.getConfigurations()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 389a9fa369c21..49919fda9f89d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -32,6 +32,7 @@ import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; @@ -56,6 +57,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.EnumSet; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -786,11 +788,50 @@ public List defaultConfigIds() { ); } - /** - * Default configurations that can be out of the box without creating an endpoint first. - * @param defaultsListener Config listener - */ @Override + public void updateModelsWithDynamicFields(List models, ActionListener> listener) { + + if (models.isEmpty()) { + listener.onResponse(models); + return; + } + + var modelsByDeploymentIds = new HashMap(); + for (var model : models) { + assert model instanceof ElasticsearchInternalModel; + + if (model instanceof ElasticsearchInternalModel esModel) { + modelsByDeploymentIds.put(esModel.mlNodeDeploymentId(), esModel); + } else { + listener.onFailure( + new ElasticsearchStatusException( + "Cannot update model [{}] as it is not an Elasticsearch service model", + RestStatus.INTERNAL_SERVER_ERROR, + model.getInferenceEntityId() + ) + ); + return; + } + } + + String deploymentIds = String.join(",", modelsByDeploymentIds.keySet()); + client.execute( + GetDeploymentStatsAction.INSTANCE, + new GetDeploymentStatsAction.Request(deploymentIds), + ActionListener.wrap(stats -> { + for (var deploymentStats : stats.getStats().results()) { + var model = modelsByDeploymentIds.get(deploymentStats.getDeploymentId()); + model.updateNumAllocations(deploymentStats.getNumberOfAllocations()); + } + listener.onResponse(new ArrayList<>(modelsByDeploymentIds.values())); + }, e -> { + logger.warn("Get deployment stats failed, cannot update the endpoint's number of allocations", e); + // continue with the original response + listener.onResponse(models); + }) + ); + } + public void defaultConfigs(ActionListener> defaultsListener) { preferredModelVariantFn.accept(defaultsListener.delegateFailureAndWrap((delegate, preferredModelVariant) -> { if (PreferredModelVariant.LINUX_X86_OPTIMIZED.equals(preferredModelVariant)) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index fedf48fb583a3..962c939146ef2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -39,7 +39,7 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings { public static final String DEPLOYMENT_ID = "deployment_id"; public static final String ADAPTIVE_ALLOCATIONS = "adaptive_allocations"; - private final Integer numAllocations; + private Integer numAllocations; private final int numThreads; private final String modelId; private final AdaptiveAllocationsSettings adaptiveAllocationsSettings; @@ -172,6 +172,10 @@ public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException { : null; } + public void setNumAllocations(Integer numAllocations) { + this.numAllocations = numAllocations; + } + @Override public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { @@ -194,6 +198,10 @@ public String modelId() { return modelId; } + public String deloymentId() { + return modelId; + } + public Integer getNumAllocations() { return numAllocations; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java new file mode 100644 index 0000000000000..96cd42efa42f5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; + +public class ElserInternalModelTests extends ESTestCase { + public void testUpdateNumAllocation() { + var model = new ElserInternalModel( + "foo", + TaskType.SPARSE_EMBEDDING, + ElasticsearchInternalService.NAME, + new ElserInternalServiceSettings(null, 1, "elser", null), + new ElserMlNodeTaskSettings(), + null + ); + + model.updateNumAllocations(1); + assertEquals(1, model.getServiceSettings().getNumAllocations().intValue()); + + model.updateNumAllocations(null); + assertNull(model.getServiceSettings().getNumAllocations()); + } +}