Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,8 @@ default List<DefaultConfigId> defaultConfigIds() {
default void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
defaultsListener.onResponse(List.of());
}

default void updateModelsWithDynamicFields(List<Model> model, ActionListener<List<Model>> listener) {
listener.onResponse(model);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,55 @@ public void testModelIdDoesNotMatch() throws IOException {
);
}

public void testNumAllocationsIsUpdated() throws IOException {
var modelId = "update_num_allocations";
var deploymentId = modelId;

CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
var response = startMlNodeDeploymemnt(modelId, deploymentId);
assertOkOrCreated(response);

var inferenceId = "test_num_allocations_updated";
var putModel = putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
var serviceSettings = putModel.get("service_settings");
assertThat(
putModel.toString(),
serviceSettings,
is(
Map.of(
"num_allocations",
1,
"num_threads",
1,
"model_id",
"update_num_allocations",
"deployment_id",
"update_num_allocations"
)
)
);

assertOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2));

var updatedServiceSettings = getModel(inferenceId).get("service_settings");
assertThat(
updatedServiceSettings.toString(),
updatedServiceSettings,
is(
Map.of(
"num_allocations",
2,
"num_threads",
1,
"model_id",
"update_num_allocations",
"deployment_id",
"update_num_allocations"
)
)
);
}

private String endpointConfig(String deploymentId) {
return Strings.format("""
{
Expand Down Expand Up @@ -147,6 +196,20 @@ private Response startMlNodeDeploymemnt(String modelId, String deploymentId) thr
return client().performRequest(request);
}

private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations) throws IOException {
String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update";

var body = Strings.format("""
{
"number_of_allocations": %d
}
""", numAllocations);

Request request = new Request("POST", endPoint);
request.setJsonEntity(body);
return client().performRequest(request);
}

protected void stopMlNodeDeployment(String deploymentId) throws IOException {
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
Request request = new Request("POST", endpoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.stream.Stream;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.equalToIgnoringCase;
import static org.hamcrest.Matchers.hasSize;
Expand Down Expand Up @@ -326,4 +327,9 @@ public void testSupportedStream() throws Exception {
deleteModel(modelId);
}
}

public void testGetZeroModels() throws IOException {
var models = getModels("_all", TaskType.RERANK);
assertThat(models, empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.injection.guice.Inject;
Expand All @@ -29,8 +29,11 @@
import org.elasticsearch.xpack.inference.registry.ModelRegistry;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.stream.Collectors;

public class TransportGetInferenceModelAction extends HandledTransportAction<
GetInferenceModelAction.Request,
Expand Down Expand Up @@ -96,39 +99,77 @@ private void getSingleModel(

var model = service.get()
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
delegate.onResponse(new GetInferenceModelAction.Response(List.of(model.getConfigurations())));

service.get()
.updateModelsWithDynamicFields(
List.of(model),
delegate.delegateFailureAndWrap(
(l2, updatedModels) -> l2.onResponse(
new GetInferenceModelAction.Response(
updatedModels.stream().map(Model::getConfigurations).collect(Collectors.toList())
)
)
)
);
}));
}

private void getAllModels(boolean persistDefaultEndpoints, ActionListener<GetInferenceModelAction.Response> listener) {
modelRegistry.getAllModels(
persistDefaultEndpoints,
listener.delegateFailureAndWrap((l, models) -> executor.execute(ActionRunnable.supply(l, () -> parseModels(models))))
listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener)))
);
}

private void getModelsByTaskType(TaskType taskType, ActionListener<GetInferenceModelAction.Response> listener) {
modelRegistry.getModelsByTaskType(
taskType,
listener.delegateFailureAndWrap((l, models) -> executor.execute(ActionRunnable.supply(l, () -> parseModels(models))))
listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener)))
);
}

private GetInferenceModelAction.Response parseModels(List<UnparsedModel> unparsedModels) {
var parsedModels = new ArrayList<ModelConfigurations>();
private void parseModels(List<UnparsedModel> unparsedModels, ActionListener<GetInferenceModelAction.Response> listener) {
if (unparsedModels.isEmpty()) {
listener.onResponse(new GetInferenceModelAction.Response(List.of()));
return;
}

for (var unparsedModel : unparsedModels) {
var service = serviceRegistry.getService(unparsedModel.service());
if (service.isEmpty()) {
throw serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId());
var parsedModelsByService = new HashMap<String, List<Model>>();
try {
for (var unparsedModel : unparsedModels) {
var service = serviceRegistry.getService(unparsedModel.service());
if (service.isEmpty()) {
throw serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId());
}
var list = parsedModelsByService.computeIfAbsent(service.get().name(), s -> new ArrayList<>());
list.add(
service.get()
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings())
);
}
parsedModels.add(
service.get()
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings())
.getConfigurations()

var groupedListener = new GroupedActionListener<List<Model>>(
parsedModelsByService.entrySet().size(),
listener.delegateFailureAndWrap((delegate, listOfListOfModels) -> {
var modifiable = new ArrayList<Model>();
for (var l : listOfListOfModels) {
modifiable.addAll(l);
}
modifiable.sort(Comparator.comparing(Model::getInferenceEntityId));
delegate.onResponse(
new GetInferenceModelAction.Response(modifiable.stream().map(Model::getConfigurations).collect(Collectors.toList()))
);
})
);

for (var entry : parsedModelsByService.entrySet()) {
serviceRegistry.getService(entry.getKey())
.get() // must be non-null to get this far
.updateModelsWithDynamicFields(entry.getValue(), groupedListener);
}
} catch (Exception e) {
listener.onFailure(e);
}
return new GetInferenceModelAction.Response(parsedModels);
}

private ElasticsearchStatusException serviceNotFoundException(String service, String inferenceId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

public abstract class ElasticsearchInternalModel extends Model {

protected final ElasticsearchInternalServiceSettings internalServiceSettings;
protected ElasticsearchInternalServiceSettings internalServiceSettings;

public ElasticsearchInternalModel(
String inferenceEntityId,
Expand Down Expand Up @@ -91,6 +91,10 @@ public ElasticsearchInternalServiceSettings getServiceSettings() {
return (ElasticsearchInternalServiceSettings) super.getServiceSettings();
}

public void updateNumAllocations(Integer numAllocations) {
this.internalServiceSettings.setNumAllocations(numAllocations);
}

@Override
public String toString() {
return Strings.toString(this.getConfigurations());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
Expand All @@ -56,6 +57,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -786,11 +788,50 @@ public List<DefaultConfigId> defaultConfigIds() {
);
}

/**
* Default configurations that can be out of the box without creating an endpoint first.
* @param defaultsListener Config listener
*/
@Override
public void updateModelsWithDynamicFields(List<Model> models, ActionListener<List<Model>> listener) {

if (models.isEmpty()) {
listener.onResponse(models);
return;
}

var modelsByDeploymentIds = new HashMap<String, ElasticsearchInternalModel>();
for (var model : models) {
assert model instanceof ElasticsearchInternalModel;

if (model instanceof ElasticsearchInternalModel esModel) {
modelsByDeploymentIds.put(esModel.mlNodeDeploymentId(), esModel);
} else {
listener.onFailure(
new ElasticsearchStatusException(
"Cannot update model [{}] as it is not an Elasticsearch service model",
RestStatus.INTERNAL_SERVER_ERROR,
model.getInferenceEntityId()
)
);
return;
}
}

String deploymentIds = String.join(",", modelsByDeploymentIds.keySet());
client.execute(
GetDeploymentStatsAction.INSTANCE,
new GetDeploymentStatsAction.Request(deploymentIds),
ActionListener.wrap(stats -> {
for (var deploymentStats : stats.getStats().results()) {
var model = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
model.updateNumAllocations(deploymentStats.getNumberOfAllocations());
}
listener.onResponse(new ArrayList<>(modelsByDeploymentIds.values()));
}, e -> {
logger.warn("Get deployment stats failed, cannot update the endpoint's number of allocations", e);
// continue with the original response
listener.onResponse(models);
})
);
}

public void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
preferredModelVariantFn.accept(defaultsListener.delegateFailureAndWrap((delegate, preferredModelVariant) -> {
if (PreferredModelVariant.LINUX_X86_OPTIMIZED.equals(preferredModelVariant)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings {
public static final String DEPLOYMENT_ID = "deployment_id";
public static final String ADAPTIVE_ALLOCATIONS = "adaptive_allocations";

private final Integer numAllocations;
private Integer numAllocations;
private final int numThreads;
private final String modelId;
private final AdaptiveAllocationsSettings adaptiveAllocationsSettings;
Expand Down Expand Up @@ -172,6 +172,10 @@ public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException {
: null;
}

public void setNumAllocations(Integer numAllocations) {
this.numAllocations = numAllocations;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) {
Expand All @@ -194,6 +198,10 @@ public String modelId() {
return modelId;
}

public String deloymentId() {
return modelId;
}

public Integer getNumAllocations() {
return numAllocations;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.elasticsearch;

import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;

public class ElserInternalModelTests extends ESTestCase {
public void testUpdateNumAllocation() {
var model = new ElserInternalModel(
"foo",
TaskType.SPARSE_EMBEDDING,
ElasticsearchInternalService.NAME,
new ElserInternalServiceSettings(null, 1, "elser", null),
new ElserMlNodeTaskSettings(),
null
);

model.updateNumAllocations(1);
assertEquals(1, model.getServiceSettings().getNumAllocations().intValue());

model.updateNumAllocations(null);
assertNull(model.getServiceSettings().getNumAllocations());
}
}