Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support task cancellation cross clusters #55779

Closed
wants to merge 15 commits into from
Closed
1 change: 1 addition & 0 deletions server/build.gradle
Expand Up @@ -132,6 +132,7 @@ testingConventions {
IT {
baseClass "org.elasticsearch.test.ESIntegTestCase"
baseClass "org.elasticsearch.test.ESSingleNodeTestCase"
baseClass "org.elasticsearch.test.AbstractMultiClustersTestCase"
}
}
}
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

126 changes: 79 additions & 47 deletions server/src/main/java/org/elasticsearch/tasks/TaskManager.java
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchTimeoutException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
Expand All @@ -46,9 +47,12 @@
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.ConcurrentMapLong;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.NodeAndClusterAlias;
import org.elasticsearch.transport.TcpChannel;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
Expand All @@ -73,7 +77,7 @@
/**
* Task Manager service for keeping track of currently running tasks on the nodes
*/
public class TaskManager implements ClusterStateApplier {
public class TaskManager implements ClusterStateApplier, Closeable {

private static final Logger logger = LogManager.getLogger(TaskManager.class);

Expand All @@ -90,19 +94,25 @@ public class TaskManager implements ClusterStateApplier {

private final AtomicLong taskIdGenerator = new AtomicLong();

private final Map<TaskId, String> banedParents = new ConcurrentHashMap<>();
private final Map<TaskId, BanReason> bannedParents = new ConcurrentHashMap<>();

private TaskResultsService taskResultsService;

private DiscoveryNodes lastDiscoveryNodes = DiscoveryNodes.EMPTY_NODES;
private volatile DiscoveryNodes lastDiscoveryNodes = DiscoveryNodes.EMPTY_NODES;

private final ByteSizeValue maxHeaderSize;
private final Map<TcpChannel, ChannelPendingTaskTracker> channelPendingTaskTrackers = ConcurrentCollections.newConcurrentMap();
private final SetOnce<TaskCancellationService> cancellationService = new SetOnce<>();
private final long banParentRetainingIntervalInMillis;

public TaskManager(Settings settings, ThreadPool threadPool, Set<String> taskHeaders) {
this(settings, threadPool, TimeValue.timeValueSeconds(60), taskHeaders);
}

TaskManager(Settings settings, ThreadPool threadPool, TimeValue banParentRetainingInterval, Set<String> taskHeaders) {
this.threadPool = threadPool;
this.taskHeaders = new ArrayList<>(taskHeaders);
this.banParentRetainingIntervalInMillis = banParentRetainingInterval.millis();
this.maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings);
}

Expand All @@ -115,6 +125,11 @@ public void setTaskCancellationService(TaskCancellationService taskCancellationS
this.cancellationService.set(taskCancellationService);
}

@Override
public void close() throws IOException {
IOUtils.close(cancellationService.get());
}

/**
* Registers a task without parent task
*/
Expand Down Expand Up @@ -154,7 +169,7 @@ Task registerAndExecute(String type, TransportAction<Request, Response> action,
BiConsumer<Task, Response> onResponse, BiConsumer<Task, Exception> onFailure) {
final Releasable unregisterChildNode;
if (request.getParentTask().isSet()) {
unregisterChildNode = registerChildNode(request.getParentTask().getId(), lastDiscoveryNodes.getLocalNode());
unregisterChildNode = registerChildNode(request.getParentTask().getId(), lastDiscoveryNodes.getLocalNode(), null);
} else {
unregisterChildNode = () -> {};
}
Expand Down Expand Up @@ -195,12 +210,12 @@ private void registerCancellableTask(Task task) {
assert oldHolder == null;
// Check if this task was banned before we start it. The empty check is used to avoid
// computing the hash code of the parent taskId as most of the time banedParents is empty.
if (task.getParentTaskId().isSet() && banedParents.isEmpty() == false) {
String reason = banedParents.get(task.getParentTaskId());
if (task.getParentTaskId().isSet() && bannedParents.isEmpty() == false) {
BanReason reason = bannedParents.get(task.getParentTaskId());
if (reason != null) {
try {
holder.cancel(reason);
throw new TaskCancelledException("Task cancelled before it started: " + reason);
holder.cancel(reason.reason);
throw new TaskCancelledException("Task cancelled before it started: " + reason.reason);
} finally {
// let's clean up the registration
unregister(task);
Expand Down Expand Up @@ -248,15 +263,12 @@ public Task unregister(Task task) {
* Register a node on which a child task will execute. The returned {@link Releasable} must be called
* to unregister the child node once the child task is completed or failed.
*/
public Releasable registerChildNode(long taskId, DiscoveryNode node) {
public Releasable registerChildNode(long taskId, DiscoveryNode node, String clusterAlias) {
final CancellableTaskHolder holder = cancellableTasks.get(taskId);
if (holder != null) {
logger.trace("register child node [{}] task [{}]", node, taskId);
holder.registerChildNode(node);
return Releasables.releaseOnce(() -> {
logger.trace("unregister child node [{}] task [{}]", node, taskId);
holder.unregisterChildNode(node);
});
final NodeAndClusterAlias nodeAndClusterAlias = new NodeAndClusterAlias(node, clusterAlias);
holder.registerChildNode(nodeAndClusterAlias);
return Releasables.releaseOnce(() -> holder.unregisterChildNode(nodeAndClusterAlias));
}
return () -> {};
}
Expand Down Expand Up @@ -380,7 +392,7 @@ public CancellableTask getCancellableTask(long id) {
* Will be used in task manager stats and for debugging.
*/
public int getBanCount() {
return banedParents.size();
return bannedParents.size();
}

/**
Expand All @@ -391,14 +403,8 @@ public int getBanCount() {
*/
public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
logger.trace("setting ban for the parent task {} {}", parentTaskId, reason);

// Set the ban first, so the newly created tasks cannot be registered
synchronized (banedParents) {
if (lastDiscoveryNodes.nodeExists(parentTaskId.getNodeId())) {
// Only set the ban if the node is the part of the cluster
banedParents.put(parentTaskId, reason);
}
}
cleanupOldBanMarkers();
bannedParents.put(parentTaskId, new BanReason(reason, threadPool.relativeTimeInMillis()));
return cancellableTasks.values().stream()
.filter(t -> t.hasParent(parentTaskId))
.map(t -> t.task)
Expand All @@ -412,12 +418,42 @@ public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
*/
public void removeBan(TaskId parentTaskId) {
logger.trace("removing ban for the parent task {}", parentTaskId);
banedParents.remove(parentTaskId);
bannedParents.remove(parentTaskId);
cleanupOldBanMarkers();
}

// for testing
public Set<TaskId> getBannedTaskIds() {
return Collections.unmodifiableSet(banedParents.keySet());
return Collections.unmodifiableSet(bannedParents.keySet());
}

/**
* If a node that installed some ban parent markers disconnects before it can remove them,
* then we need to manually clean up those markers after the retaining interval elapsed.
*/
private void cleanupOldBanMarkers() {
final Iterator<Map.Entry<TaskId, BanReason>> iterator = bannedParents.entrySet().iterator();
final long timeInMillis = threadPool.relativeTimeInMillis();
while (iterator.hasNext()) {
final Map.Entry<TaskId, BanReason> entry = iterator.next();
final long elapsed = entry.getValue().lastUpdatedInMillis - timeInMillis;
if (elapsed > banParentRetainingIntervalInMillis) {
// If the ban marker is from an old node, then do not remove it as the old node does not send heartbeats.
final DiscoveryNode parentNode = lastDiscoveryNodes.get(entry.getKey().getNodeId());
if (parentNode != null && parentNode.getVersion().before(Version.V_8_0_0)) {
continue;
}
logger.debug("Clean up ban for the parent task [{}] after [{}]", entry.getKey(), TimeValue.timeValueMillis(elapsed));
iterator.remove();
}
}
}

void updateBanMarkerTimestamp(TaskId taskId) {
final BanReason banReason = bannedParents.get(taskId);
if (banReason != null) {
banReason.lastUpdatedInMillis = threadPool.relativeTimeInMillis();
}
}

/**
Expand All @@ -427,7 +463,7 @@ public Set<TaskId> getBannedTaskIds() {
* @param onChildTasksCompleted called when all child tasks are completed or failed
* @return the set of current nodes that have outstanding child tasks
*/
public Collection<DiscoveryNode> startBanOnChildrenNodes(long taskId, Runnable onChildTasksCompleted) {
public Collection<NodeAndClusterAlias> startBanOnChildrenNodes(long taskId, Runnable onChildTasksCompleted) {
final CancellableTaskHolder holder = cancellableTasks.get(taskId);
if (holder != null) {
return holder.startBan(onChildTasksCompleted);
Expand All @@ -440,21 +476,6 @@ public Collection<DiscoveryNode> startBanOnChildrenNodes(long taskId, Runnable o
@Override
public void applyClusterState(ClusterChangedEvent event) {
lastDiscoveryNodes = event.state().getNodes();
if (event.nodesRemoved()) {
synchronized (banedParents) {
lastDiscoveryNodes = event.state().getNodes();
dnhatn marked this conversation as resolved.
Show resolved Hide resolved
// Remove all bans that were registered by nodes that are no longer in the cluster state
Iterator<TaskId> banIterator = banedParents.keySet().iterator();
while (banIterator.hasNext()) {
TaskId taskId = banIterator.next();
if (lastDiscoveryNodes.nodeExists(taskId.getNodeId()) == false) {
logger.debug("Removing ban for the parent [{}] on the node [{}], reason: the parent node is gone", taskId,
event.state().getNodes().getLocalNode());
banIterator.remove();
}
}
}
}
}

/**
Expand All @@ -478,7 +499,7 @@ private static class CancellableTaskHolder {
private final CancellableTask task;
private boolean finished = false;
private List<Runnable> cancellationListeners = null;
private ObjectIntMap<DiscoveryNode> childTasksPerNode = null;
private ObjectIntMap<NodeAndClusterAlias> childTasksPerNode = null;
private boolean banChildren = false;
private List<Runnable> childTaskCompletedListeners = null;

Expand Down Expand Up @@ -555,7 +576,7 @@ public CancellableTask getTask() {
return task;
}

synchronized void registerChildNode(DiscoveryNode node) {
synchronized void registerChildNode(NodeAndClusterAlias node) {
if (banChildren) {
throw new TaskCancelledException("The parent task was cancelled, shouldn't start any child tasks");
}
Expand All @@ -565,7 +586,7 @@ synchronized void registerChildNode(DiscoveryNode node) {
childTasksPerNode.addTo(node, 1);
}

void unregisterChildNode(DiscoveryNode node) {
void unregisterChildNode(NodeAndClusterAlias node) {
final List<Runnable> listeners;
synchronized (this) {
if (childTasksPerNode.addTo(node, -1) == 0) {
Expand All @@ -581,8 +602,8 @@ void unregisterChildNode(DiscoveryNode node) {
notifyListeners(listeners);
}

Set<DiscoveryNode> startBan(Runnable onChildTasksCompleted) {
final Set<DiscoveryNode> pendingChildNodes;
Set<NodeAndClusterAlias> startBan(Runnable onChildTasksCompleted) {
final Set<NodeAndClusterAlias> pendingChildNodes;
final Runnable toRun;
synchronized (this) {
banChildren = true;
Expand Down Expand Up @@ -613,6 +634,7 @@ Set<DiscoveryNode> startBan(Runnable onChildTasksCompleted) {
* pending tasks associated that channel and cancel them as these results won't be retrieved by the parent task.
*
* @return a releasable that should be called when this pending task is completed
* TODO: support cancellation when the proxy connection gets disconnected
*/
public Releasable startTrackingCancellableChannelTask(TcpChannel channel, CancellableTask task) {
assert cancellableTasks.containsKey(task.getId()) : "task [" + task.getId() + "] is not registered yet";
Expand Down Expand Up @@ -702,4 +724,14 @@ public void cancelTaskAndDescendants(CancellableTask task, String reason, boolea
throw new IllegalStateException("TaskCancellationService is not initialized");
}
}

static final class BanReason {
volatile long lastUpdatedInMillis;
final String reason;

BanReason(String reason, long nowInMillis) {
this.lastUpdatedInMillis = nowInMillis;
this.reason = reason;
}
}
}
@@ -0,0 +1,60 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.transport;

import org.elasticsearch.cluster.node.DiscoveryNode;

import java.util.Objects;

public final class NodeAndClusterAlias {
private final DiscoveryNode node;
private final String clusterAlias;

public NodeAndClusterAlias(DiscoveryNode node, String clusterAlias) {
this.node = Objects.requireNonNull(node);
this.clusterAlias = clusterAlias;
}

public DiscoveryNode getNode() {
return node;
}

public String getClusterAlias() {
return clusterAlias;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
NodeAndClusterAlias that = (NodeAndClusterAlias) o;
return node.equals(that.node) && Objects.equals(clusterAlias, that.clusterAlias);
}

@Override
public int hashCode() {
return node.hashCode() * 31 + (clusterAlias == null ? 0 : clusterAlias.hashCode());
}

@Override
public String toString() {
return "cluster[" + clusterAlias + "] node [" + node + "]";
}
}
Expand Up @@ -47,6 +47,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -174,7 +175,9 @@ public Set<String> getRegisteredRemoteClusterNames() {
* @throws IllegalArgumentException if the remote cluster is unknown
*/
public Transport.Connection getConnection(DiscoveryNode node, String cluster) {
return getRemoteClusterConnection(cluster).getConnection(node);
final Transport.Connection connection = getRemoteClusterConnection(cluster).getConnection(node);
assert Objects.equals(connection.clusterAlias(), cluster) : connection.clusterAlias() + " != " + cluster;
return connection;
}

/**
Expand All @@ -193,7 +196,9 @@ public boolean isSkipUnavailable(String clusterAlias) {
}

public Transport.Connection getConnection(String cluster) {
return getRemoteClusterConnection(cluster).getConnection();
final Transport.Connection connection = getRemoteClusterConnection(cluster).getConnection();
assert Objects.equals(connection.clusterAlias(), cluster) : connection.clusterAlias() + " != " + cluster;
return connection;
}

RemoteClusterConnection getRemoteClusterConnection(String cluster) {
Expand Down