diff --git a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java index 8be107b19197c..3d0f07bdd8671 100644 --- a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -432,15 +432,12 @@ public boolean nodeConnected(DiscoveryNode node) { @Override public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) { connectionProfile = connectionProfile == null ? defaultConnectionProfile : connectionProfile; - if (!lifecycle.started()) { - throw new IllegalStateException("can't add nodes to a stopped transport"); - } if (node == null) { throw new ConnectTransportException(null, "can't connect to a null node"); } - globalLock.readLock().lock(); + globalLock.readLock().lock(); // ensure we don't open connections while we are closing try { - + ensureOpen(); try (Releasable ignored = connectionLock.acquire(node.getId())) { if (!lifecycle.started()) { throw new IllegalStateException("can't add nodes to a stopped transport"); @@ -477,31 +474,40 @@ public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfil @Override public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile connectionProfile) throws IOException { + if (node == null) { + throw new ConnectTransportException(null, "can't open connection to a null node"); + } boolean success = false; NodeChannels nodeChannels = null; + globalLock.readLock().lock(); // ensure we don't open connections while we are closing try { - nodeChannels = connectToChannels(node, connectionProfile); - final Channel channel = nodeChannels.getChannels().get(0); // one channel is guaranteed by the connection profile - final TimeValue connectTimeout = connectionProfile.getConnectTimeout() == null ? - defaultConnectionProfile.getConnectTimeout() : - connectionProfile.getConnectTimeout(); - final TimeValue handshakeTimeout = connectionProfile.getHandshakeTimeout() == null ? - connectTimeout : connectionProfile.getHandshakeTimeout(); - final Version version = executeHandshake(node, channel, handshakeTimeout); - transportServiceAdapter.onConnectionOpened(node); - nodeChannels = new NodeChannels(nodeChannels, version);// clone the channels - we now have the correct version - success = true; - return nodeChannels; - } catch (ConnectTransportException e) { - throw e; - } catch (Exception e) { - // ConnectTransportExceptions are handled specifically on the caller end - we wrap the actual exception to ensure - // only relevant exceptions are logged on the caller end.. this is the same as in connectToNode - throw new ConnectTransportException(node, "general node connection failure", e); - } finally { - if (success == false) { - IOUtils.closeWhileHandlingException(nodeChannels); + ensureOpen(); + try { + nodeChannels = connectToChannels(node, connectionProfile); + final Channel channel = nodeChannels.getChannels().get(0); // one channel is guaranteed by the connection profile + final TimeValue connectTimeout = connectionProfile.getConnectTimeout() == null ? + defaultConnectionProfile.getConnectTimeout() : + connectionProfile.getConnectTimeout(); + final TimeValue handshakeTimeout = connectionProfile.getHandshakeTimeout() == null ? + connectTimeout : connectionProfile.getHandshakeTimeout(); + final Version version = executeHandshake(node, channel, handshakeTimeout); + transportServiceAdapter.onConnectionOpened(node); + nodeChannels = new NodeChannels(nodeChannels, version);// clone the channels - we now have the correct version + success = true; + return nodeChannels; + } catch (ConnectTransportException e) { + throw e; + } catch (Exception e) { + // ConnectTransportExceptions are handled specifically on the caller end - we wrap the actual exception to ensure + // only relevant exceptions are logged on the caller end.. this is the same as in connectToNode + throw new ConnectTransportException(node, "general node connection failure", e); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(nodeChannels); + } } + } finally { + globalLock.readLock().unlock(); } } @@ -1577,4 +1583,14 @@ protected final void onChannelClosed(Channel channel) { } } } + + /** + * Ensures this transport is still started / open + * @throws IllegalStateException if the transport is not started / open + */ + protected final void ensureOpen() { + if (lifecycle.started() == false) { + throw new IllegalStateException("transport has been stopped"); + } + } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java index 51774edcba57b..b4931ef8847f6 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java @@ -49,10 +49,10 @@ import java.net.Socket; import java.net.SocketException; import java.net.SocketTimeoutException; -import java.util.HashMap; -import java.util.IdentityHashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; @@ -76,7 +76,7 @@ public class MockTcpTransport extends TcpTransport */ public static final ConnectionProfile LIGHT_PROFILE; - private final Map openChannels = new IdentityHashMap<>(); + private final Set openChannels = new HashSet<>(); static { ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); @@ -289,7 +289,7 @@ public MockChannel(Socket socket, InetSocketAddress localAddress, String profile this.profile = profile; this.onClose = () -> onClose.accept(this); synchronized (openChannels) { - openChannels.put(this, Boolean.TRUE); + openChannels.add(this); } } @@ -305,6 +305,9 @@ public MockChannel(ServerSocket serverSocket, String profile) { this.profile = profile; this.activeChannel = null; this.onClose = null; + synchronized (openChannels) { + openChannels.add(this); + } } public void accept(Executor executor) throws IOException { @@ -313,10 +316,10 @@ public void accept(Executor executor) throws IOException { MockChannel incomingChannel = null; try { configureSocket(incomingSocket); - incomingChannel = new MockChannel(incomingSocket, localAddress, profile, workerChannels::remove); - //establish a happens-before edge between closing and accepting a new connection synchronized (this) { if (isOpen.get()) { + incomingChannel = new MockChannel(incomingSocket, localAddress, profile, workerChannels::remove); + //establish a happens-before edge between closing and accepting a new connection workerChannels.put(incomingChannel, Boolean.TRUE); // this spawns a new thread immediately, so OK under lock incomingChannel.loopRead(executor); @@ -360,7 +363,7 @@ protected void doRun() throws Exception { @Override public void close() throws IOException { if (isOpen.compareAndSet(true, false)) { - final Boolean removedChannel; + final boolean removedChannel; synchronized (openChannels) { removedChannel = openChannels.remove(this); } @@ -370,9 +373,19 @@ public void close() throws IOException { IOUtils.close(serverSocket, activeChannel, () -> IOUtils.close(workerChannels.keySet()), () -> cancellableThreads.cancel("channel closed"), onClose); } - assert removedChannel : "Channel was not removed or removed twice?"; + assert removedChannel: "Channel was not removed or removed twice?"; } } + + @Override + public String toString() { + return "MockChannel{" + + "profile='" + profile + '\'' + + ", isOpen=" + isOpen + + ", localAddress=" + localAddress + + ", isServerSocket=" + (serverSocket != null) + + '}'; + } }