Skip to content

Commit

Permalink
Fix context leak in list tasks API (#93431)
Browse files Browse the repository at this point in the history
In #90977 we made the list tasks API fully async, but failed to notice
that if we waited for a task to complete then we would respond in the
thread context of the last-completing task. This commit fixes the
problem by restoring the context of the list-tasks task before
responding.

Closes #93428
  • Loading branch information
DaveCTurner committed Feb 2, 2023
1 parent b0c380d commit 8d44c9a
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 11 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/93431.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 93431
summary: Fix context leak in list tasks API
area: Task Management
type: bug
issues:
- 93428
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,38 @@

package org.elasticsearch.action.admin.cluster.tasks;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.PluginsService;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;

import java.util.Collection;
import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;

public class ListTasksIT extends ESIntegTestCase {
public class ListTasksIT extends ESSingleNodeTestCase {

public void testListTasksFilteredByDescription() {

Expand All @@ -40,4 +63,155 @@ public void testListTasksValidation() {
assertThat(ex.getMessage(), containsString("matching on descriptions is not available when [detailed] is false"));

}

public void testWaitForCompletion() throws Exception {
final var threadPool = getInstanceFromNode(ThreadPool.class);
final var threadContext = threadPool.getThreadContext();

final var barrier = new CyclicBarrier(2);
getInstanceFromNode(PluginsService.class).filterPlugins(TestPlugin.class).get(0).barrier = barrier;

final var testActionFuture = new PlainActionFuture<ActionResponse.Empty>();
client().execute(TEST_ACTION, new TestRequest(), testActionFuture.map(r -> {
assertThat(threadContext.getResponseHeaders().get(TestTransportAction.HEADER_NAME), hasItem(TestTransportAction.HEADER_VALUE));
return r;
}));

barrier.await(10, TimeUnit.SECONDS);

final var listTasksResponse = client().admin().cluster().prepareListTasks().setActions(TestTransportAction.NAME).get();
assertThat(listTasksResponse.getNodeFailures(), empty());
assertEquals(1, listTasksResponse.getTasks().size());
final var task = listTasksResponse.getTasks().get(0);
assertEquals(TestTransportAction.NAME, task.action());

final var listWaitFuture = new PlainActionFuture<Void>();
client().admin()
.cluster()
.prepareListTasks()
.setTargetTaskId(task.taskId())
.setWaitForCompletion(true)
.execute(listWaitFuture.delegateFailure((l, listResult) -> {
assertEquals(1, listResult.getTasks().size());
assertEquals(task.taskId(), listResult.getTasks().get(0).taskId());
// the task must now be complete:
client().admin().cluster().prepareListTasks().setActions(TestTransportAction.NAME).execute(l.map(listAfterWaitResult -> {
assertThat(listAfterWaitResult.getTasks(), empty());
assertThat(listAfterWaitResult.getNodeFailures(), empty());
assertThat(listAfterWaitResult.getTaskFailures(), empty());
return null;
}));
// and we must not see its header:
assertNull(threadContext.getResponseHeaders().get(TestTransportAction.HEADER_NAME));
}));

// briefly fill up the management pool so that (a) we know the wait has started and (b) we know it's not blocking
flushThreadPool(threadPool, ThreadPool.Names.MANAGEMENT);

final var getWaitFuture = new PlainActionFuture<Void>();
client().admin()
.cluster()
.prepareGetTask(task.taskId())
.setWaitForCompletion(true)
.execute(getWaitFuture.delegateFailure((l, getResult) -> {
assertTrue(getResult.getTask().isCompleted());
assertEquals(task.taskId(), getResult.getTask().getTask().taskId());
// the task must now be complete:
client().admin().cluster().prepareListTasks().setActions(TestTransportAction.NAME).execute(l.map(listAfterWaitResult -> {
assertThat(listAfterWaitResult.getTasks(), empty());
assertThat(listAfterWaitResult.getNodeFailures(), empty());
assertThat(listAfterWaitResult.getTaskFailures(), empty());
return null;
}));
// and we must not see its header:
assertNull(threadContext.getResponseHeaders().get(TestTransportAction.HEADER_NAME));
}));

// briefly fill up the generic pool so that (a) we know the wait has started and (b) we know it's not blocking
// flushThreadPool(threadPool, ThreadPool.Names.GENERIC); // TODO it _is_ blocking right now!!, unmute this in #93375

assertFalse(listWaitFuture.isDone());
assertFalse(testActionFuture.isDone());
barrier.await(10, TimeUnit.SECONDS);
testActionFuture.get(10, TimeUnit.SECONDS);
listWaitFuture.get(10, TimeUnit.SECONDS);
getWaitFuture.get(10, TimeUnit.SECONDS);
}

private void flushThreadPool(ThreadPool threadPool, String executor) throws InterruptedException, BrokenBarrierException,
TimeoutException {
var maxThreads = threadPool.info(executor).getMax();
var barrier = new CyclicBarrier(maxThreads + 1);
for (int i = 0; i < maxThreads; i++) {
threadPool.executor(executor).execute(() -> {
try {
barrier.await(10, TimeUnit.SECONDS);
} catch (Exception e) {
throw new AssertionError(e);
}
});
}
barrier.await(10, TimeUnit.SECONDS);
}

@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return List.of(TestPlugin.class);
}

private static final ActionType<ActionResponse.Empty> TEST_ACTION = new ActionType<>(
TestTransportAction.NAME,
in -> ActionResponse.Empty.INSTANCE
);

public static class TestPlugin extends Plugin implements ActionPlugin {
volatile CyclicBarrier barrier;

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return List.of(new ActionHandler<>(TEST_ACTION, TestTransportAction.class));
}
}

public static class TestRequest extends ActionRequest {
@Override
public ActionRequestValidationException validate() {
return null;
}
}

public static class TestTransportAction extends HandledTransportAction<TestRequest, ActionResponse.Empty> {

static final String NAME = "internal:test/action";

static final String HEADER_NAME = "HEADER_NAME";
static final String HEADER_VALUE = "HEADER_VALUE";

private final TestPlugin testPlugin;
private final ThreadPool threadPool;

@Inject
public TestTransportAction(
TransportService transportService,
ActionFilters actionFilters,
PluginsService pluginsService,
ThreadPool threadPool
) {
super(NAME, transportService, actionFilters, in -> new TestRequest());
testPlugin = pluginsService.filterPlugins(TestPlugin.class).get(0);
this.threadPool = threadPool;
}

@Override
protected void doExecute(Task task, TestRequest request, ActionListener<ActionResponse.Empty> listener) {
final var barrier = testPlugin.barrier;
assertNotNull(barrier);
threadPool.generic().execute(ActionRunnable.run(listener, () -> {
barrier.await(10, TimeUnit.SECONDS);
threadPool.getThreadContext().addResponseHeader(HEADER_NAME, HEADER_VALUE);
barrier.await(10, TimeUnit.SECONDS);
}));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.ListenableActionFuture;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
Expand Down Expand Up @@ -126,19 +127,21 @@ protected void processTasks(ListTasksRequest request, Consumer<Task> operation,
// No tasks to wait, we can run nodeOperation in the management pool
allMatchedTasksRemovedListener.onResponse(null);
} else {
final var threadPool = clusterService.threadPool();
future.addListener(
new ThreadedActionListener<>(
clusterService.threadPool().executor(ThreadPool.Names.MANAGEMENT),
false,
allMatchedTasksRemovedListener
threadPool.executor(ThreadPool.Names.MANAGEMENT),
new ContextPreservingActionListener<>(
threadPool.getThreadContext().newRestorableContext(false),
allMatchedTasksRemovedListener
)
)
);
var cancellable = clusterService.threadPool()
.schedule(
() -> future.onFailure(new ElasticsearchTimeoutException("Timed out waiting for completion of tasks")),
requireNonNullElse(request.getTimeout(), DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT),
ThreadPool.Names.SAME
);
var cancellable = threadPool.schedule(
() -> future.onFailure(new ElasticsearchTimeoutException("Timed out waiting for completion of tasks")),
requireNonNullElse(request.getTimeout(), DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT),
ThreadPool.Names.SAME
);
future.addListener(ActionListener.wrap(cancellable::cancel));
}
} else {
Expand Down

0 comments on commit 8d44c9a

Please sign in to comment.