diff --git a/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java b/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java index 931f733b0a09d..560db3208b4d9 100644 --- a/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java @@ -58,7 +58,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; @@ -378,28 +377,24 @@ protected void newResponseAsync( public void testConcurrentlyCompletionAndCancellation() throws InterruptedException { final var action = getTestTransportNodesAction(); - final CountDownLatch onCancelledLatch = new CountDownLatch(1); - final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()) { - @Override - protected void onCancelled() { - onCancelledLatch.countDown(); - } - }; + final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()); final PlainActionFuture future = new PlainActionFuture<>(); action.execute(cancellableTask, new TestNodesRequest(), future); final List nodeResponses = new ArrayList<>(); final CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear(); + // Complete all but the last request for racing completion with cancellation for (int i = 0; i < capturedRequests.length - 1; i++) { final var capturedRequest = capturedRequests[i]; nodeResponses.add(completeOneRequest(capturedRequest)); } final var raceBarrier = new CyclicBarrier(3); + final var lastResponseFuture = new PlainActionFuture(); final Thread completeThread = new Thread(() -> { safeAwait(raceBarrier); - nodeResponses.add(completeOneRequest(capturedRequests[capturedRequests.length - 1])); + lastResponseFuture.onResponse(completeOneRequest(capturedRequests[capturedRequests.length - 1])); }); final Thread cancelThread = new Thread(() -> { safeAwait(raceBarrier); @@ -419,8 +414,11 @@ protected void onCancelled() { assertNotNull("expect task cancellation exception, but got\n" + ExceptionsHelper.stackTrace(e), taskCancelledException); assertThat(e.getMessage(), containsString("task cancelled [simulated]")); assertTrue(cancellableTask.isCancelled()); - safeAwait(onCancelledLatch); // wait for the latch, the listener for releasing node responses is called before it + // All previously captured responses are released due to cancellation assertTrue(nodeResponses.stream().allMatch(r -> r.hasReferences() == false)); + // Wait for the last response to be gathered and assert it is also released by either the concurrent cancellation or + // not tracked in onItemResponse at all due to already cancelled + assertFalse(safeGet(lastResponseFuture).hasReferences()); } completeThread.join(10_000);