Skip to content

Commit

Permalink
Child requests proactively cancel children tasks (#92588)
Browse files Browse the repository at this point in the history
To make this possible we modify the CancellableTasksTracker to track children tasks by the Request ID as well. That way, we can send an Action to cancel a child based on the parent task and the Request ID.

This is especially useful when parents' children requests timeout on the parents' side.

Fixes #90353
Relates #66992
  • Loading branch information
kingherc committed Apr 3, 2023
1 parent f353be2 commit 400b7ec
Show file tree
Hide file tree
Showing 25 changed files with 435 additions and 95 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/92588.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 92588
summary: Failed tasks proactively cancel children tasks
area: Snapshot/Restore
type: enhancement
issues:
- 90353

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,16 @@ public TaskId getParentTask() {
return request.getParentTask();
}

@Override
public void setRequestId(long requestId) {
request.setRequestId(requestId);
}

@Override
public long getRequestId() {
return request.getRequestId();
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return request.createTask(id, type, action, parentTaskId, headers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ private <T extends ClusterStateTaskListener> void executeAndPublishBatch(
@Override
public void setParentTask(TaskId taskId) {}

@Override
public void setRequestId(long requestId) {}

@Override
public TaskId getParentTask() {
return TaskId.EMPTY_TASK_ID;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ public void setParentTask(TaskId taskId) {
throw new UnsupportedOperationException("parent task if for persistent tasks shouldn't change");
}

@Override
public void setRequestId(long requestId) {
throw new UnsupportedOperationException("does not have a request ID");
}

@Override
public TaskId getParentTask() {
return parentTaskId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,45 @@ public CancellableTasksTracker(T[] empty) {
}

private final Map<Long, T> byTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();
private final Map<TaskId, T[]> byParentTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();
private final Map<TaskId, Map<Long, T[]>> byParentTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();

/**
* Gets the cancellable children of a parent task.
*
* Note: children of non-positive request IDs (e.g., -1) may be grouped together.
*/
public Stream<T> getChildrenByRequestId(TaskId parentTaskId, long childRequestId) {
Map<Long, T[]> byRequestId = byParentTaskId.get(parentTaskId);
if (byRequestId != null) {
T[] children = byRequestId.get(childRequestId);
if (children != null) {
return Arrays.stream(children);
}
}
return Stream.empty();
}

/**
* Add an item for the given task. Should only be called once for each task, and {@code item} must be unique per task too.
*/
public void put(Task task, T item) {
public void put(Task task, long requestId, T item) {
final long taskId = task.getId();
if (task.getParentTaskId().isSet()) {
byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> {
if (oldValue == null) {
oldValue = empty;
byParentTaskId.compute(task.getParentTaskId(), (taskKey, oldRequestIdMap) -> {
if (oldRequestIdMap == null) {
oldRequestIdMap = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();
}
final T[] newValue = Arrays.copyOf(oldValue, oldValue.length + 1);
newValue[oldValue.length] = item;
return newValue;

oldRequestIdMap.compute(requestId, (requestIdKey, oldValue) -> {
if (oldValue == null) {
oldValue = empty;
}
final T[] newValue = Arrays.copyOf(oldValue, oldValue.length + 1);
newValue[oldValue.length] = item;
return newValue;
});

return oldRequestIdMap;
});
}
final T oldItem = byTaskId.put(taskId, item);
Expand All @@ -60,36 +84,50 @@ public T get(long id) {
}

/**
* Remove (and return) the item that corresponds with the given task. Return {@code null} if not present. Safe to call multiple times
* for each task. However, {@link #getByParent} may return this task even after a call to this method completes, if the removal is
* actually being completed by a concurrent call that's still ongoing.
* Remove (and return) the item that corresponds with the given task and request ID. Return {@code null} if not present. Safe to call
* multiple times for each task. However, {@link #getByParent} may return this task even after a call to this method completes, if
* the removal is actually being completed by a concurrent call that's still ongoing.
*/
public T remove(Task task) {
final long taskId = task.getId();
final T oldItem = byTaskId.remove(taskId);
if (oldItem != null && task.getParentTaskId().isSet()) {
byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> {
if (oldValue == null) {
byParentTaskId.compute(task.getParentTaskId(), (taskKey, oldRequestIdMap) -> {
if (oldRequestIdMap == null) {
return null;
}
if (oldValue.length == 1) {
if (oldValue[0] == oldItem) {
return null;
} else {

for (Long requestId : oldRequestIdMap.keySet()) {
oldRequestIdMap.compute(requestId, (requestIdKey, oldValue) -> {
if (oldValue == null) {
return null;
}
if (oldValue.length == 1) {
if (oldValue[0] == oldItem) {
return null;
} else {
return oldValue;
}
}
if (oldValue[0] == oldItem) {
return Arrays.copyOfRange(oldValue, 1, oldValue.length);
}
for (int i = 1; i < oldValue.length; i++) {
if (oldValue[i] == oldItem) {
final T[] newValue = Arrays.copyOf(oldValue, oldValue.length - 1);
System.arraycopy(oldValue, i + 1, newValue, i, oldValue.length - i - 1);
return newValue;
}
}
return oldValue;
}
}
if (oldValue[0] == oldItem) {
return Arrays.copyOfRange(oldValue, 1, oldValue.length);
});
}
for (int i = 1; i < oldValue.length; i++) {
if (oldValue[i] == oldItem) {
final T[] newValue = Arrays.copyOf(oldValue, oldValue.length - 1);
System.arraycopy(oldValue, i + 1, newValue, i, oldValue.length - i - 1);
return newValue;
}

if (oldRequestIdMap.keySet().isEmpty()) {
return null;
}
return oldValue;

return oldRequestIdMap;
});
}
return oldItem;
Expand All @@ -109,11 +147,11 @@ public Collection<T> values() {
* started before this method was called have not completed.
*/
public Stream<T> getByParent(TaskId parentTaskId) {
final T[] byParent = byParentTaskId.get(parentTaskId);
final Map<Long, T[]> byParent = byParentTaskId.get(parentTaskId);
if (byParent == null) {
return Stream.empty();
}
return Arrays.stream(byParent);
return byParent.values().stream().flatMap(Stream::of);
}

// assertion for tests, not an invariant but should eventually be true
Expand All @@ -123,12 +161,14 @@ boolean assertConsistent() {

// every by-parent value must be tracked by task too; the converse isn't true since we don't track values without a parent
final Set<T> byTaskValues = new HashSet<>(byTaskId.values());
for (T[] byParent : byParentTaskId.values()) {
assert byParent.length > 0;
for (T t : byParent) {
assert byTaskValues.contains(t);
}
}
byParentTaskId.values().forEach(byParentMap -> {
byParentMap.forEach((requestId, byParentArray) -> {
assert byParentArray.length > 0;
for (T t : byParentArray) {
assert byTaskValues.contains(t);
}
});
});

return true;
}
Expand Down
12 changes: 12 additions & 0 deletions server/src/main/java/org/elasticsearch/tasks/TaskAwareRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ default void setParentTask(String parentTaskNode, long parentTaskId) {
*/
void setParentTask(TaskId taskId);

/**
* Gets the request ID. Defaults to -1, meaning "no request ID is set".
*/
default long getRequestId() {
return -1;
}

/**
* Set the request ID related to this task.
*/
void setRequestId(long requestId);

/**
* Get a reference to the task that created this request. Implementers should default to
* {@link TaskId#EMPTY_TASK_ID}, meaning "there is no parent".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.threadpool.ThreadPool;
Expand All @@ -44,6 +45,8 @@

public class TaskCancellationService {
public static final String BAN_PARENT_ACTION_NAME = "internal:admin/tasks/ban";
public static final String CANCEL_CHILD_ACTION_NAME = "internal:admin/tasks/cancel_child";
public static final TransportVersion VERSION_SUPPORTING_CANCEL_CHILD_ACTION = TransportVersion.V_8_8_0;
private static final Logger logger = LogManager.getLogger(TaskCancellationService.class);
private final TransportService transportService;
private final TaskManager taskManager;
Expand All @@ -59,6 +62,12 @@ public TaskCancellationService(TransportService transportService) {
BanParentTaskRequest::new,
new BanParentRequestHandler()
);
transportService.registerRequestHandler(
CANCEL_CHILD_ACTION_NAME,
ThreadPool.Names.SAME,
CancelChildRequest::new,
new CancelChildRequestHandler()
);
}

private String localNodeId() {
Expand Down Expand Up @@ -341,4 +350,69 @@ public void messageReceived(final BanParentTaskRequest request, final TransportC
}
}
}

private static class CancelChildRequest extends TransportRequest {

private final TaskId parentTaskId;
private final long childRequestId;
private final String reason;

static CancelChildRequest createCancelChildRequest(TaskId parentTaskId, long childRequestId, String reason) {
return new CancelChildRequest(parentTaskId, childRequestId, reason);
}

private CancelChildRequest(TaskId parentTaskId, long childRequestId, String reason) {
this.parentTaskId = parentTaskId;
this.childRequestId = childRequestId;
this.reason = reason;
}

private CancelChildRequest(StreamInput in) throws IOException {
super(in);
parentTaskId = TaskId.readFromStream(in);
childRequestId = in.readLong();
reason = in.readString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
parentTaskId.writeTo(out);
out.writeLong(childRequestId);
out.writeString(reason);
}
}

private class CancelChildRequestHandler implements TransportRequestHandler<CancelChildRequest> {
@Override
public void messageReceived(final CancelChildRequest request, final TransportChannel channel, Task task) throws Exception {
taskManager.cancelChildLocal(request.parentTaskId, request.childRequestId, request.reason);
channel.sendResponse(TransportResponse.Empty.INSTANCE);
}
}

/**
* Sends an action to cancel a child task, associated with the given request ID and parent task.
*/
public void cancelChildRemote(TaskId parentTask, long childRequestId, Transport.Connection childConnection, String reason) {
if (childConnection.getTransportVersion().onOrAfter(VERSION_SUPPORTING_CANCEL_CHILD_ACTION)) {
DiscoveryNode childNode = childConnection.getNode();
logger.debug(
"sending cancellation of child of parent task [{}] with request ID [{}] to node [{}] because of [{}]",
parentTask,
childRequestId,
childNode,
reason
);
final CancelChildRequest request = CancelChildRequest.createCancelChildRequest(parentTask, childRequestId, reason);
transportService.sendRequest(
childNode,
CANCEL_CHILD_ACTION_NAME,
request,
TransportRequestOptions.EMPTY,
EmptyTransportResponseHandler.INSTANCE_SAME
);
}
}

}

0 comments on commit 400b7ec

Please sign in to comment.