From 71d8a0bda45c09831a8a2554f39c3c7c67e8ce79 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Thu, 21 May 2026 09:59:16 +0100 Subject: [PATCH 1/3] Close SocketChannel on pre-registration failure in TlsChannelStream Resolve address before opening SocketChannel to avoid FD leak on DNS failure. Close the channel in catch blocks if opened but not yet registered with the selector monitor. Cancel pending registration action to prevent races with the timeout scheduler. JAVA-6216 Co-Authored-By: Claude Opus 4.6 (1M context) --- .../TlsChannelStreamFactoryFactory.java | 48 ++++++++++---- .../TlsChannelStreamFunctionalTest.java | 63 +++++++++++++++++++ 2 files changed, 100 insertions(+), 11 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java b/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java index b0fae1d044d..f369ec243cd 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java +++ b/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java @@ -37,6 +37,7 @@ import javax.net.ssl.SSLParameters; import java.io.Closeable; import java.io.IOException; +import java.net.InetSocketAddress; import java.net.StandardSocketOptions; import java.nio.ByteBuffer; import java.nio.channels.CompletionHandler; @@ -209,35 +210,60 @@ private static class TlsChannelStream extends AsynchronousChannelStream { @Override public void openAsync(final OperationContext operationContext, final AsyncCompletionHandler handler) { isTrue("unopened", getChannel() == null); + SocketChannel socketChannel = null; + SelectorMonitor.SocketRegistration socketRegistration = null; try { - SocketChannel socketChannel = SocketChannel.open(); - socketChannel.configureBlocking(false); + // getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeoutException. + int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs(); + InetSocketAddress socketAddress = getSocketAddresses(getServerAddress(), inetAddressResolver).get(0); + SocketChannel openedSocketChannel = SocketChannel.open(); + socketChannel = openedSocketChannel; + openedSocketChannel.configureBlocking(false); - socketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true); - socketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true); + openedSocketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true); + openedSocketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true); if (getSettings().getReceiveBufferSize() > 0) { - socketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize()); + openedSocketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize()); } if (getSettings().getSendBufferSize() > 0) { - socketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize()); + openedSocketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize()); } - //getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception. - int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs(); - socketChannel.connect(getSocketAddresses(getServerAddress(), inetAddressResolver).get(0)); - SelectorMonitor.SocketRegistration socketRegistration = new SelectorMonitor.SocketRegistration( - socketChannel, () -> initializeTslChannel(handler, socketChannel)); + openedSocketChannel.connect(socketAddress); + socketRegistration = new SelectorMonitor.SocketRegistration( + openedSocketChannel, () -> initializeTslChannel(handler, openedSocketChannel)); if (connectTimeoutMs > 0) { scheduleTimeoutInterruption(handler, socketRegistration, connectTimeoutMs); } selectorMonitor.register(socketRegistration); } catch (IOException e) { + closeSocketChannel(socketChannel, socketRegistration, e); handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e)); } catch (Throwable t) { + closeSocketChannel(socketChannel, socketRegistration, t); handler.failed(t); } } + private void closeSocketChannel(@Nullable final SocketChannel socketChannel, + @Nullable final SelectorMonitor.SocketRegistration socketRegistration, + final Throwable failure) { + if (socketRegistration != null) { + try { + socketRegistration.tryCancelPendingConnection(); + } catch (Throwable t) { + failure.addSuppressed(t); + } + } + if (socketChannel != null) { + try { + socketChannel.close(); + } catch (Throwable e) { + failure.addSuppressed(e); + } + } + } + private void scheduleTimeoutInterruption(final AsyncCompletionHandler handler, final SelectorMonitor.SocketRegistration socketRegistration, final int connectTimeoutMs) { diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java index 3af1eaa33e1..7af01924eaa 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java @@ -17,18 +17,22 @@ package com.mongodb.internal.connection; import com.mongodb.ClusterFixture; +import com.mongodb.MongoSocketException; import com.mongodb.MongoSocketOpenException; import com.mongodb.ServerAddress; +import com.mongodb.connection.AsyncCompletionHandler; import com.mongodb.connection.SocketSettings; import com.mongodb.connection.SslSettings; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.TimeoutSettings; +import com.mongodb.spi.dns.InetAddressResolver; import org.bson.ByteBuf; import org.bson.ByteBufNIO; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; import org.mockito.MockedStatic; import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; @@ -37,6 +41,7 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import java.io.IOException; +import java.net.InetAddress; import java.net.ServerSocket; import java.nio.ByteBuffer; import java.nio.channels.InterruptedByTimeoutException; @@ -52,10 +57,12 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.atLeast; @@ -68,6 +75,62 @@ class TlsChannelStreamFunctionalTest { private static final String UNREACHABLE_PRIVATE_IP_ADDRESS = "10.255.255.1"; private static final int UNREACHABLE_PORT = 65333; + @Test + void shouldNotOpenSocketChannelIfNameResolutionFails() { + //given + MongoSocketException resolverException = new MongoSocketException("Temporary failure in name resolution", new ServerAddress()); + InetAddressResolver inetAddressResolver = host -> { throw resolverException; }; + + try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver); + MockedStatic socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) { + StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder() + .connectTimeout(100, TimeUnit.MILLISECONDS) + .build(), SSL_SETTINGS); + Stream stream = streamFactory.create(new ServerAddress()); + @SuppressWarnings("unchecked") + AsyncCompletionHandler handler = Mockito.mock(AsyncCompletionHandler.class); + + //when + stream.openAsync(createOperationContext(100), handler); + + //then + verify(handler).failed(resolverException); + verify(handler, times(0)).completed(null); + socketChannelMockedStatic.verify(SocketChannel::open, times(0)); + } + } + + @Test + void shouldCloseSocketChannelIfConnectFailsBeforeRegistration() throws IOException { + //given + IOException connectException = new IOException("connect failed"); + InetAddressResolver inetAddressResolver = host -> Collections.singletonList(InetAddress.getLoopbackAddress()); + + try (SocketChannel socketChannel = Mockito.spy(SocketChannel.open()); + StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver); + MockedStatic socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) { + socketChannelMockedStatic.when(SocketChannel::open).thenReturn(socketChannel); + Mockito.doThrow(connectException).when(socketChannel).connect(any()); + StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder() + .connectTimeout(100, TimeUnit.MILLISECONDS) + .build(), SSL_SETTINGS); + Stream stream = streamFactory.create(new ServerAddress()); + @SuppressWarnings("unchecked") + AsyncCompletionHandler handler = Mockito.mock(AsyncCompletionHandler.class); + ArgumentCaptor failureCaptor = ArgumentCaptor.forClass(Throwable.class); + + //when + stream.openAsync(createOperationContext(100), handler); + + //then + verify(handler).failed(failureCaptor.capture()); + MongoSocketOpenException actual = assertInstanceOf(MongoSocketOpenException.class, failureCaptor.getValue()); + assertSame(connectException, actual.getCause()); + verify(handler, times(0)).completed(null); + verify(socketChannel).close(); + } + } + @ParameterizedTest @ValueSource(ints = {500, 1000, 2000}) void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires(final int connectTimeoutMs) throws IOException { From 49af76f6ac0242738348120b25c0b12faa02bb73 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Thu, 21 May 2026 10:58:49 +0100 Subject: [PATCH 2/3] Fix checkstyle nit --- .../internal/connection/TlsChannelStreamFunctionalTest.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java index 7af01924eaa..a7eb86b8d31 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java @@ -79,7 +79,9 @@ class TlsChannelStreamFunctionalTest { void shouldNotOpenSocketChannelIfNameResolutionFails() { //given MongoSocketException resolverException = new MongoSocketException("Temporary failure in name resolution", new ServerAddress()); - InetAddressResolver inetAddressResolver = host -> { throw resolverException; }; + InetAddressResolver inetAddressResolver = host -> { + throw resolverException; + }; try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver); MockedStatic socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) { From bc809ae6f190b5afc8465f2ab429362f6a4150f9 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Thu, 21 May 2026 13:48:19 +0100 Subject: [PATCH 3/3] Updated AGENTS.md to help future agents --- AGENTS.md | 1 + driver-core/AGENTS.md | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index b25f3f1e875..4ba04f76d32 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -52,6 +52,7 @@ See [`.agents/references/style-reference`](.agents/references/style-reference.md - No `System.out.println` / `System.err.println` — use SLF4J - No `e.printStackTrace()` — use proper error handling +- Prefer lambdas over SAM (Single Abstract Method) anonymous class instantiation - Copyright header required: `Copyright 2008-present MongoDB, Inc.` - Every public package must have a `package-info.java` diff --git a/driver-core/AGENTS.md b/driver-core/AGENTS.md index d1ad09af42f..0f9eebdff94 100644 --- a/driver-core/AGENTS.md +++ b/driver-core/AGENTS.md @@ -24,6 +24,10 @@ Largest and most complex module. ./gradlew :driver-core:generateMongoDriverVersion # If MongoDriverVersion is missing ``` +## Important + +- Async code MUST handle errors and they MUST be handled via callbacks or handlers. + ## Notes - Most extensive test suite — JUnit 5 + Spock + Mockito.