Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/121231.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 121231
summary: Fix inference update API calls with `task_type` in body or `deployment_id`
defined
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ public void testAttachToDeployment() throws IOException {
var results = infer(inferenceId, List.of("washing machine"));
assertNotNull(results.get("sparse_embedding"));

var updatedNumAllocations = randomIntBetween(1, 10);
var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING);
assertThat(
updatedEndpointConfig.get("service_settings"),
is(
Map.of(
"num_allocations",
updatedNumAllocations,
"num_threads",
1,
"model_id",
"attach_to_deployment",
"deployment_id",
"existing_deployment"
)
)
);

deleteModel(inferenceId);
// assert deployment not stopped
var stats = (List<Map<String, Object>>) getTrainedModelStats(modelId).get("trained_model_stats");
Expand Down Expand Up @@ -83,6 +101,24 @@ public void testAttachWithModelId() throws IOException {
var results = infer(inferenceId, List.of("washing machine"));
assertNotNull(results.get("sparse_embedding"));

var updatedNumAllocations = randomIntBetween(1, 10);
var updatedEndpointConfig = updateEndpoint(inferenceId, updatedEndpointConfig(updatedNumAllocations), TaskType.SPARSE_EMBEDDING);
assertThat(
updatedEndpointConfig.get("service_settings"),
is(
Map.of(
"num_allocations",
updatedNumAllocations,
"num_threads",
1,
"model_id",
"attach_with_model_id",
"deployment_id",
"existing_deployment_with_model_id"
)
)
);

stopMlNodeDeployment(deploymentId);
}

Expand Down Expand Up @@ -189,6 +225,16 @@ private String endpointConfig(String modelId, String deploymentId) {
""", modelId, deploymentId);
}

private String updatedEndpointConfig(int numAllocations) {
return Strings.format("""
{
"service_settings": {
"num_allocations": %d
}
}
""", numAllocations);
}

private Response startMlNodeDeploymemnt(String modelId, String deploymentId) throws IOException {
String endPoint = "/_ml/trained_models/"
+ modelId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ static Map<String, Object> updateEndpoint(String inferenceID, String modelConfig
return putRequest(endpoint, modelConfig);
}

static Map<String, Object> updateEndpoint(String inferenceID, String modelConfig) throws IOException {
String endpoint = Strings.format("_inference/%s/_update", inferenceID);
return putRequest(endpoint, modelConfig);
}

protected Map<String, Object> putPipeline(String pipelineId, String modelId) throws IOException {
String endpoint = Strings.format("_ingest/pipeline/%s", pipelineId);
String body = """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,61 @@ public void testUnifiedCompletionInference() throws Exception {
}
}

public void testUpdateEndpointWithWrongTaskTypeInURL() throws IOException {
putModel("sparse_embedding_model", mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var e = expectThrows(
ResponseException.class,
() -> updateEndpoint(
"sparse_embedding_model",
updateConfig(null, randomAlphaOfLength(10), randomIntBetween(1, 10)),
TaskType.TEXT_EMBEDDING
)
);
assertThat(e.getMessage(), containsString("Task type must match the task type of the existing endpoint"));
}

public void testUpdateEndpointWithWrongTaskTypeInBody() throws IOException {
putModel("sparse_embedding_model", mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var e = expectThrows(
ResponseException.class,
() -> updateEndpoint(
"sparse_embedding_model",
updateConfig(TaskType.TEXT_EMBEDDING, randomAlphaOfLength(10), randomIntBetween(1, 10))
)
);
assertThat(e.getMessage(), containsString("Task type must match the task type of the existing endpoint"));
}

public void testUpdateEndpointWithTaskTypeInURL() throws IOException {
testUpdateEndpoint(false, true);
}

public void testUpdateEndpointWithTaskTypeInBody() throws IOException {
testUpdateEndpoint(true, false);
}

public void testUpdateEndpointWithTaskTypeInBodyAndURL() throws IOException {
testUpdateEndpoint(true, true);
}

@SuppressWarnings("unchecked")
private void testUpdateEndpoint(boolean taskTypeInBody, boolean taskTypeInURL) throws IOException {
String endpointId = "sparse_embedding_model";
putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);

int temperature = randomIntBetween(1, 10);
var expectedConfig = updateConfig(taskTypeInBody ? TaskType.SPARSE_EMBEDDING : null, randomAlphaOfLength(1), temperature);
Map<String, Object> updatedEndpoint;
if (taskTypeInURL) {
updatedEndpoint = updateEndpoint(endpointId, expectedConfig, TaskType.SPARSE_EMBEDDING);
} else {
updatedEndpoint = updateEndpoint(endpointId, expectedConfig);
}

Map<String, Objects> updatedTaskSettings = (Map<String, Objects>) updatedEndpoint.get("task_settings");
assertEquals(temperature, updatedTaskSettings.get("temperature"));
}

private static Iterator<String> expectedResultsIterator(List<String> input) {
// The Locale needs to be ROOT to match what the test service is going to respond with
return Stream.concat(input.stream().map(s -> s.toUpperCase(Locale.ROOT)).map(InferenceCrudIT::expectedResult), Stream.of("[DONE]"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.XContentHelper;
Expand Down Expand Up @@ -51,6 +52,7 @@
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;

Expand Down Expand Up @@ -257,22 +259,22 @@ private void updateInClusterEndpoint(
ActionListener<Boolean> listener
) throws IOException {
// The model we are trying to update must have a trained model associated with it if it is an in-cluster deployment
throwIfTrainedModelDoesntExist(request);
var deploymentId = getDeploymentIdForInClusterEndpoint(existingParsedModel);
throwIfTrainedModelDoesntExist(request.getInferenceEntityId(), deploymentId);

Map<String, Object> serviceSettings = request.getContentAsSettings().serviceSettings();
if (serviceSettings != null && serviceSettings.get(NUM_ALLOCATIONS) instanceof Integer numAllocations) {

UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(
request.getInferenceEntityId()
);
UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
updateRequest.setNumberOfAllocations(numAllocations);

var delegate = listener.<CreateTrainedModelAssignmentAction.Response>delegateFailure((l2, response) -> {
modelRegistry.updateModelTransaction(newModel, existingParsedModel, l2);
});

logger.info(
"Updating trained model deployment for inference entity [{}] with [{}] num_allocations",
"Updating trained model deployment [{}] for inference entity [{}] with [{}] num_allocations",
deploymentId,
request.getInferenceEntityId(),
numAllocations
);
Expand All @@ -295,12 +297,26 @@ private boolean isInClusterService(String name) {
return List.of(ElasticsearchInternalService.NAME, ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME).contains(name);
}

private void throwIfTrainedModelDoesntExist(UpdateInferenceModelAction.Request request) throws ElasticsearchStatusException {
var assignments = TrainedModelAssignmentUtils.modelAssignments(request.getInferenceEntityId(), clusterService.state());
private String getDeploymentIdForInClusterEndpoint(Model model) {
if (model instanceof ElasticsearchInternalModel esModel) {
return esModel.mlNodeDeploymentId();
} else {
throw new IllegalStateException(
Strings.format(
"Cannot update inference endpoint [%s]. Class [%s] is not an Elasticsearch internal model",
model.getInferenceEntityId(),
model.getClass().getSimpleName()
)
);
}
}

private void throwIfTrainedModelDoesntExist(String inferenceEntityId, String deploymentId) throws ElasticsearchStatusException {
var assignments = TrainedModelAssignmentUtils.modelAssignments(deploymentId, clusterService.state());
if ((assignments == null || assignments.isEmpty())) {
throw ExceptionsHelper.entityNotFoundException(
Messages.MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE,
request.getInferenceEntityId()
inferenceEntityId

);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@

package org.elasticsearch.xpack.inference.rest;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.rest.RestUtils;
import org.elasticsearch.rest.Scope;
import org.elasticsearch.rest.ServerlessScope;
Expand Down Expand Up @@ -48,7 +46,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
inferenceEntityId = restRequest.param(INFERENCE_ID);
taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID));
} else {
throw new ElasticsearchStatusException("Inference ID must be provided in the path", RestStatus.BAD_REQUEST);
inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID);
taskType = TaskType.ANY;
}

var content = restRequest.requiredContent();
Expand Down