Skip to content

Commit

Permalink
[ML] make trained model rest APIs cancellable (#88009)
Browse files Browse the repository at this point in the history
This change makes all the trained model APIs cancellable, and addresses the handful of APIs that rely on our abstract resource structure.

closes: #87931
  • Loading branch information
benwtrent committed Jun 24, 2022
1 parent f153c2a commit 71ab4c4
Show file tree
Hide file tree
Showing 34 changed files with 311 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.action.util.PageParams;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public abstract class AbstractGetResourcesRequest extends ActionRequest {
Expand Down Expand Up @@ -93,5 +97,12 @@ public boolean equals(Object obj) {
&& allowNoResources == other.allowNoResources;
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, getCancelableTaskDescription(), parentTaskId, headers);
}

public abstract String getCancelableTaskDescription();

public abstract String getResourceIdField();
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
Expand Down Expand Up @@ -73,7 +74,7 @@ protected AbstractTransportGetResourcesAction(
this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
}

protected void searchResources(AbstractGetResourcesRequest request, ActionListener<QueryPage<Resource>> listener) {
protected void searchResources(AbstractGetResourcesRequest request, TaskId parentTaskId, ActionListener<QueryPage<Resource>> listener) {
String[] tokens = Strings.tokenizeToStringArray(request.getResourceId(), ",");
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().sort(
SortBuilders.fieldSort(request.getResourceIdField())
Expand All @@ -96,6 +97,7 @@ protected void searchResources(AbstractGetResourcesRequest request, ActionListen
indicesOptions
)
).source(customSearchOptions(sourceBuilder));
searchRequest.setParentTask(parentTaskId);

executeAsyncWithOrigin(
client.threadPool().getThreadContext(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import java.io.IOException;

import static org.elasticsearch.core.Strings.format;

public class GetDataFrameAnalyticsAction extends ActionType<GetDataFrameAnalyticsAction.Response> {

public static final GetDataFrameAnalyticsAction INSTANCE = new GetDataFrameAnalyticsAction();
Expand Down Expand Up @@ -46,6 +48,11 @@ public Request(StreamInput in) throws IOException {
public String getResourceIdField() {
return DataFrameAnalyticsConfig.ID.getPreferredName();
}

@Override
public String getCancelableTaskDescription() {
return format("get_data_frame_analytics[%s]", getResourceId());
}
}

public static class Response extends AbstractGetResourcesResponse<DataFrameAnalyticsConfig> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import java.io.IOException;

import static org.elasticsearch.core.Strings.format;

public class GetFiltersAction extends ActionType<GetFiltersAction.Response> {

public static final GetFiltersAction INSTANCE = new GetFiltersAction();
Expand All @@ -41,6 +43,11 @@ public Request(StreamInput in) throws IOException {
super(in);
}

@Override
public String getCancelableTaskDescription() {
return format("get_filters[%s]", getResourceId());
}

@Override
public String getResourceIdField() {
return MlFilter.ID.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.util.Objects;
import java.util.Set;

import static org.elasticsearch.core.Strings.format;

public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> {

public static final GetTrainedModelsAction INSTANCE = new GetTrainedModelsAction();
Expand Down Expand Up @@ -118,7 +120,6 @@ public int hashCode() {
public static class Request extends AbstractGetResourcesRequest {

public static final ParseField INCLUDE = new ParseField("include");
public static final String DEFINITION = "definition";
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
public static final ParseField TAGS = new ParseField("tags");

Expand Down Expand Up @@ -178,6 +179,11 @@ public boolean equals(Object obj) {
Request other = (Request) obj;
return super.equals(obj) && this.includes.equals(other.includes) && Objects.equals(tags, other.tags);
}

@Override
public String getCancelableTaskDescription() {
return format("get_trained_models[%s]", getResourceId());
}
}

public static class Response extends AbstractGetResourcesResponse<TrainedModelConfig> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.Set;

import static org.elasticsearch.core.RestApiVersion.onOrAfter;
import static org.elasticsearch.core.Strings.format;

public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStatsAction.Response> {

Expand Down Expand Up @@ -68,6 +69,11 @@ public Request(StreamInput in) throws IOException {
super(in);
}

@Override
public String getCancelableTaskDescription() {
return format("get_trained_model_stats[%s]", getResourceId());
}

@Override
public String getResourceIdField() {
return TrainedModelConfig.MODEL_ID.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
Expand All @@ -31,6 +34,8 @@
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.core.Strings.format;

public class InferModelAction extends ActionType<InferModelAction.Response> {
public static final String NAME = "cluster:internal/xpack/ml/inference/infer";
public static final String EXTERNAL_NAME = "cluster:monitor/xpack/ml/inference/infer";
Expand Down Expand Up @@ -176,6 +181,11 @@ public boolean equals(Object o) {
&& Objects.equals(objectsToInfer, that.objectsToInfer);
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, format("infer_trained_model[%s]", modelId), parentTaskId, headers);
}

@Override
public int hashCode() {
return Objects.hash(modelId, objectsToInfer, update, previouslyLicensed, timeout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
Expand All @@ -36,6 +38,7 @@
import java.util.Optional;

import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.core.Strings.format;

public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedModelDeploymentAction.Response> {

Expand Down Expand Up @@ -192,6 +195,11 @@ public int hashCode() {
return Objects.hash(deploymentId, update, docs, inferenceTimeout);
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, format("infer_trained_model_deployment[%s]", deploymentId), parentTaskId, headers);
}

public static class Builder {

private String deploymentId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Objects;

import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.core.Strings.format;

public class GetTransformAction extends ActionType<GetTransformAction.Response> {

Expand Down Expand Up @@ -76,6 +77,11 @@ public ActionRequestValidationException validate() {
return exception;
}

@Override
public String getCancelableTaskDescription() {
return format("get_transforms[%s]", getResourceId());
}

@Override
public String getResourceIdField() {
return TransformField.ID.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,15 @@ public void testStoreModelViaChunkedPersister() throws IOException {
PageParams.defaultParams(),
Collections.emptySet(),
ModelAliasMetadata.EMPTY,
null,
getIdsFuture
);
Tuple<Long, Map<String, Set<String>>> ids = getIdsFuture.actionGet();
assertThat(ids.v1(), equalTo(1L));
String inferenceModelId = ids.v2().keySet().iterator().next();

PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(inferenceModelId, GetTrainedModelsAction.Includes.all(), getTrainedModelFuture);
trainedModelProvider.getTrainedModel(inferenceModelId, GetTrainedModelsAction.Includes.all(), null, getTrainedModelFuture);

TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
Expand All @@ -128,7 +129,7 @@ public void testStoreModelViaChunkedPersister() throws IOException {
assertThat(storedConfig.getMetadata(), hasKey("hyperparameters"));

PlainActionFuture<Map<String, TrainedModelMetadata>> getTrainedMetadataFuture = new PlainActionFuture<>();
trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), getTrainedMetadataFuture);
trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), null, getTrainedMetadataFuture);

TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet().get(inferenceModelId);
assertThat(storedMetadata.getModelId(), startsWith(modelId));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public void testGetTrainedModelConfig() throws Exception {

AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand All @@ -132,7 +132,7 @@ public void testGetTrainedModelConfig() throws Exception {

getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.all(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.all(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand Down Expand Up @@ -204,7 +204,7 @@ public void testGetTrainedModelConfigWithMultiDocDefinition() throws Exception {

AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand Down Expand Up @@ -248,7 +248,7 @@ public void testGetTrainedModelConfigWithoutDefinition() throws Exception {

AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand All @@ -263,7 +263,7 @@ public void testGetMissingTrainingModelConfig() throws Exception {
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand All @@ -288,7 +288,7 @@ public void testGetMissingTrainingModelConfigDefinition() throws Exception {

AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand Down Expand Up @@ -335,7 +335,7 @@ public void testGetTruncatedModelDeprecatedDefinition() throws Exception {

AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand Down Expand Up @@ -388,7 +388,7 @@ public void testGetTruncatedModelDefinition() throws Exception {

AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), null, listener),
getConfigHolder,
exceptionHolder
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
Expand All @@ -31,12 +33,14 @@ public class TransportGetDataFrameAnalyticsAction extends AbstractTransportGetRe
DataFrameAnalyticsConfig,
GetDataFrameAnalyticsAction.Request,
GetDataFrameAnalyticsAction.Response> {
private final ClusterService clusterService;

@Inject
public TransportGetDataFrameAnalyticsAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
ClusterService clusterService,
NamedXContentRegistry xContentRegistry
) {
super(
Expand All @@ -47,6 +51,7 @@ public TransportGetDataFrameAnalyticsAction(
client,
xContentRegistry
);
this.clusterService = clusterService;
}

@Override
Expand Down Expand Up @@ -77,6 +82,7 @@ protected void doExecute(
) {
searchResources(
request,
new TaskId(clusterService.localNode().getId(), task.getId()),
ActionListener.wrap(queryPage -> listener.onResponse(new GetDataFrameAnalyticsAction.Response(queryPage)), listener::onFailure)
);
}
Expand Down

0 comments on commit 71ab4c4

Please sign in to comment.