diff --git a/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java b/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java index c58621d03ce8f..00aa44ac73acb 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.TransportSearchAction; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.set.Sets; @@ -44,7 +45,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.LongSupplier; public class RestCancellableNodeClientTests extends ESTestCase { @@ -79,7 +79,9 @@ public void testCompletedTasks() throws Exception { for (int j = 0; j < numTasks; j++) { PlainActionFuture actionFuture = new PlainActionFuture<>(); RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel); - threadPool.generic().submit(() -> client.execute(TransportSearchAction.TYPE, new SearchRequest(), actionFuture)); + futures.add( + threadPool.generic().submit(() -> client.execute(TransportSearchAction.TYPE, new SearchRequest(), actionFuture)) + ); futures.add(actionFuture); } } @@ -150,7 +152,7 @@ public void testChannelAlreadyClosed() { assertEquals(totalSearches, testClient.cancelledTasks.size()); } - public void testConcurrentExecuteAndClose() throws Exception { + public void testConcurrentExecuteAndClose() { final var testClient = new TestClient(Settings.EMPTY, threadPool, true); int initialHttpChannels = RestCancellableNodeClient.getNumChannels(); int numTasks = randomIntBetween(1, 30); @@ -254,7 +256,7 @@ public String getLocalNodeId() { private class TestHttpChannel implements HttpChannel { private final AtomicBoolean open = new AtomicBoolean(true); - private final AtomicReference> closeListener = new AtomicReference<>(); + private final SubscribableListener> closeListener = new SubscribableListener<>(); private final CountDownLatch closeLatch = new CountDownLatch(1); @Override @@ -273,8 +275,7 @@ public InetSocketAddress getRemoteAddress() { @Override public void close() { assertTrue("HttpChannel is already closed", open.compareAndSet(true, false)); - ActionListener listener = closeListener.get(); - if (listener != null) { + closeListener.andThenAccept(listener -> { boolean failure = randomBoolean(); threadPool.generic().submit(() -> { if (failure) { @@ -284,11 +285,10 @@ public void close() { } closeLatch.countDown(); }); - } + }); } private void awaitClose() throws InterruptedException { - assertNotNull("must set closeListener before calling awaitClose", closeListener.get()); close(); closeLatch.await(); } @@ -304,9 +304,8 @@ public void addCloseListener(ActionListener listener) { if (open.get() == false) { listener.onResponse(null); } else { - if (closeListener.compareAndSet(null, listener) == false) { - throw new AssertionError("close listener already set, only one is allowed!"); - } + assertFalse("close listener already set, only one is allowed!", closeListener.isDone()); + closeListener.onResponse(ActionListener.assertOnce(listener)); } } }