Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,55 @@ public void testModelIdDoesNotMatch() throws IOException {
);
}

public void testNumAllocationsIsUpdated() throws IOException {
var modelId = "update_num_allocations";
var deploymentId = modelId;

CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
var response = startMlNodeDeploymemnt(modelId, deploymentId);
assertOkOrCreated(response);

var inferenceId = "test_num_allocations_updated";
var putModel = putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
var serviceSettings = putModel.get("service_settings");
assertThat(
putModel.toString(),
serviceSettings,
is(
Map.of(
"num_allocations",
1,
"num_threads",
1,
"model_id",
"update_num_allocations",
"deployment_id",
"update_num_allocations"
)
)
);

assertOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2));

var updatedServiceSettings = getModel(inferenceId).get("service_settings");
assertThat(
updatedServiceSettings.toString(),
updatedServiceSettings,
is(
Map.of(
"num_allocations",
2,
"num_threads",
1,
"model_id",
"update_num_allocations",
"deployment_id",
"update_num_allocations"
)
)
);
}

private String endpointConfig(String deploymentId) {
return Strings.format("""
{
Expand Down Expand Up @@ -147,6 +196,20 @@ private Response startMlNodeDeploymemnt(String modelId, String deploymentId) thr
return client().performRequest(request);
}

private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations) throws IOException {
String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update";

var body = Strings.format("""
{
"number_of_allocations": %d
}
""", numAllocations);

Request request = new Request("POST", endPoint);
request.setJsonEntity(body);
return client().performRequest(request);
}

protected void stopMlNodeDeployment(String deploymentId) throws IOException {
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
Request request = new Request("POST", endpoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.stream.Stream;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.equalToIgnoringCase;
import static org.hamcrest.Matchers.hasSize;
Expand Down Expand Up @@ -326,4 +327,9 @@ public void testSupportedStream() throws Exception {
deleteModel(modelId);
}
}

public void testGetZeroModels() throws IOException {
var models = getModels("_all", TaskType.RERANK);
assertThat(models, empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ private void getModelsByTaskType(TaskType taskType, ActionListener<GetInferenceM
}

private void parseModels(List<UnparsedModel> unparsedModels, ActionListener<GetInferenceModelAction.Response> listener) {
if (unparsedModels.isEmpty()) {
listener.onResponse(new GetInferenceModelAction.Response(List.of()));
return;
}

var parsedModelsByService = new HashMap<String, List<Model>>();
try {
for (var unparsedModel : unparsedModels) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,8 @@ public ElasticsearchInternalServiceSettings getServiceSettings() {
return (ElasticsearchInternalServiceSettings) super.getServiceSettings();
}

public void updateNumAllocation(Integer numAllocations) {
this.internalServiceSettings = new ElasticsearchInternalServiceSettings(
numAllocations,
this.internalServiceSettings.getNumThreads(),
this.internalServiceSettings.modelId(),
this.internalServiceSettings.getAdaptiveAllocationsSettings()
);
public void updateNumAllocations(Integer numAllocations) {
this.internalServiceSettings.setNumAllocations(numAllocations);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -778,10 +778,17 @@ public List<DefaultConfigId> defaultConfigIds() {

@Override
public void updateModelsWithDynamicFields(List<Model> models, ActionListener<List<Model>> listener) {
if (models.isEmpty()) {
listener.onResponse(models);
return;
}

var modelsByDeploymentIds = new HashMap<String, ElasticsearchInternalModel>();
for (var model : models) {
assert model instanceof ElasticsearchInternalModel;

if (model instanceof ElasticsearchInternalModel esModel) {
modelsByDeploymentIds.put(esModel.internalServiceSettings.deloymentId(), esModel);
modelsByDeploymentIds.put(esModel.mlNodeDeploymentId(), esModel);
} else {
listener.onFailure(
new ElasticsearchStatusException(
Expand All @@ -794,19 +801,14 @@ public void updateModelsWithDynamicFields(List<Model> models, ActionListener<Lis
}
}

if (modelsByDeploymentIds.isEmpty()) {
listener.onResponse(models);
return;
}

String deploymentIds = String.join(",", modelsByDeploymentIds.keySet());
client.execute(
GetDeploymentStatsAction.INSTANCE,
new GetDeploymentStatsAction.Request(deploymentIds),
ActionListener.wrap(stats -> {
for (var deploymentStats : stats.getStats().results()) {
var model = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
model.updateNumAllocation(deploymentStats.getNumberOfAllocations());
model.updateNumAllocations(deploymentStats.getNumberOfAllocations());
}
listener.onResponse(new ArrayList<>(modelsByDeploymentIds.values()));
}, e -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings {
public static final String DEPLOYMENT_ID = "deployment_id";
public static final String ADAPTIVE_ALLOCATIONS = "adaptive_allocations";

private final Integer numAllocations;
private Integer numAllocations;
private final int numThreads;
private final String modelId;
private final AdaptiveAllocationsSettings adaptiveAllocationsSettings;
Expand Down Expand Up @@ -172,6 +172,10 @@ public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException {
: null;
}

public void setNumAllocations(Integer numAllocations) {
this.numAllocations = numAllocations;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ public void testUpdateNumAllocation() {
null
);

model.updateNumAllocation(1);
assertEquals(1, model.internalServiceSettings.getNumAllocations().intValue());
model.updateNumAllocations(1);
assertEquals(1, model.getServiceSettings().getNumAllocations().intValue());

model.updateNumAllocation(null);
assertNull(model.internalServiceSettings.getNumAllocations());
model.updateNumAllocations(null);
assertNull(model.getServiceSettings().getNumAllocations());
}
}