diff --git a/AGENTS.md b/AGENTS.md index b25f3f1e87..4ba04f76d3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -52,6 +52,7 @@ See [`.agents/references/style-reference`](.agents/references/style-reference.md - No `System.out.println` / `System.err.println` — use SLF4J - No `e.printStackTrace()` — use proper error handling +- Prefer lambdas over SAM (Single Abstract Method) anonymous class instantiation - Copyright header required: `Copyright 2008-present MongoDB, Inc.` - Every public package must have a `package-info.java` diff --git a/driver-core/AGENTS.md b/driver-core/AGENTS.md index d1ad09af42..0f9eebdff9 100644 --- a/driver-core/AGENTS.md +++ b/driver-core/AGENTS.md @@ -24,6 +24,10 @@ Largest and most complex module. ./gradlew :driver-core:generateMongoDriverVersion # If MongoDriverVersion is missing ``` +## Important + +- Async code MUST handle errors and they MUST be handled via callbacks or handlers. + ## Notes - Most extensive test suite — JUnit 5 + Spock + Mockito. diff --git a/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java b/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java index b0fae1d044..f369ec243c 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java +++ b/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java @@ -37,6 +37,7 @@ import javax.net.ssl.SSLParameters; import java.io.Closeable; import java.io.IOException; +import java.net.InetSocketAddress; import java.net.StandardSocketOptions; import java.nio.ByteBuffer; import java.nio.channels.CompletionHandler; @@ -209,35 +210,60 @@ private static class TlsChannelStream extends AsynchronousChannelStream { @Override public void openAsync(final OperationContext operationContext, final AsyncCompletionHandler handler) { isTrue("unopened", getChannel() == null); + SocketChannel socketChannel = null; + SelectorMonitor.SocketRegistration socketRegistration = null; try { - SocketChannel socketChannel = SocketChannel.open(); - socketChannel.configureBlocking(false); + // getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeoutException. + int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs(); + InetSocketAddress socketAddress = getSocketAddresses(getServerAddress(), inetAddressResolver).get(0); + SocketChannel openedSocketChannel = SocketChannel.open(); + socketChannel = openedSocketChannel; + openedSocketChannel.configureBlocking(false); - socketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true); - socketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true); + openedSocketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true); + openedSocketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true); if (getSettings().getReceiveBufferSize() > 0) { - socketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize()); + openedSocketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize()); } if (getSettings().getSendBufferSize() > 0) { - socketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize()); + openedSocketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize()); } - //getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception. - int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs(); - socketChannel.connect(getSocketAddresses(getServerAddress(), inetAddressResolver).get(0)); - SelectorMonitor.SocketRegistration socketRegistration = new SelectorMonitor.SocketRegistration( - socketChannel, () -> initializeTslChannel(handler, socketChannel)); + openedSocketChannel.connect(socketAddress); + socketRegistration = new SelectorMonitor.SocketRegistration( + openedSocketChannel, () -> initializeTslChannel(handler, openedSocketChannel)); if (connectTimeoutMs > 0) { scheduleTimeoutInterruption(handler, socketRegistration, connectTimeoutMs); } selectorMonitor.register(socketRegistration); } catch (IOException e) { + closeSocketChannel(socketChannel, socketRegistration, e); handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e)); } catch (Throwable t) { + closeSocketChannel(socketChannel, socketRegistration, t); handler.failed(t); } } + private void closeSocketChannel(@Nullable final SocketChannel socketChannel, + @Nullable final SelectorMonitor.SocketRegistration socketRegistration, + final Throwable failure) { + if (socketRegistration != null) { + try { + socketRegistration.tryCancelPendingConnection(); + } catch (Throwable t) { + failure.addSuppressed(t); + } + } + if (socketChannel != null) { + try { + socketChannel.close(); + } catch (Throwable e) { + failure.addSuppressed(e); + } + } + } + private void scheduleTimeoutInterruption(final AsyncCompletionHandler handler, final SelectorMonitor.SocketRegistration socketRegistration, final int connectTimeoutMs) { diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java index 3af1eaa33e..a7eb86b8d3 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java @@ -17,18 +17,22 @@ package com.mongodb.internal.connection; import com.mongodb.ClusterFixture; +import com.mongodb.MongoSocketException; import com.mongodb.MongoSocketOpenException; import com.mongodb.ServerAddress; +import com.mongodb.connection.AsyncCompletionHandler; import com.mongodb.connection.SocketSettings; import com.mongodb.connection.SslSettings; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.TimeoutSettings; +import com.mongodb.spi.dns.InetAddressResolver; import org.bson.ByteBuf; import org.bson.ByteBufNIO; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; import org.mockito.MockedStatic; import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; @@ -37,6 +41,7 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import java.io.IOException; +import java.net.InetAddress; import java.net.ServerSocket; import java.nio.ByteBuffer; import java.nio.channels.InterruptedByTimeoutException; @@ -52,10 +57,12 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.atLeast; @@ -68,6 +75,64 @@ class TlsChannelStreamFunctionalTest { private static final String UNREACHABLE_PRIVATE_IP_ADDRESS = "10.255.255.1"; private static final int UNREACHABLE_PORT = 65333; + @Test + void shouldNotOpenSocketChannelIfNameResolutionFails() { + //given + MongoSocketException resolverException = new MongoSocketException("Temporary failure in name resolution", new ServerAddress()); + InetAddressResolver inetAddressResolver = host -> { + throw resolverException; + }; + + try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver); + MockedStatic socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) { + StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder() + .connectTimeout(100, TimeUnit.MILLISECONDS) + .build(), SSL_SETTINGS); + Stream stream = streamFactory.create(new ServerAddress()); + @SuppressWarnings("unchecked") + AsyncCompletionHandler handler = Mockito.mock(AsyncCompletionHandler.class); + + //when + stream.openAsync(createOperationContext(100), handler); + + //then + verify(handler).failed(resolverException); + verify(handler, times(0)).completed(null); + socketChannelMockedStatic.verify(SocketChannel::open, times(0)); + } + } + + @Test + void shouldCloseSocketChannelIfConnectFailsBeforeRegistration() throws IOException { + //given + IOException connectException = new IOException("connect failed"); + InetAddressResolver inetAddressResolver = host -> Collections.singletonList(InetAddress.getLoopbackAddress()); + + try (SocketChannel socketChannel = Mockito.spy(SocketChannel.open()); + StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver); + MockedStatic socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) { + socketChannelMockedStatic.when(SocketChannel::open).thenReturn(socketChannel); + Mockito.doThrow(connectException).when(socketChannel).connect(any()); + StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder() + .connectTimeout(100, TimeUnit.MILLISECONDS) + .build(), SSL_SETTINGS); + Stream stream = streamFactory.create(new ServerAddress()); + @SuppressWarnings("unchecked") + AsyncCompletionHandler handler = Mockito.mock(AsyncCompletionHandler.class); + ArgumentCaptor failureCaptor = ArgumentCaptor.forClass(Throwable.class); + + //when + stream.openAsync(createOperationContext(100), handler); + + //then + verify(handler).failed(failureCaptor.capture()); + MongoSocketOpenException actual = assertInstanceOf(MongoSocketOpenException.class, failureCaptor.getValue()); + assertSame(connectException, actual.getCause()); + verify(handler, times(0)).completed(null); + verify(socketChannel).close(); + } + } + @ParameterizedTest @ValueSource(ints = {500, 1000, 2000}) void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires(final int connectTimeoutMs) throws IOException {