Skip to content

Commit

Permalink
Refactor the way BaseCluster.selectServer deals with the race condi…
Browse files Browse the repository at this point in the history
…tion

The new approach allows us to later refactor all other logic inside one or more `ServerSelector`s.

See the comment left in the code for more details on the new approach.

JAVA-4254
  • Loading branch information
stIncMale committed May 6, 2024
1 parent 4d883c1 commit d25010d
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
Expand Down Expand Up @@ -122,14 +124,13 @@ public void close() {
}

@Override
public ClusterableServer getServer(final ServerAddress serverAddress) {
public ServersSnapshot getServersSnapshot() {
isTrue("is open", !isClosed());

ServerTuple serverTuple = addressToServerTupleMap.get(serverAddress);
if (serverTuple == null) {
return null;
}
return serverTuple.server;
Map<ServerAddress, ServerTuple> nonAtomicSnapshot = new HashMap<>(addressToServerTupleMap);
return serverAddress -> {
ServerTuple serverTuple = nonAtomicSnapshot.get(serverAddress);
return serverTuple == null ? null : serverTuple.server;
};
}

void onChange(final Collection<ServerAddress> newHosts) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;

import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.assertions.Assertions.isTrue;
import static com.mongodb.assertions.Assertions.notNull;
import static com.mongodb.connection.ServerDescription.MAX_DRIVER_WIRE_VERSION;
Expand Down Expand Up @@ -314,16 +315,35 @@ private boolean handleServerSelectionRequest(final ServerSelectionRequest reques
@Nullable
private ServerTuple selectServer(final ServerSelector serverSelector,
final ClusterDescription clusterDescription) {
return selectServer(serverSelector, clusterDescription, this::getServer);
return selectServer(serverSelector, clusterDescription, getServersSnapshot());
}

@Nullable
@VisibleForTesting(otherwise = PRIVATE)
static ServerTuple selectServer(final ServerSelector serverSelector, final ClusterDescription clusterDescription,
final Function<ServerAddress, Server> serverCatalog) {
return atMostNRandom(new ArrayList<>(serverSelector.select(clusterDescription)), 2, serverDescription -> {
Server server = serverCatalog.apply(serverDescription.getAddress());
return server == null ? null : new ServerTuple(server, serverDescription);
final ServersSnapshot serversSnapshot) {
// The set of `Server`s maintained by the `Cluster` is updated concurrently with `clusterDescription` being read.
// Additionally, that set of servers continues to be concurrently updated while `serverSelector` selects.
// This race condition means that we are not guaranteed not observe all the servers from `clusterDescription`
// among the `Server`s maintained by the `Cluster`.
// To deal with this race condition, we take `serversSnapshot` of that set of `Server`s
// (the snapshot itself does not have to be atomic) non-atomically with reading `clusterDescription`
// (this means, `serversSnapshot` and `clusterDescription` are not guaranteed to be consistent with each other),
// and do pre-filtering to make sure that the only `ServerDescription`s we may select,
// are of those `Server`s that are known to both `clusterDescription` and `serversSnapshot`.
// This way we are guaranteed to successfully get `Server`s from `serversSnapshot` based on the selected `ServerDescription`s.
//
// The pre-filtering we do to deal with the race condition described above is achieved by this `ServerSelector`.
ServerSelector raceConditionPreFiltering = clusterDescriptionPotentiallyInconsistentWithServerSnapshot ->
clusterDescriptionPotentiallyInconsistentWithServerSnapshot.getServerDescriptions()
.stream()
.filter(serverDescription -> serversSnapshot.containsServer(serverDescription.getAddress()))
.collect(toList());
List<ServerDescription> intermediateResult = new CompositeServerSelector(asList(raceConditionPreFiltering, serverSelector))
.select(clusterDescription);
return atMostNRandom(new ArrayList<>(intermediateResult), 2, serverDescription -> {
Server server = assertNotNull(serversSnapshot.getServer(serverDescription.getAddress()));
return new ServerTuple(server, serverDescription);
}).stream()
.min(comparingInt(serverTuple -> serverTuple.getServer().operationCount()))
.orElse(null);
Expand All @@ -345,18 +365,16 @@ private static List<ServerTuple> atMostNRandom(final ArrayList<ServerDescription
List<ServerTuple> result = new ArrayList<>(n);
for (int i = list.size() - 1; i >= 0 && result.size() < n; i--) {
Collections.swap(list, i, random.nextInt(i + 1));
ServerTuple serverTuple = transformer.apply(list.get(i));
if (serverTuple != null) {
result.add(serverTuple);
}
ServerTuple serverTuple = assertNotNull(transformer.apply(list.get(i)));
result.add(serverTuple);
}
return result;
}

private ServerSelector getCompleteServerSelector(final ServerSelector serverSelector, final ServerDeprioritization serverDeprioritization) {
List<ServerSelector> selectors = Stream.of(
serverSelector,
settings.getServerSelector(),
settings.getServerSelector(), // may be null
new LatencyMinimizingServerSelector(settings.getLocalThreshold(MILLISECONDS), MILLISECONDS)
).filter(Objects::nonNull).collect(toList());
return serverDeprioritization.apply(new CompositeServerSelector(selectors));
Expand Down
21 changes: 15 additions & 6 deletions driver-core/src/main/com/mongodb/internal/connection/Cluster.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@


import com.mongodb.ServerAddress;
import com.mongodb.annotations.ThreadSafe;
import com.mongodb.connection.ClusterId;
import com.mongodb.event.ServerDescriptionChangedEvent;
import com.mongodb.internal.VisibleForTesting;
import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.connection.ClusterDescription;
import com.mongodb.connection.ClusterSettings;
Expand All @@ -29,8 +29,6 @@

import java.io.Closeable;

import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE;

/**
* Represents a cluster of MongoDB servers. Implementations can define the behaviour depending upon the type of cluster.
*
Expand All @@ -43,9 +41,7 @@ public interface Cluster extends Closeable {

ClusterId getClusterId();

@Nullable
@VisibleForTesting(otherwise = PRIVATE)
ClusterableServer getServer(ServerAddress serverAddress);
ServersSnapshot getServersSnapshot();

/**
* Get the current description of this cluster.
Expand Down Expand Up @@ -89,4 +85,17 @@ void selectServerAsync(ServerSelector serverSelector, OperationContext operation
* Server Discovery And Monitoring</a> specification.
*/
void onChange(ServerDescriptionChangedEvent event);

/**
* A non-atomic snapshot of the servers in a {@link Cluster}.
*/
@ThreadSafe
interface ServersSnapshot {
@Nullable
Server getServer(ServerAddress serverAddress);

default boolean containsServer(final ServerAddress serverAddress) {
return getServer(serverAddress) != null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,11 @@ public ClusterId getClusterId() {
}

@Override
public ClusterableServer getServer(final ServerAddress serverAddress) {
public ServersSnapshot getServersSnapshot() {
isTrue("open", !isClosed());
waitForSrv();
return assertNotNull(server);
ClusterableServer server = assertNotNull(this.server);
return serverAddress -> server;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.mongodb.internal.connection;

import com.mongodb.MongoConfigurationException;
import com.mongodb.ServerAddress;
import com.mongodb.connection.ClusterConnectionMode;
import com.mongodb.connection.ClusterDescription;
import com.mongodb.connection.ClusterId;
Expand Down Expand Up @@ -69,9 +68,10 @@ protected void connect() {
}

@Override
public ClusterableServer getServer(final ServerAddress serverAddress) {
public ServersSnapshot getServersSnapshot() {
isTrue("open", !isClosed());
return assertNotNull(server.get());
ClusterableServer server = assertNotNull(this.server.get());
return serverAddress -> server;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,13 @@ protected void applyResponse(final BsonArray response) {
protected void applyApplicationError(final BsonDocument applicationError) {
ServerAddress serverAddress = new ServerAddress(applicationError.getString("address").getValue());
int errorGeneration = applicationError.getNumber("generation",
new BsonInt32(((DefaultServer) getCluster().getServer(serverAddress)).getConnectionPool().getGeneration())).intValue();
new BsonInt32(((DefaultServer) getCluster().getServersSnapshot().getServer(serverAddress))
.getConnectionPool().getGeneration())).intValue();
int maxWireVersion = applicationError.getNumber("maxWireVersion").intValue();
String when = applicationError.getString("when").getValue();
String type = applicationError.getString("type").getValue();

DefaultServer server = (DefaultServer) cluster.getServer(serverAddress);
DefaultServer server = (DefaultServer) cluster.getServersSnapshot().getServer(serverAddress);
RuntimeException exception;

switch (type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ class BaseClusterSpecification extends Specification {
}

@Override
ClusterableServer getServer(final ServerAddress serverAddress) {
throw new UnsupportedOperationException()
Cluster.ServersSnapshot getServersSnapshot() {
Cluster.ServersSnapshot result = serverAddress -> {
throw new UnsupportedOperationException()
}
return result
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,11 @@ class DefaultServerSpecification extends Specification {
}

@Override
ClusterableServer getServer(final ServerAddress serverAddress) {
throw new UnsupportedOperationException()
Cluster.ServersSnapshot getServersSnapshot() {
Cluster.ServersSnapshot result = serverAddress -> {
throw new UnsupportedOperationException()
}
return result
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ class MultiServerClusterSpecification extends Specification {
cluster.getCurrentDescription().connectionMode == MULTIPLE
}

def 'should not get server when closed'() {
def 'should not get servers snapshot when closed'() {
given:
def cluster = new MultiServerCluster(CLUSTER_ID, ClusterSettings.builder().hosts(Arrays.asList(firstServer)).mode(MULTIPLE).build(),
factory)
cluster.close()

when:
cluster.getServer(firstServer)
cluster.getServersSnapshot()

then:
thrown(IllegalStateException)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ private void assertServer(final String serverName, final BsonDocument expectedSe

if (expectedServerDescriptionDocument.isDocument("pool")) {
int expectedGeneration = expectedServerDescriptionDocument.getDocument("pool").getNumber("generation").intValue();
DefaultServer server = (DefaultServer) getCluster().getServer(new ServerAddress(serverName));
DefaultServer server = (DefaultServer) getCluster().getServersSnapshot().getServer(new ServerAddress(serverName));
assertEquals(expectedGeneration, server.getConnectionPool().getGeneration());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
@RunWith(Parameterized.class)
public class ServerSelectionWithinLatencyWindowTest {
private final ClusterDescription clusterDescription;
private final Map<ServerAddress, Server> serverCatalog;
private final Cluster.ServersSnapshot serversSnapshot;
private final int iterations;
private final Outcome outcome;

Expand All @@ -65,7 +65,7 @@ public ServerSelectionWithinLatencyWindowTest(
@SuppressWarnings("unused") final String description,
final BsonDocument definition) {
clusterDescription = buildClusterDescription(definition.getDocument("topology_description"), null);
serverCatalog = serverCatalog(definition.getArray("mocked_topology_state"));
serversSnapshot = serverCatalog(definition.getArray("mocked_topology_state"));
iterations = definition.getInt32("iterations").getValue();
outcome = Outcome.parse(definition.getDocument("outcome"));
}
Expand All @@ -74,8 +74,7 @@ public ServerSelectionWithinLatencyWindowTest(
public void shouldPassAllOutcomes() {
ServerSelector selector = new ReadPreferenceServerSelector(ReadPreference.nearest());
Map<ServerAddress, List<ServerTuple>> selectionResultsGroupedByServerAddress = IntStream.range(0, iterations)
.mapToObj(i -> BaseCluster.selectServer(selector, clusterDescription,
address -> Assertions.assertNotNull(serverCatalog.get(address))))
.mapToObj(i -> BaseCluster.selectServer(selector, clusterDescription, serversSnapshot))
.collect(groupingBy(serverTuple -> serverTuple.getServerDescription().getAddress()));
Map<ServerAddress, BigDecimal> selectionFrequencies = selectionResultsGroupedByServerAddress.entrySet()
.stream()
Expand All @@ -97,8 +96,8 @@ public static Collection<Object[]> data() {
.collect(toList());
}

private static Map<ServerAddress, Server> serverCatalog(final BsonArray mockedTopologyState) {
return mockedTopologyState.stream()
private static Cluster.ServersSnapshot serverCatalog(final BsonArray mockedTopologyState) {
Map<ServerAddress, Server> serverMap = mockedTopologyState.stream()
.map(BsonValue::asDocument)
.collect(toMap(
el -> new ServerAddress(el.getString("address").getValue()),
Expand All @@ -108,6 +107,7 @@ private static Map<ServerAddress, Server> serverCatalog(final BsonArray mockedTo
when(server.operationCount()).thenReturn(operationCount);
return server;
}));
return serverAddress -> Assertions.assertNotNull(serverMap.get(serverAddress));
}

private static final class Outcome {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,21 @@ class SingleServerClusterSpecification extends Specification {
sendNotification(firstServer, STANDALONE)

then:
cluster.getServer(firstServer) == factory.getServer(firstServer)
cluster.getServersSnapshot().getServer(firstServer) == factory.getServer(firstServer)

cleanup:
cluster?.close()
}


def 'should not get server when closed'() {
def 'should not get servers snapshot when closed'() {
given:
def cluster = new SingleServerCluster(CLUSTER_ID,
ClusterSettings.builder().mode(SINGLE).hosts(Arrays.asList(firstServer)).build(), factory)
cluster.close()

when:
cluster.getServer(firstServer)
cluster.getServersSnapshot()

then:
thrown(IllegalStateException)
Expand Down

0 comments on commit d25010d

Please sign in to comment.