Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support task cancellation cross clusters #55779

Closed
wants to merge 15 commits into from
Closed
1 change: 1 addition & 0 deletions server/build.gradle
Expand Up @@ -173,6 +173,7 @@ testingConventions {
IT {
baseClass "org.elasticsearch.test.ESIntegTestCase"
baseClass "org.elasticsearch.test.ESSingleNodeTestCase"
baseClass "org.elasticsearch.test.AbstractMultiClustersTestCase"
}
}
}
Expand Down
Expand Up @@ -19,6 +19,7 @@

package org.elasticsearch.action.admin.cluster.node.tasks.cancel;

import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.ResourceNotFoundException;
Expand All @@ -42,17 +43,26 @@
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.EmptyTransportResponseHandler;
import org.elasticsearch.transport.RemoteClusterAware;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportActionProxy;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.Collection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

/**
* Transport action that can be used to cancel currently running cancellable tasks.
Expand All @@ -70,6 +80,7 @@ public TransportCancelTasksAction(ClusterService clusterService, TransportServic
CancelTasksRequest::new, CancelTasksResponse::new, TaskInfo::new, ThreadPool.Names.MANAGEMENT);
transportService.registerRequestHandler(BAN_PARENT_ACTION_NAME, ThreadPool.Names.SAME, BanParentTaskRequest::new,
new BanParentRequestHandler());
TransportActionProxy.registerProxyAction(transportService, BAN_PARENT_ACTION_NAME, in -> TransportResponse.Empty.INSTANCE);
}

@Override
Expand Down Expand Up @@ -116,17 +127,17 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF
if (task.shouldCancelChildrenOnCancellation()) {
StepListener<Void> completedListener = new StepListener<>();
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(ActionListener.map(completedListener, r -> null), 3);
Collection<DiscoveryNode> childrenNodes =
Map<String, List<DiscoveryNode>> childConnections =
taskManager.startBanOnChildrenNodes(task.getId(), () -> groupedListener.onResponse(null));
taskManager.cancel(task, reason, () -> groupedListener.onResponse(null));

StepListener<Void> banOnNodesListener = new StepListener<>();
setBanOnNodes(reason, waitForCompletion, task, childrenNodes, banOnNodesListener);
setBanOnNodes(reason, waitForCompletion, task, childConnections, banOnNodesListener);
banOnNodesListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure);
// If we start unbanning when the last child task completed and that child task executed with a specific user, then unban
// requests are denied because internal requests can't run with a user. We need to remove bans with the current thread context.
final Runnable removeBansRunnable = transportService.getThreadPool().getThreadContext()
.preserveContext(() -> removeBanOnNodes(task, childrenNodes));
.preserveContext(() -> removeBanOnNodes(task, childConnections));
// We remove bans after all child tasks are completed although in theory we can do it on a per-node basis.
completedListener.whenComplete(r -> removeBansRunnable.run(), e -> removeBansRunnable.run());
// if wait_for_completion is true, then only return when (1) bans are placed on child nodes, (2) child tasks are
Expand All @@ -148,47 +159,74 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF
}

private void setBanOnNodes(String reason, boolean waitForCompletion, CancellableTask task,
Collection<DiscoveryNode> childNodes, ActionListener<Void> listener) {
if (childNodes.isEmpty()) {
Map<String, List<DiscoveryNode>> childConnections, ActionListener<Void> listener) {
final TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId());
sendBanParentRequests(childConnections, listener,
subNodes -> new BanParentTaskRequest(parentTaskId, reason, waitForCompletion, subNodes));
}

private void removeBanOnNodes(CancellableTask task, Map<String, List<DiscoveryNode>> childConnections) {
final TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId());
sendBanParentRequests(childConnections, ActionListener.wrap(() -> {}),
subNodes -> new BanParentTaskRequest(parentTaskId, subNodes));
}

private void sendBanParentRequests(Map<String, List<DiscoveryNode>> childConnections, ActionListener<Void> listener,
Function<List<DiscoveryNode>, BanParentTaskRequest> requestGenerator) {
if (childConnections.isEmpty()) {
listener.onResponse(null);
return;
}
logger.trace("cancelling task {} on child nodes {}", task.getId(), childNodes);
GroupedActionListener<Void> groupedListener =
new GroupedActionListener<>(ActionListener.map(listener, r -> null), childNodes.size());
final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(
new TaskId(clusterService.localNode().getId(), task.getId()), reason, waitForCompletion);
for (DiscoveryNode node : childNodes) {
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, banRequest,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
public void handleResponse(TransportResponse.Empty response) {
groupedListener.onResponse(null);
}

@Override
public void handleException(TransportException exp) {
assert ExceptionsHelper.unwrapCause(exp) instanceof ElasticsearchSecurityException == false;
logger.warn("Cannot send ban for tasks with the parent [{}] to the node [{}]", banRequest.parentTaskId, node);
groupedListener.onFailure(exp);
}
});
final int groupSize = childConnections.entrySet().stream()
.mapToInt(e -> e.getKey().equals(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY) ? e.getValue().size() : 1)
.sum();
final GroupedActionListener<Void> groupedListener =
new GroupedActionListener<>(ActionListener.map(listener, r -> null), groupSize);
for (Map.Entry<String, List<DiscoveryNode>> entry : childConnections.entrySet()) {
final String clusterAlias = entry.getKey();
if (clusterAlias.equals(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY)) {
final BanParentTaskRequest request = requestGenerator.apply(List.of());
for (DiscoveryNode node : entry.getValue()) {
sendBanParentRequest(() -> transportService.getConnection(node), request, groupedListener);
}
} else {
final ArrayList<DiscoveryNode> subNodes = new ArrayList<>(entry.getValue());
final DiscoveryNode targetNode = subNodes.remove(0);
dnhatn marked this conversation as resolved.
Show resolved Hide resolved
if (targetNode.getVersion().onOrAfter(Version.V_8_0_0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case of "proxy mode" connections (see also ProxyConnectionStrategy), targetNode.getVersion() will always return Version.CURRENT.minimumCompatibilityVersion() AFAICS, which means that this request won't be send to those nodes.

The actual version of the node on the other side is available using channel.getVersion().
On the other hand channel.getNode will return the (possibly fake) DiscoveryNode object that was used to create the connection.

/cc: @tbrooks8 This is quite trappy, anything we can change in the transport in the short term to avoid this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've modified the handshaking to update the version of the target node with its actual version. I also exposed the proxy node so that we can check the version of the proxy node. I think we are good here.

BanParentTaskRequest request = requestGenerator.apply(subNodes);
sendBanParentRequest(() -> transportService.getRemoteClusterService().getConnection(targetNode, clusterAlias),
request, groupedListener);
} else {
groupedListener.onResponse(null); // old versions do not support cancellation cross clusters
}
}
}
}

private void removeBanOnNodes(CancellableTask task, Collection<DiscoveryNode> childNodes) {
final BanParentTaskRequest request =
BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(clusterService.localNode().getId(), task.getId()));
for (DiscoveryNode node : childNodes) {
logger.trace("Sending remove ban for tasks with the parent [{}] to the node [{}]", request.parentTaskId, node);
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, request, new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
private void sendBanParentRequest(Supplier<Transport.Connection> connectionSupplier,
BanParentTaskRequest request, ActionListener<Void> listener) {
final Transport.Connection connection;
try {
connection = connectionSupplier.get();
} catch (Exception e) {
listener.onFailure(e);
return;
}
transportService.sendRequest(connection, BAN_PARENT_ACTION_NAME, request, TransportRequestOptions.EMPTY,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
public void handleResponse(TransportResponse.Empty response) {
listener.onResponse(null);
}

@Override
public void handleException(TransportException exp) {
assert ExceptionsHelper.unwrapCause(exp) instanceof ElasticsearchSecurityException == false;
logger.info("failed to remove the parent ban for task {} on node {}", request.parentTaskId, node);
logger.warn(new ParameterizedMessage("failed to send {} request for task {} to node {}",
request.ban ? "ban" : "unban", request.parentTaskId, connection.getNode()), exp);
listener.onFailure(exp);
}
});
}
}

private static class BanParentTaskRequest extends TransportRequest {
Expand All @@ -197,27 +235,22 @@ private static class BanParentTaskRequest extends TransportRequest {
private final boolean ban;
private final boolean waitForCompletion;
private final String reason;
private final List<DiscoveryNode> subNodes; // forward the request to these sub nodes

static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
return new BanParentTaskRequest(parentTaskId, reason, waitForCompletion);
}

static BanParentTaskRequest createRemoveBanParentTaskRequest(TaskId parentTaskId) {
return new BanParentTaskRequest(parentTaskId);
}

private BanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
private BanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion, List<DiscoveryNode> subNodes) {
this.parentTaskId = parentTaskId;
this.ban = true;
this.reason = reason;
this.waitForCompletion = waitForCompletion;
this.subNodes = Objects.requireNonNull(subNodes);
}

private BanParentTaskRequest(TaskId parentTaskId) {
private BanParentTaskRequest(TaskId parentTaskId, List<DiscoveryNode> subNodes) {
this.parentTaskId = parentTaskId;
this.ban = false;
this.reason = null;
this.waitForCompletion = false;
this.subNodes = Objects.requireNonNull(subNodes);
}

private BanParentTaskRequest(StreamInput in) throws IOException {
Expand All @@ -230,6 +263,11 @@ private BanParentTaskRequest(StreamInput in) throws IOException {
} else {
waitForCompletion = false;
}
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
subNodes = in.readList(DiscoveryNode::new);
} else {
subNodes = Collections.emptyList();
}
}

@Override
Expand All @@ -243,6 +281,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeBoolean(waitForCompletion);
}
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
out.writeList(subNodes);
}
}
}

Expand All @@ -253,18 +294,31 @@ public void messageReceived(final BanParentTaskRequest request, final TransportC
logger.debug("Received ban for the parent [{}] on the node [{}], reason: [{}]", request.parentTaskId,
clusterService.localNode().getId(), request.reason);
final List<CancellableTask> childTasks = taskManager.setBan(request.parentTaskId, request.reason);
final int groupSize = childTasks.size() + request.subNodes.size() + 1;
final GroupedActionListener<Void> listener = new GroupedActionListener<>(ActionListener.map(
new ChannelActionListener<>(channel, BAN_PARENT_ACTION_NAME, request), r -> TransportResponse.Empty.INSTANCE),
childTasks.size() + 1);
groupSize);
for (CancellableTask childTask : childTasks) {
cancelTaskAndDescendants(childTask, request.reason, request.waitForCompletion, listener);
}
final BanParentTaskRequest subRequest = new BanParentTaskRequest(
request.parentTaskId, request.reason, request.waitForCompletion, List.of());
for (DiscoveryNode subNode : request.subNodes) {
sendBanParentRequest(() -> transportService.getConnection(subNode), subRequest, listener);
}
listener.onResponse(null);
} else {
logger.debug("Removing ban for the parent [{}] on the node [{}]", request.parentTaskId,
clusterService.localNode().getId());
logger.debug("Removing ban for the parent [{}] on the node [{}]", request.parentTaskId, clusterService.localNode().getId());
taskManager.removeBan(request.parentTaskId);
channel.sendResponse(TransportResponse.Empty.INSTANCE);
final GroupedActionListener<Void> listener = new GroupedActionListener<>(
ActionListener.map(
new ChannelActionListener<>(channel, BAN_PARENT_ACTION_NAME, request), r -> TransportResponse.Empty.INSTANCE),
request.subNodes.size() + 1);
final BanParentTaskRequest subRequest = new BanParentTaskRequest(request.parentTaskId, List.of());
for (DiscoveryNode subNode : request.subNodes) {
sendBanParentRequest(() -> transportService.getConnection(subNode), subRequest, listener);
}
listener.onResponse(null);
}
}
}
Expand Down
Expand Up @@ -57,6 +57,7 @@
import org.elasticsearch.search.profile.ProfileShardResult;
import org.elasticsearch.search.profile.SearchProfileShardResults;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.RemoteClusterAware;
import org.elasticsearch.transport.RemoteClusterService;
Expand Down Expand Up @@ -213,7 +214,8 @@ protected void doExecute(Task task, SearchRequest searchRequest, ActionListener<
executeLocalSearch(task, timeProvider, searchRequest, localIndices, clusterState, listener);
} else {
if (shouldMinimizeRoundtrips(searchRequest)) {
ccsRemoteReduce(searchRequest, localIndices, remoteClusterIndices, timeProvider,
final TaskId taskId = new TaskId(clusterState.nodes().getLocalNodeId(), task.getId());
ccsRemoteReduce(taskId, searchRequest, localIndices, remoteClusterIndices, timeProvider,
searchService.aggReduceContextBuilder(searchRequest),
remoteClusterService, threadPool, listener,
(r, l) -> executeLocalSearch(task, timeProvider, r, localIndices, clusterState, l));
Expand Down Expand Up @@ -261,8 +263,9 @@ static boolean shouldMinimizeRoundtrips(SearchRequest searchRequest) {
source.collapse().getInnerHits().isEmpty();
}

static void ccsRemoteReduce(SearchRequest searchRequest, OriginalIndices localIndices, Map<String, OriginalIndices> remoteIndices,
SearchTimeProvider timeProvider, InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
static void ccsRemoteReduce(TaskId parentTaskId, SearchRequest searchRequest, OriginalIndices localIndices,
Map<String, OriginalIndices> remoteIndices, SearchTimeProvider timeProvider,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
RemoteClusterService remoteClusterService, ThreadPool threadPool, ActionListener<SearchResponse> listener,
BiConsumer<SearchRequest, ActionListener<SearchResponse>> localSearchConsumer) {

Expand All @@ -275,6 +278,7 @@ static void ccsRemoteReduce(SearchRequest searchRequest, OriginalIndices localIn
OriginalIndices indices = entry.getValue();
SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(searchRequest, indices.indices(),
clusterAlias, timeProvider.getAbsoluteStartMillis(), true);
ccsSearchRequest.setParentTask(parentTaskId);
dnhatn marked this conversation as resolved.
Show resolved Hide resolved
Client remoteClusterClient = remoteClusterService.getRemoteClusterClient(threadPool, clusterAlias);
remoteClusterClient.search(ccsSearchRequest, new ActionListener<SearchResponse>() {
@Override
Expand Down Expand Up @@ -312,6 +316,7 @@ public void onFailure(Exception e) {
OriginalIndices indices = entry.getValue();
SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(searchRequest, indices.indices(),
clusterAlias, timeProvider.getAbsoluteStartMillis(), false);
ccsSearchRequest.setParentTask(parentTaskId);
dnhatn marked this conversation as resolved.
Show resolved Hide resolved
ActionListener<SearchResponse> ccsListener = createCCSListener(clusterAlias, skipUnavailable, countDown,
skippedClusters, exceptions, searchResponseMerger, totalClusters, listener);
Client remoteClusterClient = remoteClusterService.getRemoteClusterClient(threadPool, clusterAlias);
Expand Down