Skip to content

Commit

Permalink
[ML] Start, stop and infer with deployment ID (#95168)
Browse files Browse the repository at this point in the history
A trained model deployment can be started with an optional deployment Id.
Deployment Ids and model Ids considered to be in the same namespace
and unique, a deployment id cannot be the same as any other deployment
or model Id unless it is the same as the model being deployed. When 
creating a new model, the id cannot match any models or deployments
  • Loading branch information
davidkyle committed Apr 18, 2023
1 parent bc6b5c9 commit 267e74f
Show file tree
Hide file tree
Showing 17 changed files with 405 additions and 167 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/95168.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 95168
summary: "Start, stop and infer of a trained model can now optionally use a deployment ID that is different to the model ID"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
"description": "A byte-size value for configuring the inference cache size. For example, 20mb.",
"required": false
},
"deployment_id":{
"type":"string",
"description": "The Id of the new deployment. Defaults to the model_id if not set.",
"required": false
},
"number_of_allocations":{
"type":"int",
"description": "The total number of allocations this model is assigned across machine learning nodes.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public static class Request extends BaseTasksRequest<Request> {

static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
static {
PARSER.declareString(Request.Builder::setModelId, InferModelAction.Request.DEPLOYMENT_ID);
PARSER.declareString(Request.Builder::setId, InferModelAction.Request.DEPLOYMENT_ID);
PARSER.declareObjectArray(Request.Builder::setDocs, (p, c) -> p.mapOrdered(), DOCS);
PARSER.declareString(Request.Builder::setInferenceTimeout, TIMEOUT);
PARSER.declareNamedObject(
Expand All @@ -88,12 +88,12 @@ public static class Request extends BaseTasksRequest<Request> {
public static Request.Builder parseRequest(String modelId, XContentParser parser) {
Request.Builder builder = PARSER.apply(parser, null);
if (modelId != null) {
builder.setModelId(modelId);
builder.setId(modelId);
}
return builder;
}

private String modelId;
private String id;
private final List<Map<String, Object>> docs;
private final InferenceConfigUpdate update;
private final TimeValue inferenceTimeout;
Expand Down Expand Up @@ -144,7 +144,7 @@ public static Request forTextInput(
boolean highPriority,
TimeValue inferenceTimeout
) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, InferModelAction.Request.DEPLOYMENT_ID);
this.id = ExceptionsHelper.requireNonNull(modelId, InferModelAction.Request.DEPLOYMENT_ID);
this.docs = docs;
this.textInput = textInput;
this.update = update;
Expand All @@ -154,7 +154,7 @@ public static Request forTextInput(

public Request(StreamInput in) throws IOException {
super(in);
modelId = in.readString();
id = in.readString();
docs = in.readImmutableList(StreamInput::readMap);
update = in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
inferenceTimeout = in.readOptionalTimeValue();
Expand All @@ -168,8 +168,8 @@ public Request(StreamInput in) throws IOException {
}
}

public String getModelId() {
return modelId;
public String getId() {
return id;
}

public List<Map<String, Object>> getDocs() {
Expand All @@ -188,8 +188,8 @@ public TimeValue getInferenceTimeout() {
return inferenceTimeout == null ? DEFAULT_TIMEOUT : inferenceTimeout;
}

public void setModelId(String modelId) {
this.modelId = modelId;
public void setId(String id) {
this.id = id;
}

/**
Expand Down Expand Up @@ -226,7 +226,7 @@ public ActionRequestValidationException validate() {
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeString(id);
out.writeCollection(docs, StreamOutput::writeGenericMap);
out.writeOptionalNamedWriteable(update);
out.writeOptionalTimeValue(inferenceTimeout);
Expand All @@ -240,15 +240,15 @@ public void writeTo(StreamOutput out) throws IOException {

@Override
public boolean match(Task task) {
return StartTrainedModelDeploymentAction.TaskMatcher.match(task, modelId);
return StartTrainedModelDeploymentAction.TaskMatcher.match(task, id);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferTrainedModelDeploymentAction.Request that = (InferTrainedModelDeploymentAction.Request) o;
return Objects.equals(modelId, that.modelId)
return Objects.equals(id, that.id)
&& Objects.equals(docs, that.docs)
&& Objects.equals(update, that.update)
&& Objects.equals(inferenceTimeout, that.inferenceTimeout)
Expand All @@ -258,17 +258,17 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
return Objects.hash(modelId, update, docs, inferenceTimeout, highPriority, textInput);
return Objects.hash(id, update, docs, inferenceTimeout, highPriority, textInput);
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, format("infer_trained_model_deployment[%s]", modelId), parentTaskId, headers);
return new CancellableTask(id, type, action, format("infer_trained_model_deployment[%s]", this.id), parentTaskId, headers);
}

public static class Builder {

private String modelId;
private String id;
private List<Map<String, Object>> docs;
private TimeValue timeout;
private InferenceConfigUpdate update;
Expand All @@ -277,8 +277,8 @@ public static class Builder {

private Builder() {}

public Builder setModelId(String modelId) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, InferModelAction.Request.DEPLOYMENT_ID);
public Builder setId(String id) {
this.id = ExceptionsHelper.requireNonNull(id, InferModelAction.Request.DEPLOYMENT_ID);
return this;
}

Expand Down Expand Up @@ -312,7 +312,7 @@ public Builder setSkipQueue(boolean skipQueue) {
}

public Request build() {
return new Request(modelId, update, docs, textInput, skipQueue, timeout);
return new Request(id, update, docs, textInput, skipQueue, timeout);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ public List<Tuple<String, Integer>> selectRandomStartedNodesWeighedOnAllocations
return List.of();
}

if (nodeIds.size() == 1) {
return List.of(new Tuple<>(nodeIds.get(0), numberOfRequests));
}

if (allocationSum == 0) {
// If we are in a mixed cluster where there are assignments prior to introducing allocation distribution
// we could have a zero-sum of allocations. We fall back to returning a random started node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ public static ResourceAlreadyExistsException dataFrameAnalyticsAlreadyExists(Str
return new ResourceAlreadyExistsException("A data frame analytics with id [{}] already exists", id);
}

public static ResourceNotFoundException missingModelDeployment(String deploymentId) {
return new ResourceNotFoundException("No known model deployment with id [{}]", deploymentId);
}

public static ResourceNotFoundException missingTrainedModel(String modelId) {
return new ResourceNotFoundException("No known trained model with model_id [{}]", modelId);
}
Expand Down
4 changes: 4 additions & 0 deletions x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ tasks.named("yamlRestTest").configure {
'ml/3rd_party_deployment/Test start deployment fails with missing model definition',
'ml/3rd_party_deployment/Test start deployment with low priority and multiple allocations',
'ml/3rd_party_deployment/Test start deployment with low priority and multiple threads per allocation',
'ml/3rd_party_deployment/Test stop deployments with allow_no_match',
'ml/3rd_party_deployment/Test cannot start 2 deployments with the same Id',
'ml/3rd_party_deployment/Test cannot start when deployment Id matches a different model',
'ml/3rd_party_deployment/Test cannot create model with a deployment Id',
'ml/calendar_crud/Test get calendar given missing',
'ml/calendar_crud/Test cannot create calendar with name _all',
'ml/calendar_crud/Test PageParams with ID is invalid',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.integration;

import org.elasticsearch.client.Response;

import java.io.IOException;
import java.util.List;

public class MultipleDeploymentsIT extends PyTorchModelRestTestCase {

@SuppressWarnings("unchecked")
public void testDeployModelMultipleTimes() throws IOException {
String baseModelId = "base-model";
createPassThroughModel(baseModelId);
putModelDefinition(baseModelId);
putVocabulary(List.of("these", "are", "my", "words"), baseModelId);

String forSearch = "for-search";
startWithDeploymentId(baseModelId, forSearch);

Response inference = infer("my words", forSearch);
assertOK(inference);

String forIngest = "for-ingest";
startWithDeploymentId(baseModelId, forIngest);

inference = infer("my words", forIngest);
assertOK(inference);
inference = infer("my words", forIngest);
assertOK(inference);

// TODO
// assertInferenceCount(1, forSearch);
// assertInferenceCount(2, forIngest);

stopDeployment(forSearch);
stopDeployment(forIngest);
}

private void putModelDefinition(String modelId) throws IOException {
putModelDefinition(modelId, PyTorchModelIT.BASE_64_ENCODED_MODEL, PyTorchModelIT.RAW_MODEL_SIZE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ public void testStopUsedDeploymentByIngestProcessor() throws IOException {
assertThat(
EntityUtils.toString(ex.getResponse().getEntity()),
containsString(
"Cannot stop deployment for model [test_stop_used_deployment_by_ingest_processor] as it is referenced by"
"Cannot stop deployment [test_stop_used_deployment_by_ingest_processor] as it is referenced by"
+ " ingest processors; use force to stop the deployment"
)
);
Expand Down Expand Up @@ -699,7 +699,7 @@ public void testStopWithModelAliasUsedDeploymentByIngestProcessor() throws IOExc
assertThat(
EntityUtils.toString(ex.getResponse().getEntity()),
containsString(
"Cannot stop deployment for model [test_stop_model_alias_used_deployment_by_ingest_processor] as it has a "
"Cannot stop deployment [test_stop_model_alias_used_deployment_by_ingest_processor] as it has a "
+ "model_alias [used_model_alias] that is still referenced"
+ " by ingest processors; use force to stop the deployment"
)
Expand Down Expand Up @@ -846,8 +846,8 @@ public void testStoppingDeploymentShouldTriggerRebalance() throws Exception {
putModelDefinition(modelId2);
putVocabulary(List.of("these", "are", "my", "words"), modelId2);

startDeployment(modelId1, AllocationStatus.State.STARTED.toString(), 100, 1, Priority.NORMAL);
startDeployment(modelId2, AllocationStatus.State.STARTING.toString(), 1, 1, Priority.NORMAL);
startDeployment(modelId1, modelId1, AllocationStatus.State.STARTED.toString(), 100, 1, Priority.NORMAL);
startDeployment(modelId2, modelId2, AllocationStatus.State.STARTING.toString(), 1, 1, Priority.NORMAL);

// Check second model did not get any allocations
assertAllocationCount(modelId2, 0);
Expand Down Expand Up @@ -888,7 +888,7 @@ public void testStartDeployment_TooManyAllocations() throws IOException {

ResponseException ex = expectThrows(
ResponseException.class,
() -> startDeployment(modelId, AllocationStatus.State.STARTED.toString(), 100, 1, Priority.NORMAL)
() -> startDeployment(modelId, modelId, AllocationStatus.State.STARTED.toString(), 100, 1, Priority.NORMAL)
);
assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(429));
assertThat(
Expand Down Expand Up @@ -924,7 +924,7 @@ public void testStartDeployment_GivenNoProcessorsLeft_AndLazyStartEnabled() thro
putModelDefinition(modelId2);
putVocabulary(List.of("these", "are", "my", "words"), modelId2);

startDeployment(modelId1, AllocationStatus.State.STARTED.toString(), 100, 1, Priority.NORMAL);
startDeployment(modelId1, modelId1, AllocationStatus.State.STARTED.toString(), 100, 1, Priority.NORMAL);

{
Request request = new Request(
Expand Down Expand Up @@ -1033,7 +1033,7 @@ public void testUpdateDeployment_GivenAllocationsAreDecreased() throws Exception
createPassThroughModel(modelId);
putModelDefinition(modelId);
putVocabulary(List.of("these", "are", "my", "words"), modelId);
startDeployment(modelId, "started", 2, 1, Priority.NORMAL);
startDeployment(modelId, modelId, "started", 2, 1, Priority.NORMAL);

assertBusy(() -> assertAllocationCount(modelId, 2));

Expand All @@ -1051,7 +1051,7 @@ public void testStartMultipleLowPriorityDeployments() throws Exception {
createPassThroughModel(modelId);
putModelDefinition(modelId);
putVocabulary(List.of("these", "are", "my", "words"), modelId);
startDeployment(modelId, "started", 1, 1, Priority.LOW);
startDeployment(modelId, modelId, "started", 1, 1, Priority.LOW);
assertAllocationCount(modelId, 1);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,19 @@ protected void assertAllocationCount(String modelId, int expectedAllocationCount
assertThat(allocations, equalTo(expectedAllocationCount));
}

@SuppressWarnings("unchecked")
protected void assertInferenceCount(int expectedCount, String deploymentId) throws IOException {
Response noInferenceCallsStatsResponse = getTrainedModelStats(deploymentId);
Map<String, Object> stats = entityAsMap(noInferenceCallsStatsResponse);

List<Map<String, Object>> nodes = (List<Map<String, Object>>) XContentMapValues.extractValue(
"trained_model_stats.0.deployment_stats.nodes",
stats
);
int inferenceCount = sumInferenceCountOnNodes(nodes);
assertEquals(expectedCount, inferenceCount);
}

protected int sumInferenceCountOnNodes(List<Map<String, Object>> nodes) {
int inferenceCount = 0;
for (var node : nodes) {
Expand Down Expand Up @@ -192,30 +205,38 @@ protected Response startDeployment(String modelId) throws IOException {
return startDeployment(modelId, AllocationStatus.State.STARTED.toString());
}

protected Response startWithDeploymentId(String modelId, String deploymentId) throws IOException {
return startDeployment(modelId, deploymentId, AllocationStatus.State.STARTED.toString(), 1, 1, Priority.NORMAL);
}

protected Response startDeployment(String modelId, String waitForState) throws IOException {
return startDeployment(modelId, waitForState, 1, 1, Priority.NORMAL);
return startDeployment(modelId, null, waitForState, 1, 1, Priority.NORMAL);
}

protected Response startDeployment(
String modelId,
String deploymentId,
String waitForState,
int numberOfAllocations,
int threadsPerAllocation,
Priority priority
) throws IOException {
Request request = new Request(
"POST",
"/_ml/trained_models/"
+ modelId
+ "/deployment/_start?timeout=40s&wait_for="
+ waitForState
+ "&threads_per_allocation="
+ threadsPerAllocation
+ "&number_of_allocations="
+ numberOfAllocations
+ "&priority="
+ priority
);
String endPoint = "/_ml/trained_models/"
+ modelId
+ "/deployment/_start?timeout=40s&wait_for="
+ waitForState
+ "&threads_per_allocation="
+ threadsPerAllocation
+ "&number_of_allocations="
+ numberOfAllocations
+ "&priority="
+ priority;

if (deploymentId != null) {
endPoint = endPoint + "&deployment_id=" + deploymentId;
}

Request request = new Request("POST", endPoint);
return client().performRequest(request);
}

Expand Down

0 comments on commit 267e74f

Please sign in to comment.