diff --git a/ambry-network/src/main/java/com.github.ambry.network/NetworkMetrics.java b/ambry-network/src/main/java/com.github.ambry.network/NetworkMetrics.java index 3058c36e55..63008f9212 100644 --- a/ambry-network/src/main/java/com.github.ambry.network/NetworkMetrics.java +++ b/ambry-network/src/main/java/com.github.ambry.network/NetworkMetrics.java @@ -44,7 +44,7 @@ public class NetworkMetrics { public final Counter selectorKeyOperationErrorCount; public final Counter selectorCloseKeyErrorCount; public final Counter selectorCloseSocketErrorCount; - public Gauge selectorActiveConnections; + public Gauge numActiveConnections; public final Map selectorNodeMetricMap; // Plaintext metrics @@ -121,8 +121,12 @@ public NetworkMetrics(MetricRegistry registry) { selectorNodeMetricMap = new HashMap(); } - public void initializeSelectorMetricsIfRequired(final AtomicLong activeConnections) { - selectorActiveConnections = new Gauge() { + /** + * Initializes a few network metrics for the selector + * @param activeConnections count of current active connections + */ + void initializeSelectorMetrics(final AtomicLong activeConnections) { + numActiveConnections = new Gauge() { @Override public Long getValue() { return activeConnections.get(); diff --git a/ambry-network/src/main/java/com.github.ambry.network/PlainTextTransmission.java b/ambry-network/src/main/java/com.github.ambry.network/PlainTextTransmission.java index 625fb141b2..d5ff4826e7 100644 --- a/ambry-network/src/main/java/com.github.ambry.network/PlainTextTransmission.java +++ b/ambry-network/src/main/java/com.github.ambry.network/PlainTextTransmission.java @@ -26,7 +26,7 @@ * Transmission used to speak plain text to the underlying channel. */ public class PlainTextTransmission extends Transmission { - private static final Logger logger = LoggerFactory.getLogger(SSLTransmission.class); + private static final Logger logger = LoggerFactory.getLogger(PlainTextTransmission.class); public PlainTextTransmission(String connectionId, SocketChannel socketChannel, SelectionKey key, Time time, NetworkMetrics metrics) { diff --git a/ambry-network/src/main/java/com.github.ambry.network/Selector.java b/ambry-network/src/main/java/com.github.ambry.network/Selector.java index d4750c1160..452d3fbf90 100644 --- a/ambry-network/src/main/java/com.github.ambry.network/Selector.java +++ b/ambry-network/src/main/java/com.github.ambry.network/Selector.java @@ -26,6 +26,7 @@ import java.nio.channels.UnresolvedAddressException; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -72,10 +73,11 @@ public class Selector implements Selectable { private final List completedReceives; private final List disconnected; private final List connected; + private final Set unreadyConnections; private final Time time; private final NetworkMetrics metrics; private final AtomicLong IdGenerator; - private AtomicLong activeConnections; + private final AtomicLong numActiveConnections; private final SSLFactory sslFactory; /** @@ -92,8 +94,9 @@ public Selector(NetworkMetrics metrics, Time time, SSLFactory sslFactory) this.disconnected = new ArrayList(); this.metrics = metrics; this.IdGenerator = new AtomicLong(0); - this.activeConnections = new AtomicLong(0); - this.metrics.initializeSelectorMetricsIfRequired(activeConnections); + numActiveConnections = new AtomicLong(0); + unreadyConnections = new HashSet<>(); + metrics.initializeSelectorMetrics(numActiveConnections); this.sslFactory = sslFactory; } @@ -163,7 +166,7 @@ public String connect(InetSocketAddress address, int sendBufferSize, int receive } key.attach(transmission); this.keyMap.put(connectionId, key); - activeConnections.set(this.keyMap.size()); + numActiveConnections.set(this.keyMap.size()); return connectionId; } @@ -190,7 +193,7 @@ public String register(SocketChannel channel, PortType portType) } key.attach(transmission); this.keyMap.put(connectionId, key); - activeConnections.set(this.keyMap.size()); + numActiveConnections.set(this.keyMap.size()); return connectionId; } @@ -315,7 +318,13 @@ public void poll(long timeoutMs, List sends) Transmission transmission = getTransmission(key); try { if (key.isConnectable()) { - handleConnect(key, transmission); + transmission.finishConnect(); + if (transmission.ready()) { + connected.add(transmission.getConnectionId()); + metrics.selectorConnectionCreated.inc(); + } else { + unreadyConnections.add(transmission.getConnectionId()); + } } /* if channel is not ready, finish prepare */ @@ -347,12 +356,28 @@ public void poll(long timeoutMs, List sends) close(key); } } + checkUnreadyConnectionsStatus(); this.metrics.selectorIORate.inc(); } long endIo = time.milliseconds(); this.metrics.selectorIOTime.update(endIo - endSelect); } + /** + * Check readiness for unready connections and add to completed list if ready + */ + private void checkUnreadyConnectionsStatus() { + Iterator iterator = unreadyConnections.iterator(); + while (iterator.hasNext()) { + String connId = iterator.next(); + if (isChannelReady(connId)) { + connected.add(connId); + iterator.remove(); + metrics.selectorConnectionCreated.inc(); + } + } + } + /** * Generate the description for a SocketChannel */ @@ -368,7 +393,7 @@ private String socketDescription(SocketChannel channel) { } /** - * Returns true if channel is ready after completing handshake to accept reads/writes + * Returns {@code true} if channel is ready to send or receive data, {@code false} otherwise * @param connectionId upon which readiness is checked for * @return true if channel is ready to accept reads/writes, false otherwise */ @@ -397,8 +422,8 @@ public List connected() { return this.connected; } - public long getActiveConnections() { - return activeConnections.get(); + public long getNumActiveConnections() { + return numActiveConnections.get(); } /** @@ -452,7 +477,8 @@ private void close(SelectionKey key) { logger.debug("Closing connection from {}", transmission.getConnectionId()); this.disconnected.add(transmission.getConnectionId()); this.keyMap.remove(transmission.getConnectionId()); - activeConnections.set(this.keyMap.size()); + numActiveConnections.set(keyMap.size()); + unreadyConnections.remove(transmission.getConnectionId()); try { transmission.close(); } catch (IOException e) { @@ -483,16 +509,6 @@ private SelectionKey keyForId(String id) { return this.keyMap.get(id); } - /** - * Process connections that have finished their handshake - */ - private void handleConnect(SelectionKey key, Transmission transmission) - throws IOException { - transmission.finishConnect(); - this.connected.add(transmission.getConnectionId()); - this.metrics.selectorConnectionCreated.inc(); - } - /** * Process reads from ready sockets */ diff --git a/ambry-network/src/test/java/com.github.ambry.network/SSLSelectorTest.java b/ambry-network/src/test/java/com.github.ambry.network/SSLSelectorTest.java index af0e89c2a8..36444f2a4c 100644 --- a/ambry-network/src/test/java/com.github.ambry.network/SSLSelectorTest.java +++ b/ambry-network/src/test/java/com.github.ambry.network/SSLSelectorTest.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Random; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -47,18 +48,16 @@ public void setup() TestSSLUtils.createSSLConfig("DC1,DC2,DC3", SSLFactory.Mode.CLIENT, trustStoreFile, "client"); SSLFactory serverSSLFactory = new SSLFactory(sslConfig); SSLFactory clientSSLFactory = new SSLFactory(clientSSLConfig); - this.server = new EchoServer(serverSSLFactory, 18383); - this.server.start(); - this.selector = - new Selector(new NetworkMetrics(new MetricRegistry()), SystemTime.getInstance(), - clientSSLFactory); + server = new EchoServer(serverSSLFactory, 18383); + server.start(); + selector = new Selector(new NetworkMetrics(new MetricRegistry()), SystemTime.getInstance(), clientSSLFactory); } @After public void teardown() throws Exception { - this.selector.close(); - this.server.close(); + selector.close(); + server.close(); } /** @@ -72,7 +71,7 @@ public void testServerDisconnect() assertEquals("hello", blockingRequest(connectionId, "hello")); // disconnect - this.server.closeConnections(); + server.closeConnections(); while (!selector.disconnected().contains(connectionId)) { selector.poll(1000L); } @@ -203,6 +202,27 @@ public void testEmptyRequest() assertEquals("", blockingRequest(connectionId, "")); } + @Test + public void testSSLConnect() + throws IOException { + String connectionId = + selector.connect(new InetSocketAddress("localhost", server.port), BUFFER_SIZE, BUFFER_SIZE, PortType.SSL); + while (!selector.connected().contains(connectionId)) { + selector.poll(10000L); + } + Assert.assertTrue("Channel should have been ready by now ", selector.isChannelReady(connectionId)); + } + + @Test + public void testCloseAfterConnectCall() + throws IOException { + String connectionId = + selector.connect(new InetSocketAddress("localhost", server.port), BUFFER_SIZE, BUFFER_SIZE, PortType.SSL); + selector.close(connectionId); + Assert.assertTrue("Channel should have been added to disconnected list", + selector.disconnected().contains(connectionId)); + } + private String blockingRequest(String connectionId, String s) throws Exception { selector.poll(1000L, asList(SelectorTest.createSend(connectionId, s))); @@ -224,10 +244,6 @@ private String blockingSSLConnect() while (!selector.connected().contains(connectionId)) { selector.poll(10000L); } - //finish the handshake as well - while (!selector.isChannelReady(connectionId)) { - selector.poll(10000L); - } return connectionId; } }