Skip to content

Commit

Permalink
[7.17] Preserve context in ResultDeduplicator (#84038) (#96868)
Browse files Browse the repository at this point in the history
Today the `ResultDeduplicator` may complete a collection of listeners in
contexts different from the ones in which they were submitted. This
commit makes sure that the context is preserved in the listener.

Co-authored-by: David Turner <david.turner@elastic.co>
  • Loading branch information
arteam and DaveCTurner committed Jul 5, 2023
1 parent 9bb69a2 commit 7d11e41
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 20 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/84038.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 84038
summary: Preserve context in `ResultDeduplicator`
area: Infra/Core
type: bug
issues:
- 84036
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

package org.elasticsearch.action;

import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.ThreadContext;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -22,8 +24,13 @@
*/
public final class ResultDeduplicator<T, R> {

private final ThreadContext threadContext;
private final ConcurrentMap<T, CompositeListener> requests = ConcurrentCollections.newConcurrentMap();

public ResultDeduplicator(ThreadContext threadContext) {
this.threadContext = threadContext;
}

/**
* Ensures a given request not executed multiple times when another equal request is already in-flight.
* If the request is not yet known to the deduplicator it will invoke the passed callback with an {@link ActionListener}
Expand All @@ -35,7 +42,8 @@ public final class ResultDeduplicator<T, R> {
* @param callback Callback to be invoked with request and completion listener the first time the request is added to the deduplicator
*/
public void executeOnce(T request, ActionListener<R> listener, BiConsumer<T, ActionListener<R>> callback) {
ActionListener<R> completionListener = requests.computeIfAbsent(request, CompositeListener::new).addListener(listener);
ActionListener<R> completionListener = requests.computeIfAbsent(request, CompositeListener::new)
.addListener(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext));
if (completionListener != null) {
callback.accept(request, completionListener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ private static Priority parseReroutePriority(String priorityString) {

// a list of shards that failed during replication
// we keep track of these shards in order to avoid sending duplicate failed shard requests for a single failing shard.
private final ResultDeduplicator<FailedShardEntry, Void> remoteFailedShardsDeduplicator = new ResultDeduplicator<>();
private final ResultDeduplicator<FailedShardEntry, Void> remoteFailedShardsDeduplicator;

@Inject
public ShardStateAction(
Expand All @@ -131,6 +131,7 @@ public ShardStateAction(
this.transportService = transportService;
this.clusterService = clusterService;
this.threadPool = threadPool;
remoteFailedShardsDeduplicator = new ResultDeduplicator<>(threadPool.getThreadContext());

followUpRerouteTaskPriority = FOLLOW_UP_REROUTE_PRIORITY_SETTING.get(clusterService.getSettings());
clusterService.getClusterSettings()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements
private final Map<Snapshot, Map<ShardId, IndexShardSnapshotStatus>> shardSnapshots = new HashMap<>();

// A map of snapshots to the shardIds that we already reported to the master as failed
private final ResultDeduplicator<UpdateIndexShardSnapshotStatusRequest, Void> remoteFailedRequestDeduplicator =
new ResultDeduplicator<>();
private final ResultDeduplicator<UpdateIndexShardSnapshotStatusRequest, Void> remoteFailedRequestDeduplicator;

public SnapshotShardsService(
Settings settings,
Expand All @@ -100,6 +99,7 @@ public SnapshotShardsService(
this.transportService = transportService;
this.clusterService = clusterService;
this.threadPool = transportService.getThreadPool();
this.remoteFailedRequestDeduplicator = new ResultDeduplicator<>(threadPool.getThreadContext());
if (DiscoveryNode.canContainData(settings)) {
// this is only useful on the nodes that can hold data
clusterService.addListener(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ public class TaskCancellationService {
private static final Logger logger = LogManager.getLogger(TaskCancellationService.class);
private final TransportService transportService;
private final TaskManager taskManager;
private final ResultDeduplicator<CancelRequest, Void> deduplicator = new ResultDeduplicator<>();
private final ResultDeduplicator<CancelRequest, Void> deduplicator;

public TaskCancellationService(TransportService transportService) {
this.transportService = transportService;
this.taskManager = transportService.getTaskManager();
this.deduplicator = new ResultDeduplicator<>(transportService.getThreadPool().getThreadContext());
transportService.registerRequestHandler(
BAN_PARENT_ACTION_NAME,
ThreadPool.Names.SAME,
Expand Down
13 changes: 10 additions & 3 deletions server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import static org.hamcrest.Matchers.everyItem;
import static org.hamcrest.Matchers.in;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class TaskManagerTests extends ESTestCase {
private ThreadPool threadPool;
Expand Down Expand Up @@ -76,7 +77,9 @@ public void testResultsServiceRetryTotalTime() {
public void testTrackingChannelTask() throws Exception {
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
Set<Task> cancelledTasks = ConcurrentCollections.newConcurrentSet();
taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) {
final TransportService transportServiceMock = mock(TransportService.class);
when(transportServiceMock.getThreadPool()).thenReturn(threadPool);
taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) {
@Override
void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
assertThat(reason, equalTo("channel was closed"));
Expand Down Expand Up @@ -124,7 +127,9 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF
public void testTrackingTaskAndCloseChannelConcurrently() throws Exception {
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
Set<CancellableTask> cancelledTasks = ConcurrentCollections.newConcurrentSet();
taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) {
final TransportService transportServiceMock = mock(TransportService.class);
when(transportServiceMock.getThreadPool()).thenReturn(threadPool);
taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) {
@Override
void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
assertTrue("task [" + task + "] was cancelled already", cancelledTasks.add(task));
Expand Down Expand Up @@ -180,7 +185,9 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF

public void testRemoveBansOnChannelDisconnects() throws Exception {
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) {
final TransportService transportServiceMock = mock(TransportService.class);
when(transportServiceMock.getThreadPool()).thenReturn(threadPool);
taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) {
@Override
void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ResultDeduplicator;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;

Expand All @@ -29,27 +31,36 @@ public void testRequestDeduplication() throws Exception {
@Override
public void setParentTask(final TaskId taskId) {}
};
final ResultDeduplicator<TransportRequest, Void> deduplicator = new ResultDeduplicator<>();
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
final ResultDeduplicator<TransportRequest, Void> deduplicator = new ResultDeduplicator<>(threadContext);
final SetOnce<ActionListener<Void>> listenerHolder = new SetOnce<>();
final String headerName = "thread-context-header";
final AtomicInteger headerGenerator = new AtomicInteger();
int iterationsPerThread = scaledRandomIntBetween(100, 1000);
Thread[] threads = new Thread[between(1, 4)];
Phaser barrier = new Phaser(threads.length + 1);
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(() -> {
barrier.arriveAndAwaitAdvance();
for (int n = 0; n < iterationsPerThread; n++) {
deduplicator.executeOnce(request, new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
successCount.incrementAndGet();
}
final String headerValue = Integer.toString(headerGenerator.incrementAndGet());
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
threadContext.putHeader(headerName, headerValue);
deduplicator.executeOnce(request, new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
assertThat(threadContext.getHeader(headerName), equalTo(headerValue));
successCount.incrementAndGet();
}

@Override
public void onFailure(Exception e) {
assertThat(e, sameInstance(failure));
failureCount.incrementAndGet();
}
}, (req, reqListener) -> listenerHolder.set(reqListener));
@Override
public void onFailure(Exception e) {
assertThat(threadContext.getHeader(headerName), equalTo(headerValue));
assertThat(e, sameInstance(failure));
failureCount.incrementAndGet();
}
}, (req, reqListener) -> listenerHolder.set(reqListener));
}
}
});
threads[i].start();
Expand Down

0 comments on commit 7d11e41

Please sign in to comment.