Skip to content

Commit

Permalink
Remove parent-task bans on channels disconnect (elastic#66066)
Browse files Browse the repository at this point in the history
Like elastic#56620, this change relies on channel disconnect instead of node
leave events to remove parent-task ban markers.

Relates elastic#65443
Relates elastic#56620
  • Loading branch information
dnhatn committed Dec 16, 2020
1 parent 07d2d72 commit e55f5d3
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 34 deletions.
Expand Up @@ -55,6 +55,7 @@
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.InternalTestCluster;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
Expand Down Expand Up @@ -307,6 +308,58 @@ public void testCancelOrphanedTasks() throws Exception {
}
}

public void testRemoveBanParentsOnDisconnect() throws Exception {
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4));
client().execute(TransportTestAction.ACTION, rootRequest);
Set<TestRequest> pendingRequests = allowPartialRequest(rootRequest);
TaskId rootTaskId = getRootTaskId(rootRequest);
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks()
.setTaskId(rootTaskId).waitForCompletion(true).execute();
try {
assertBusy(() -> {
for (DiscoveryNode node : nodes) {
TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager();
Set<TaskId> expectedBans = new HashSet<>();
for (TestRequest req : pendingRequests) {
if (req.node.equals(node)) {
List<Task> childTasks = taskManager.getTasks().values().stream()
.filter(t -> t.getParentTaskId() != null && t.getDescription().equals(req.taskDescription()))
.collect(Collectors.toList());
assertThat(childTasks, hasSize(1));
CancellableTask childTask = (CancellableTask) childTasks.get(0);
assertTrue(childTask.isCancelled());
expectedBans.add(childTask.getParentTaskId());
}
}
assertThat(taskManager.getBannedTaskIds(), equalTo(expectedBans));
}
}, 30, TimeUnit.SECONDS);

final Set<TaskId> bannedParents = new HashSet<>();
for (DiscoveryNode node : nodes) {
TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager();
bannedParents.addAll(taskManager.getBannedTaskIds());
}
// Disconnect some outstanding child connections
for (DiscoveryNode node : nodes) {
TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager();
for (TaskId bannedParent : bannedParents) {
if (bannedParent.getNodeId().equals(node.getId()) && randomBoolean()) {
Collection<Transport.Connection> childConns = taskManager.startBanOnChildTasks(bannedParent.getId(), () -> {});
for (Transport.Connection connection : randomSubsetOf(childConns)) {
connection.close();
}
}
}
}
} finally {
allowEntireRequest(rootRequest);
cancelFuture.actionGet();
ensureAllBansRemoved();
}
}

static TaskId getRootTaskId(TestRequest request) throws Exception {
SetOnce<TaskId> taskId = new SetOnce<>();
assertBusy(() -> {
Expand All @@ -326,6 +379,7 @@ static void waitForRootTask(ActionFuture<TestResponse> rootTask) {
rootTask.actionGet();
} catch (Exception e) {
final Throwable cause = ExceptionsHelper.unwrap(e, TaskCancelledException.class);
assertNotNull(cause);
assertThat(cause.getMessage(), anyOf(
equalTo("The parent task was cancelled, shouldn't start any child tasks"),
containsString("Task cancelled before it started:"),
Expand Down
Expand Up @@ -242,7 +242,7 @@ public void messageReceived(final BanParentTaskRequest request, final TransportC
if (request.ban) {
logger.debug("Received ban for the parent [{}] on the node [{}], reason: [{}]", request.parentTaskId,
localNodeId(), request.reason);
final List<CancellableTask> childTasks = taskManager.setBan(request.parentTaskId, request.reason);
final List<CancellableTask> childTasks = taskManager.setBan(request.parentTaskId, request.reason, channel);
final GroupedActionListener<Void> listener = new GroupedActionListener<>(
new ChannelActionListener<>(channel, BAN_PARENT_ACTION_NAME, request).map(r -> TransportResponse.Empty.INSTANCE),
childTasks.size() + 1);
Expand Down
136 changes: 109 additions & 27 deletions server/src/main/java/org/elasticsearch/tasks/TaskManager.java
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchTimeoutException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.cluster.ClusterChangedEvent;
Expand All @@ -45,15 +46,19 @@
import org.elasticsearch.common.util.concurrent.ConcurrentMapLong;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TaskTransportChannel;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransportChannel;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand All @@ -63,6 +68,8 @@
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

Expand All @@ -89,7 +96,7 @@ public class TaskManager implements ClusterStateApplier {

private final AtomicLong taskIdGenerator = new AtomicLong();

private final Map<TaskId, String> banedParents = new ConcurrentHashMap<>();
private final Map<TaskId, Ban> bannedParents = new ConcurrentHashMap<>();

private TaskResultsService taskResultsService;

Expand Down Expand Up @@ -154,13 +161,13 @@ private void registerCancellableTask(Task task) {
CancellableTaskHolder oldHolder = cancellableTasks.put(task.getId(), holder);
assert oldHolder == null;
// Check if this task was banned before we start it. The empty check is used to avoid
// computing the hash code of the parent taskId as most of the time banedParents is empty.
if (task.getParentTaskId().isSet() && banedParents.isEmpty() == false) {
String reason = banedParents.get(task.getParentTaskId());
if (reason != null) {
// computing the hash code of the parent taskId as most of the time bannedParents is empty.
if (task.getParentTaskId().isSet() && bannedParents.isEmpty() == false) {
final Ban ban = bannedParents.get(task.getParentTaskId());
if (ban != null) {
try {
holder.cancel(reason);
throw new TaskCancelledException("Task cancelled before it started: " + reason);
holder.cancel(ban.reason);
throw new TaskCancelledException("Task cancelled before it started: " + ban.reason);
} finally {
// let's clean up the registration
unregister(task);
Expand Down Expand Up @@ -345,7 +352,7 @@ public CancellableTask getCancellableTask(long id) {
* Will be used in task manager stats and for debugging.
*/
public int getBanCount() {
return banedParents.size();
return bannedParents.size();
}

/**
Expand All @@ -354,14 +361,27 @@ public int getBanCount() {
* This method is called when a parent task that has children is cancelled.
* @return a list of pending cancellable child tasks
*/
public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
public List<CancellableTask> setBan(TaskId parentTaskId, String reason, TransportChannel channel) {
logger.trace("setting ban for the parent task {} {}", parentTaskId, reason);

// Set the ban first, so the newly created tasks cannot be registered
synchronized (banedParents) {
if (lastDiscoveryNodes.nodeExists(parentTaskId.getNodeId())) {
// Only set the ban if the node is the part of the cluster
banedParents.put(parentTaskId, reason);
synchronized (bannedParents) {
if (channel.getVersion().onOrAfter(Version.V_7_12_0)) {
final Ban ban = bannedParents.computeIfAbsent(parentTaskId, k -> new Ban(reason, true));
assert ban.perChannel : "not a ban per channel";
while (channel instanceof TaskTransportChannel) {
channel = ((TaskTransportChannel) channel).getChannel();
}
if (channel instanceof TcpTransportChannel) {
startTrackingChannel(((TcpTransportChannel) channel).getChannel(), ban::registerChannel);
} else {
assert channel.getChannelType().equals("direct") : "expect direct channel; got [" + channel + "]";
ban.registerChannel(DIRECT_CHANNEL_TRACKER);
}
} else {
if (lastDiscoveryNodes.nodeExists(parentTaskId.getNodeId())) {
// Only set the ban if the node is the part of the cluster
final Ban existing = bannedParents.put(parentTaskId, new Ban(reason, false));
assert existing == null || existing.perChannel == false : "not a ban per node";
}
}
}
return cancellableTasks.values().stream()
Expand All @@ -377,12 +397,52 @@ public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
*/
public void removeBan(TaskId parentTaskId) {
logger.trace("removing ban for the parent task {}", parentTaskId);
banedParents.remove(parentTaskId);
bannedParents.remove(parentTaskId);
}

// for testing
public Set<TaskId> getBannedTaskIds() {
return Collections.unmodifiableSet(banedParents.keySet());
return Collections.unmodifiableSet(bannedParents.keySet());
}

private class Ban {
final String reason;
final boolean perChannel; // TODO: Remove this in 8.0
final Set<ChannelPendingTaskTracker> channels;

Ban(String reason, boolean perChannel) {
assert Thread.holdsLock(bannedParents);
this.reason = reason;
this.perChannel = perChannel;
if (perChannel) {
this.channels = new HashSet<>();
} else {
this.channels = Collections.emptySet();
}
}

void registerChannel(ChannelPendingTaskTracker channel) {
assert Thread.holdsLock(bannedParents);
assert perChannel : "not a ban per channel";
channels.add(channel);
}

boolean unregisterChannel(ChannelPendingTaskTracker channel) {
assert Thread.holdsLock(bannedParents);
assert perChannel : "not a ban per channel";
return channels.remove(channel);
}

int registeredChannels() {
assert Thread.holdsLock(bannedParents);
assert perChannel : "not a ban per channel";
return channels.size();
}

@Override
public String toString() {
return "Ban{" + "reason='" + reason + '\'' + ", perChannel=" + perChannel + ", channels=" + channels + '}';
}
}

/**
Expand All @@ -406,15 +466,15 @@ public Collection<Transport.Connection> startBanOnChildTasks(long taskId, Runnab
public void applyClusterState(ClusterChangedEvent event) {
lastDiscoveryNodes = event.state().getNodes();
if (event.nodesRemoved()) {
synchronized (banedParents) {
synchronized (bannedParents) {
lastDiscoveryNodes = event.state().getNodes();
// Remove all bans that were registered by nodes that are no longer in the cluster state
Iterator<TaskId> banIterator = banedParents.keySet().iterator();
final Iterator<Map.Entry<TaskId, Ban>> banIterator = bannedParents.entrySet().iterator();
while (banIterator.hasNext()) {
TaskId taskId = banIterator.next();
if (lastDiscoveryNodes.nodeExists(taskId.getNodeId()) == false) {
logger.debug("Removing ban for the parent [{}] on the node [{}], reason: the parent node is gone", taskId,
event.state().getNodes().getLocalNode());
final Map.Entry<TaskId, Ban> ban = banIterator.next();
if (ban.getValue().perChannel == false && lastDiscoveryNodes.nodeExists(ban.getKey().getNodeId()) == false) {
logger.debug("Removing ban for the parent [{}] on the node [{}], reason: the parent node is gone",
ban.getKey(), event.state().getNodes().getLocalNode());
banIterator.remove();
}
}
Expand Down Expand Up @@ -581,32 +641,39 @@ Set<Transport.Connection> startBan(Runnable onChildTasksCompleted) {
*/
public Releasable startTrackingCancellableChannelTask(TcpChannel channel, CancellableTask task) {
assert cancellableTasks.containsKey(task.getId()) : "task [" + task.getId() + "] is not registered yet";
final ChannelPendingTaskTracker tracker = startTrackingChannel(channel, trackerChannel -> trackerChannel.addTask(task));
return () -> tracker.removeTask(task);
}

private ChannelPendingTaskTracker startTrackingChannel(TcpChannel channel, Consumer<ChannelPendingTaskTracker> onRegister) {
final ChannelPendingTaskTracker tracker = channelPendingTaskTrackers.compute(channel, (k, curr) -> {
if (curr == null) {
curr = new ChannelPendingTaskTracker();
}
curr.addTask(task);
onRegister.accept(curr);
return curr;
});
if (tracker.registered.compareAndSet(false, true)) {
channel.addCloseListener(ActionListener.wrap(
r -> {
final ChannelPendingTaskTracker removedTracker = channelPendingTaskTrackers.remove(channel);
assert removedTracker == tracker;
cancelTasksOnChannelClosed(tracker.drainTasks());
onChannelClosed(tracker);
},
e -> {
assert false : new AssertionError("must not be here", e);
}));
}
return () -> tracker.removeTask(task);
return tracker;
}

// for testing
final int numberOfChannelPendingTaskTrackers() {
return channelPendingTaskTrackers.size();
}

private static final ChannelPendingTaskTracker DIRECT_CHANNEL_TRACKER = new ChannelPendingTaskTracker();

private static class ChannelPendingTaskTracker {
final AtomicBoolean registered = new AtomicBoolean();
final Semaphore permits = Assertions.ENABLED ? new Semaphore(Integer.MAX_VALUE) : null;
Expand Down Expand Up @@ -640,7 +707,8 @@ void removeTask(CancellableTask task) {
}
}

private void cancelTasksOnChannelClosed(Set<CancellableTask> tasks) {
private void onChannelClosed(ChannelPendingTaskTracker channel) {
final Set<CancellableTask> tasks = channel.drainTasks();
if (tasks.isEmpty() == false) {
threadPool.generic().execute(new AbstractRunnable() {
@Override
Expand All @@ -656,6 +724,20 @@ protected void doRun() {
}
});
}

// Unregister the closing channel and remove bans whose has no registered channels
synchronized (bannedParents) {
final Iterator<Map.Entry<TaskId, Ban>> iterator = bannedParents.entrySet().iterator();
while (iterator.hasNext()) {
final Map.Entry<TaskId, Ban> entry = iterator.next();
final Ban ban = entry.getValue();
if (ban.perChannel) {
if (ban.unregisterChannel(channel) && entry.getValue().registeredChannels() == 0) {
iterator.remove();
}
}
}
}
}

public void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
Expand Down
Expand Up @@ -20,6 +20,7 @@

import com.carrotsearch.randomizedtesting.RandomizedContext;
import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksAction;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
Expand All @@ -42,6 +43,9 @@
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.FakeTcpChannel;
import org.elasticsearch.transport.TestTransportChannels;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
Expand Down Expand Up @@ -360,7 +364,10 @@ public void testRegisterAndExecuteChildTaskWhileParentTaskIsBeingCanceled() thro
CancellableNodesRequest parentRequest = new CancellableNodesRequest("parent");
final Task parentTask = taskManager.register("test", "test", parentRequest);
final TaskId parentTaskId = parentTask.taskInfo(testNodes[0].getNodeId(), false).getTaskId();
taskManager.setBan(new TaskId(testNodes[0].getNodeId(), parentTask.getId()), "test");
taskManager.setBan(new TaskId(testNodes[0].getNodeId(), parentTask.getId()), "test",
TestTransportChannels.newFakeTcpTransportChannel(
testNodes[0].getNodeId(), new FakeTcpChannel(), threadPool,
"test", randomNonNegativeLong(), Version.CURRENT));
CancellableNodesRequest childRequest = new CancellableNodesRequest("child");
childRequest.setParentTask(parentTaskId);
CancellableTestNodesAction testAction = new CancellableTestNodesAction("internal:testAction", threadPool, testNodes[1]
Expand Down

0 comments on commit e55f5d3

Please sign in to comment.