Skip to content

Commit

Permalink
Correct context for ClusterConnManager listener (#83035) (#83184)
Browse files Browse the repository at this point in the history
* Correct context for ClusterConnManager listener (#83035)

Today `ClusterConnectionManager#connectToNode` completes its listeners
in the thread context in which the connection completes, which may not
be the correct context if there are multiple concurrent connection
attempts. With this commit we make sure to complete each listener in the
context in which it was passed to the corresponding call to
`connectToNode`.

Co-authored-by: ievgen.degtiarenko <ievgen.degtiarenko@elastic.co>

* Missing import

* Fix up tests

Co-authored-by: ievgen.degtiarenko <ievgen.degtiarenko@elastic.co>
  • Loading branch information
DaveCTurner and idegtiarenko committed Jan 27, 2022
1 parent 5ad8ea2 commit 1f1046d
Show file tree
Hide file tree
Showing 32 changed files with 246 additions and 108 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/83035.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 83035
summary: Correct context for `ClusterConnManager` listener
area: Network
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.ListenableFuture;
import org.elasticsearch.common.util.concurrent.RunOnce;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
Expand Down Expand Up @@ -44,18 +46,20 @@ public class ClusterConnectionManager implements ConnectionManager {
private final AbstractRefCounted connectingRefCounter = AbstractRefCounted.of(this::pendingConnectionsComplete);

private final Transport transport;
private final ThreadContext threadContext;
private final ConnectionProfile defaultProfile;
private final AtomicBoolean closing = new AtomicBoolean(false);
private final CountDownLatch closeLatch = new CountDownLatch(1);
private final DelegatingNodeConnectionListener connectionListener = new DelegatingNodeConnectionListener();

public ClusterConnectionManager(Settings settings, Transport transport) {
this(ConnectionProfile.buildDefaultConnectionProfile(settings), transport);
public ClusterConnectionManager(Settings settings, Transport transport, ThreadContext threadContext) {
this(ConnectionProfile.buildDefaultConnectionProfile(settings), transport, threadContext);
}

public ClusterConnectionManager(ConnectionProfile connectionProfile, Transport transport) {
public ClusterConnectionManager(ConnectionProfile connectionProfile, Transport transport, ThreadContext threadContext) {
this.transport = transport;
this.defaultProfile = connectionProfile;
this.threadContext = threadContext;
}

@Override
Expand Down Expand Up @@ -91,7 +95,13 @@ public void connectToNode(
ConnectionValidator connectionValidator,
ActionListener<Releasable> listener
) throws ConnectTransportException {
connectToNodeOrRetry(node, connectionProfile, connectionValidator, 0, listener);
connectToNodeOrRetry(
node,
connectionProfile,
connectionValidator,
0,
ContextPreservingActionListener.wrapPreservingContext(listener, threadContext)
);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ int getNumNodesConnected() {
}

private static ConnectionManager createConnectionManager(ConnectionProfile connectionProfile, TransportService transportService) {
return new ClusterConnectionManager(connectionProfile, transportService.transport);
return new ClusterConnectionManager(connectionProfile, transportService.transport, transportService.threadPool.getThreadContext());
}

ConnectionManager getConnectionManager() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ public TransportService(
localNodeFactory,
clusterSettings,
taskHeaders,
new ClusterConnectionManager(settings, transport)
new ClusterConnectionManager(settings, transport, threadPool.getThreadContext())
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportService;

Expand Down Expand Up @@ -77,7 +78,7 @@ public void testMainActionClusterAvailable() {
TransportService transportService = new TransportService(
Settings.EMPTY,
mock(Transport.class),
null,
mock(ThreadPool.class),
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ private TransportMultiSearchAction createTransportMultiSearchAction(boolean cont
TransportService transportService = new TransportService(
Settings.EMPTY,
mock(Transport.class),
null,
threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
boundAddress -> DiscoveryNode.createLocal(settings, boundAddress.publishAddress(), UUIDs.randomBase64UUID()),
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ protected TestAction(boolean withDocumentFailureOnPrimary, boolean withDocumentF
new TransportService(
Settings.EMPTY,
mock(Transport.class),
null,
TransportWriteActionTests.threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.test.transport.CapturingTransport;
import org.elasticsearch.test.transport.CapturingTransport.CapturedRequest;
import org.elasticsearch.test.transport.MockTransport;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ClusterConnectionManager;
import org.elasticsearch.transport.RemoteTransportException;
import org.elasticsearch.transport.TransportException;
Expand Down Expand Up @@ -53,15 +54,16 @@ public void testJoinDeduplication() {
DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue();
CapturingTransport capturingTransport = new HandshakingCapturingTransport();
DiscoveryNode localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT);
final ThreadPool threadPool = deterministicTaskQueue.getThreadPool();
TransportService transportService = new TransportService(
Settings.EMPTY,
capturingTransport,
deterministicTaskQueue.getThreadPool(),
threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> localNode,
null,
Collections.emptySet(),
new ClusterConnectionManager(Settings.EMPTY, capturingTransport)
new ClusterConnectionManager(Settings.EMPTY, capturingTransport, threadPool.getThreadContext())
);
JoinHelper joinHelper = new JoinHelper(
Settings.EMPTY,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.test.transport.CapturingTransport;
import org.elasticsearch.test.transport.CapturingTransport.CapturedRequest;
import org.elasticsearch.test.transport.StubbableConnectionManager;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ClusterConnectionManager;
import org.elasticsearch.transport.ConnectionManager;
import org.elasticsearch.transport.TransportException;
Expand Down Expand Up @@ -210,7 +211,13 @@ public void setup() {

localNode = newDiscoveryNode("local-node");

ConnectionManager innerConnectionManager = new ClusterConnectionManager(settings, capturingTransport);
final ThreadPool threadPool = deterministicTaskQueue.getThreadPool();

final ConnectionManager innerConnectionManager = new ClusterConnectionManager(
settings,
capturingTransport,
threadPool.getThreadContext()
);
StubbableConnectionManager connectionManager = new StubbableConnectionManager(innerConnectionManager);
connectionManager.setDefaultNodeConnectedBehavior((cm, discoveryNode) -> {
final boolean isConnected = connectedNodes.contains(discoveryNode);
Expand All @@ -222,7 +229,7 @@ public void setup() {
transportService = new TransportService(
settings,
capturingTransport,
deterministicTaskQueue.getThreadPool(),
threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
boundTransportAddress -> localNode,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
Expand All @@ -45,6 +46,7 @@
import java.util.function.Supplier;

import static org.elasticsearch.test.ActionListenerUtils.anyActionListener;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
Expand All @@ -63,7 +65,7 @@ public void createConnectionManager() {
Settings settings = Settings.builder().put("node.name", ClusterConnectionManagerTests.class.getSimpleName()).build();
threadPool = new ThreadPool(settings);
transport = mock(Transport.class);
connectionManager = new ClusterConnectionManager(settings, transport);
connectionManager = new ClusterConnectionManager(settings, transport, threadPool.getThreadContext());
TimeValue oneSecond = new TimeValue(1000);
TimeValue oneMinute = TimeValue.timeValueMinutes(1);
connectionProfile = ConnectionProfile.buildSingleChannelProfile(
Expand Down Expand Up @@ -254,6 +256,9 @@ public void testConcurrentConnects() throws Exception {
int threadCount = between(1, 10);
Releasable[] releasables = new Releasable[threadCount];

final ThreadContext threadContext = threadPool.getThreadContext();
final String contextHeader = "test-context-header";

CyclicBarrier barrier = new CyclicBarrier(threadCount + 1);
Semaphore pendingCloses = new Semaphore(threadCount);
for (int i = 0; i < threadCount; i++) {
Expand All @@ -265,27 +270,33 @@ public void testConcurrentConnects() throws Exception {
throw new RuntimeException(e);
}
CountDownLatch latch = new CountDownLatch(1);
connectionManager.connectToNode(node, connectionProfile, validator, ActionListener.wrap(c -> {
assert connectionManager.nodeConnected(node);

assertTrue(pendingCloses.tryAcquire());
connectionManager.getConnection(node).addRemovedListener(ActionListener.wrap(pendingCloses::release));

if (randomBoolean()) {
releasables[threadIndex] = c;
nodeConnectedCount.incrementAndGet();
} else {
Releasables.close(c);
nodeClosedCount.incrementAndGet();
}

assert latch.getCount() == 1;
latch.countDown();
}, e -> {
nodeFailureCount.incrementAndGet();
assert latch.getCount() == 1;
latch.countDown();
}));
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
final String contextValue = randomAlphaOfLength(10);
threadContext.putHeader(contextHeader, contextValue);
connectionManager.connectToNode(node, connectionProfile, validator, ActionListener.wrap(c -> {
assert connectionManager.nodeConnected(node);
assertThat(threadContext.getHeader(contextHeader), equalTo(contextValue));

assertTrue(pendingCloses.tryAcquire());
connectionManager.getConnection(node).addRemovedListener(ActionListener.wrap(pendingCloses::release));

if (randomBoolean()) {
releasables[threadIndex] = c;
nodeConnectedCount.incrementAndGet();
} else {
Releasables.close(c);
nodeClosedCount.incrementAndGet();
}

assert latch.getCount() == 1;
latch.countDown();
}, e -> {
assertThat(threadContext.getHeader(contextHeader), equalTo(contextValue));
nodeFailureCount.incrementAndGet();
assert latch.getCount() == 1;
latch.countDown();
}));
}
try {
latch.await();
} catch (InterruptedException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ public void testProxyStrategyWillOpenExpectedNumberOfConnectionsToAddress() {
localService.start();
localService.acceptIncomingRequests();

ClusterConnectionManager connectionManager = new ClusterConnectionManager(profile, localService.transport);
final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
profile,
localService.transport,
threadPool.getThreadContext()
);
int numOfConnections = randomIntBetween(4, 8);
try (
RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager(clusterAlias, connectionManager);
Expand Down Expand Up @@ -124,7 +128,11 @@ public void testProxyStrategyWillOpenNewConnectionsOnDisconnect() throws Excepti
localService.start();
localService.acceptIncomingRequests();

ClusterConnectionManager connectionManager = new ClusterConnectionManager(profile, localService.transport);
final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
profile,
localService.transport,
threadPool.getThreadContext()
);
int numOfConnections = randomIntBetween(4, 8);

AtomicBoolean useAddress1 = new AtomicBoolean(true);
Expand Down Expand Up @@ -186,7 +194,11 @@ public void testConnectFailsWithIncompatibleNodes() {
localService.start();
localService.acceptIncomingRequests();

ClusterConnectionManager connectionManager = new ClusterConnectionManager(profile, localService.transport);
final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
profile,
localService.transport,
threadPool.getThreadContext()
);
int numOfConnections = randomIntBetween(4, 8);
try (
RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager(clusterAlias, connectionManager);
Expand Down Expand Up @@ -226,7 +238,11 @@ public void testClusterNameValidationPreventConnectingToDifferentClusters() thro
localService.start();
localService.acceptIncomingRequests();

ClusterConnectionManager connectionManager = new ClusterConnectionManager(profile, localService.transport);
final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
profile,
localService.transport,
threadPool.getThreadContext()
);
int numOfConnections = randomIntBetween(4, 8);

AtomicBoolean useAddress1 = new AtomicBoolean(true);
Expand Down Expand Up @@ -289,7 +305,11 @@ public void testProxyStrategyWillResolveAddressesEachConnect() throws Exception
localService.start();
localService.acceptIncomingRequests();

ClusterConnectionManager connectionManager = new ClusterConnectionManager(profile, localService.transport);
final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
profile,
localService.transport,
threadPool.getThreadContext()
);
int numOfConnections = randomIntBetween(4, 8);
try (
RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager(clusterAlias, connectionManager);
Expand Down Expand Up @@ -324,7 +344,11 @@ public void testProxyStrategyWillNeedToBeRebuiltIfNumOfSocketsOrAddressesOrServe
localService.start();
localService.acceptIncomingRequests();

ClusterConnectionManager connectionManager = new ClusterConnectionManager(profile, localService.transport);
final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
profile,
localService.transport,
threadPool.getThreadContext()
);
int numOfConnections = randomIntBetween(4, 8);
try (
RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager(clusterAlias, connectionManager);
Expand Down Expand Up @@ -429,7 +453,11 @@ public void testServerNameAttributes() {

String address = "localhost:" + address1.getPort();

ClusterConnectionManager connectionManager = new ClusterConnectionManager(profile, localService.transport);
final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
profile,
localService.transport,
threadPool.getThreadContext()
);
int numOfConnections = randomIntBetween(4, 8);
try (
RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager(clusterAlias, connectionManager);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;

import java.net.InetAddress;
Expand All @@ -35,7 +36,10 @@ public class RemoteConnectionManagerTests extends ESTestCase {
public void setUp() throws Exception {
super.setUp();
transport = mock(Transport.class);
remoteConnectionManager = new RemoteConnectionManager("remote-cluster", new ClusterConnectionManager(Settings.EMPTY, transport));
remoteConnectionManager = new RemoteConnectionManager(
"remote-cluster",
new ClusterConnectionManager(Settings.EMPTY, transport, new ThreadContext(Settings.EMPTY))
);
}

@SuppressWarnings("unchecked")
Expand Down

0 comments on commit 1f1046d

Please sign in to comment.