Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ See [`.agents/references/style-reference`](.agents/references/style-reference.md

- No `System.out.println` / `System.err.println` — use SLF4J
- No `e.printStackTrace()` — use proper error handling
- Prefer lambdas over SAM (Single Abstract Method) anonymous class instantiation
- Copyright header required: `Copyright 2008-present MongoDB, Inc.`
- Every public package must have a `package-info.java`

Expand Down
4 changes: 4 additions & 0 deletions driver-core/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ Largest and most complex module.
./gradlew :driver-core:generateMongoDriverVersion # If MongoDriverVersion is missing
```

## Important

- Async code MUST handle errors and they MUST be handled via callbacks or handlers.

## Notes

- Most extensive test suite — JUnit 5 + Spock + Mockito.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import javax.net.ssl.SSLParameters;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.StandardSocketOptions;
import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
Expand Down Expand Up @@ -209,35 +210,60 @@ private static class TlsChannelStream extends AsynchronousChannelStream {
@Override
public void openAsync(final OperationContext operationContext, final AsyncCompletionHandler<Void> handler) {
isTrue("unopened", getChannel() == null);
SocketChannel socketChannel = null;
SelectorMonitor.SocketRegistration socketRegistration = null;
try {
SocketChannel socketChannel = SocketChannel.open();
socketChannel.configureBlocking(false);
// getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeoutException.
int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs();
InetSocketAddress socketAddress = getSocketAddresses(getServerAddress(), inetAddressResolver).get(0);
SocketChannel openedSocketChannel = SocketChannel.open();
socketChannel = openedSocketChannel;
openedSocketChannel.configureBlocking(false);

socketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true);
socketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true);
openedSocketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true);
openedSocketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true);
if (getSettings().getReceiveBufferSize() > 0) {
socketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize());
openedSocketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize());
}
if (getSettings().getSendBufferSize() > 0) {
socketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize());
openedSocketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize());
}
//getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception.
int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs();
socketChannel.connect(getSocketAddresses(getServerAddress(), inetAddressResolver).get(0));
SelectorMonitor.SocketRegistration socketRegistration = new SelectorMonitor.SocketRegistration(
socketChannel, () -> initializeTslChannel(handler, socketChannel));
openedSocketChannel.connect(socketAddress);
socketRegistration = new SelectorMonitor.SocketRegistration(
openedSocketChannel, () -> initializeTslChannel(handler, openedSocketChannel));

if (connectTimeoutMs > 0) {
scheduleTimeoutInterruption(handler, socketRegistration, connectTimeoutMs);
}
selectorMonitor.register(socketRegistration);
} catch (IOException e) {
closeSocketChannel(socketChannel, socketRegistration, e);
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
} catch (Throwable t) {
closeSocketChannel(socketChannel, socketRegistration, t);
handler.failed(t);
}
}

private void closeSocketChannel(@Nullable final SocketChannel socketChannel,
@Nullable final SelectorMonitor.SocketRegistration socketRegistration,
final Throwable failure) {
if (socketRegistration != null) {
try {
socketRegistration.tryCancelPendingConnection();
} catch (Throwable t) {
failure.addSuppressed(t);
}
}
if (socketChannel != null) {
try {
socketChannel.close();
} catch (Throwable e) {
failure.addSuppressed(e);
}
}
}

private void scheduleTimeoutInterruption(final AsyncCompletionHandler<Void> handler,
final SelectorMonitor.SocketRegistration socketRegistration,
final int connectTimeoutMs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@
package com.mongodb.internal.connection;

import com.mongodb.ClusterFixture;
import com.mongodb.MongoSocketException;
import com.mongodb.MongoSocketOpenException;
import com.mongodb.ServerAddress;
import com.mongodb.connection.AsyncCompletionHandler;
import com.mongodb.connection.SocketSettings;
import com.mongodb.connection.SslSettings;
import com.mongodb.internal.TimeoutContext;
import com.mongodb.internal.TimeoutSettings;
import com.mongodb.spi.dns.InetAddressResolver;
import org.bson.ByteBuf;
import org.bson.ByteBufNIO;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
Expand All @@ -37,6 +41,7 @@
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.nio.ByteBuffer;
import java.nio.channels.InterruptedByTimeoutException;
Expand All @@ -52,10 +57,12 @@
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.atLeast;
Expand All @@ -68,6 +75,64 @@ class TlsChannelStreamFunctionalTest {
private static final String UNREACHABLE_PRIVATE_IP_ADDRESS = "10.255.255.1";
private static final int UNREACHABLE_PORT = 65333;

@Test
void shouldNotOpenSocketChannelIfNameResolutionFails() {
//given
MongoSocketException resolverException = new MongoSocketException("Temporary failure in name resolution", new ServerAddress());
InetAddressResolver inetAddressResolver = host -> {
throw resolverException;
};

try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver);
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) {
StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder()
.connectTimeout(100, TimeUnit.MILLISECONDS)
.build(), SSL_SETTINGS);
Stream stream = streamFactory.create(new ServerAddress());
@SuppressWarnings("unchecked")
AsyncCompletionHandler<Void> handler = Mockito.mock(AsyncCompletionHandler.class);

//when
stream.openAsync(createOperationContext(100), handler);

//then
verify(handler).failed(resolverException);
verify(handler, times(0)).completed(null);
socketChannelMockedStatic.verify(SocketChannel::open, times(0));
}
}

@Test
void shouldCloseSocketChannelIfConnectFailsBeforeRegistration() throws IOException {
//given
IOException connectException = new IOException("connect failed");
InetAddressResolver inetAddressResolver = host -> Collections.singletonList(InetAddress.getLoopbackAddress());

try (SocketChannel socketChannel = Mockito.spy(SocketChannel.open());
StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver);
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) {
socketChannelMockedStatic.when(SocketChannel::open).thenReturn(socketChannel);
Mockito.doThrow(connectException).when(socketChannel).connect(any());
StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder()
.connectTimeout(100, TimeUnit.MILLISECONDS)
.build(), SSL_SETTINGS);
Stream stream = streamFactory.create(new ServerAddress());
@SuppressWarnings("unchecked")
AsyncCompletionHandler<Void> handler = Mockito.mock(AsyncCompletionHandler.class);
ArgumentCaptor<Throwable> failureCaptor = ArgumentCaptor.forClass(Throwable.class);

//when
stream.openAsync(createOperationContext(100), handler);

//then
verify(handler).failed(failureCaptor.capture());
MongoSocketOpenException actual = assertInstanceOf(MongoSocketOpenException.class, failureCaptor.getValue());
assertSame(connectException, actual.getCause());
verify(handler, times(0)).completed(null);
verify(socketChannel).close();
}
}

@ParameterizedTest
@ValueSource(ints = {500, 1000, 2000})
void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires(final int connectTimeoutMs) throws IOException {
Expand Down