diff --git a/server/src/main/java/org/elasticsearch/transport/ClusterConnectionManager.java b/server/src/main/java/org/elasticsearch/transport/ClusterConnectionManager.java index 4eef70a53ef00..287b86925a8a2 100644 --- a/server/src/main/java/org/elasticsearch/transport/ClusterConnectionManager.java +++ b/server/src/main/java/org/elasticsearch/transport/ClusterConnectionManager.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.AbstractRefCounted; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.internal.io.IOUtils; @@ -202,6 +203,7 @@ private void connectToNodeOrRetry( ActionListener.wrap( conn -> connectionValidator.validate(conn, resolvedProfile, ActionListener.runAfter(ActionListener.wrap(ignored -> { assert Transports.assertNotTransportThread("connection validator success"); + final RefCounted managerRefs = AbstractRefCounted.of(conn::onRemoved); try { if (connectedNodes.putIfAbsent(node, conn) != null) { assert false : "redundant connection to " + node; @@ -209,13 +211,14 @@ private void connectToNodeOrRetry( IOUtils.closeWhileHandlingException(conn); } else { logger.debug("connected to node [{}]", node); + managerRefs.incRef(); try { connectionListener.onNodeConnected(node, conn); } finally { conn.addCloseListener(ActionListener.wrap(() -> { connectedNodes.remove(node, conn); connectionListener.onNodeDisconnected(node, conn); - conn.onRemoved(); + managerRefs.decRef(); })); conn.addCloseListener(ActionListener.wrap(() -> { @@ -236,6 +239,7 @@ private void connectToNodeOrRetry( } finally { ListenableFuture future = pendingConnections.remove(node); assert future == currentListener : "Listener in pending map is different than the expected listener"; + managerRefs.decRef(); releaseOnce.run(); future.onResponse(conn); } diff --git a/server/src/test/java/org/elasticsearch/transport/ClusterConnectionManagerTests.java b/server/src/test/java/org/elasticsearch/transport/ClusterConnectionManagerTests.java index bf639a9776ef5..e754e40249ba8 100644 --- a/server/src/test/java/org/elasticsearch/transport/ClusterConnectionManagerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/ClusterConnectionManagerTests.java @@ -23,6 +23,8 @@ import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.RunOnce; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.TimeValue; @@ -116,15 +118,16 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti assertFalse(connectionManager.nodeConnected(node)); - AtomicReference connectionRef = new AtomicReference<>(); + AtomicReference validatedConnectionRef = new AtomicReference(); ConnectionManager.ConnectionValidator validator = (c, p, l) -> { - connectionRef.set(c); + validatedConnectionRef.set(c); l.onResponse(null); }; PlainActionFuture.get(fut -> connectionManager.connectToNode(node, connectionProfile, validator, fut.map(x -> null))); assertFalse(connection.isClosed()); assertTrue(connectionManager.nodeConnected(node)); + assertSame(connection, validatedConnectionRef.get()); assertSame(connection, connectionManager.getConnection(node)); assertEquals(1, connectionManager.size()); assertEquals(1, nodeConnectedCount.get()); @@ -494,14 +497,17 @@ public void testConcurrentConnectsAndDisconnects() throws Exception { }); }; - final Semaphore pendingConnections = new Semaphore(between(1, 1000)); + final int connectionCount = between(1, 1000); + final int disconnectionCount = randomFrom(connectionCount, connectionCount - 1, between(0, connectionCount - 1)); + final Semaphore connectionPermits = new Semaphore(connectionCount); + final Semaphore disconnectionPermits = new Semaphore(disconnectionCount); final int threadCount = between(1, 10); final CountDownLatch countDownLatch = new CountDownLatch(threadCount); final Runnable action = new Runnable() { @Override public void run() { - if (pendingConnections.tryAcquire()) { + if (connectionPermits.tryAcquire()) { connectionManager.connectToNode(node, null, validator, new ActionListener() { @Override public void onResponse(Releasable releasable) { @@ -509,20 +515,26 @@ public void onResponse(Releasable releasable) { final String description = releasable.toString(); fail(description); } - Releasables.close(releasable); - threadPool.generic().execute(() -> run()); + if (disconnectionPermits.tryAcquire()) { + Releasables.close(releasable); + } + runAgain(); } @Override public void onFailure(Exception e) { if (e instanceof ConnectTransportException && e.getMessage().contains("concurrently connecting and disconnecting")) { - pendingConnections.release(); - threadPool.generic().execute(() -> run()); + connectionPermits.release(); + runAgain(); } else { throw new AssertionError("unexpected", e); } } + + private void runAgain() { + threadPool.generic().execute(() -> run()); + } }); } else { countDownLatch.countDown(); @@ -536,7 +548,116 @@ public void onFailure(Exception e) { assertTrue("threads did not all complete", countDownLatch.await(10, TimeUnit.SECONDS)); assertTrue("validatorPermits not all released", validatorPermits.tryAcquire(Integer.MAX_VALUE, 10, TimeUnit.SECONDS)); - assertFalse("node still connected", connectionManager.nodeConnected(node)); + assertEquals("node still connected", disconnectionCount < connectionCount, connectionManager.nodeConnected(node)); + connectionManager.close(); + } + + @TestLogging(reason = "ignore copious 'closed by remote' messages", value = "org.elasticsearch.transport.ClusterConnectionManager:WARN") + public void testConcurrentConnectsAndCloses() throws Exception { + final DiscoveryNode node = new DiscoveryNode("", new TransportAddress(InetAddress.getLoopbackAddress(), 0), Version.CURRENT); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(new TestConnect(node)); + return null; + }).when(transport).openConnection(eq(node), any(), anyActionListener()); + + final Semaphore validatorPermits = new Semaphore(Integer.MAX_VALUE); + + final ConnectionManager.ConnectionValidator validator = (c, p, l) -> { + assertTrue(validatorPermits.tryAcquire()); + threadPool.executor(randomFrom(ThreadPool.Names.GENERIC, ThreadPool.Names.SAME)).execute(() -> { + try { + l.onResponse(null); + } finally { + validatorPermits.release(); + } + }); + }; + + final Semaphore closePermits = new Semaphore(between(1, 1000)); + final int connectThreadCount = between(1, 3); + final int closeThreadCount = between(1, 3); + final CountDownLatch countDownLatch = new CountDownLatch(connectThreadCount + closeThreadCount); + + final PlainActionFuture cleanlyOpenedConnectionFuture = new PlainActionFuture<>(); + final RefCounted closingRefs = AbstractRefCounted.of( + () -> connectionManager.connectToNode( + node, + null, + validator, + cleanlyOpenedConnectionFuture.map(r -> connectionManager.nodeConnected(node)) + ) + ); + + final Runnable connectAction = new Runnable() { + private void runAgain() { + threadPool.generic().execute(this); + } + + @Override + public void run() { + if (cleanlyOpenedConnectionFuture.isDone() == false) { + connectionManager.connectToNode(node, null, validator, new ActionListener() { + @Override + public void onResponse(Releasable releasable) { + runAgain(); + } + + @Override + public void onFailure(Exception e) { + if (e instanceof ConnectTransportException + && e.getMessage().contains("concurrently connecting and disconnecting")) { + runAgain(); + } else { + throw new AssertionError("unexpected", e); + } + } + + }); + } else { + countDownLatch.countDown(); + } + } + }; + + final Runnable closeAction = new Runnable() { + private void runAgain() { + threadPool.generic().execute(this); + } + + @Override + public void run() { + closingRefs.decRef(); + if (closePermits.tryAcquire() && closingRefs.tryIncRef()) { + try { + Transport.Connection connection = connectionManager.getConnection(node); + connection.addRemovedListener(ActionListener.wrap(this::runAgain)); + connection.close(); + } catch (NodeNotConnectedException e) { + closePermits.release(); + runAgain(); + } + } else { + countDownLatch.countDown(); + } + } + }; + + for (int i = 0; i < connectThreadCount; i++) { + connectAction.run(); + } + for (int i = 0; i < closeThreadCount; i++) { + closingRefs.incRef(); + closeAction.run(); + } + closingRefs.decRef(); + + assertTrue("threads did not all complete", countDownLatch.await(10, TimeUnit.SECONDS)); + assertFalse(closingRefs.hasReferences()); + assertTrue(cleanlyOpenedConnectionFuture.actionGet(0, TimeUnit.SECONDS)); + + assertTrue("validatorPermits not all released", validatorPermits.tryAcquire(Integer.MAX_VALUE, 10, TimeUnit.SECONDS)); connectionManager.close(); }