Skip to content

Commit

Permalink
Ensure the correct threadContext for RemoteClusterNodesAction (#101050)
Browse files Browse the repository at this point in the history
RemoteClusterNodesAction fetches NodesInfo with system context. It must
restore the original caller's context when respond back. This PR ensures
that.
  • Loading branch information
ywangd committed Oct 19, 2023
1 parent f06f582 commit c9835b8
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/101050.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 101050
summary: Ensure the correct `threadContext` for `RemoteClusterNodesAction`
area: Network
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.action.admin.cluster.node.info.NodesInfoRequest;
import org.elasticsearch.action.admin.cluster.node.info.TransportNodesInfoAction;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.nodes.BaseNodeResponse;
import org.elasticsearch.client.internal.Client;
Expand Down Expand Up @@ -100,6 +101,14 @@ public TransportAction(TransportService transportService, ActionFilters actionFi
@Override
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
final ThreadContext threadContext = client.threadPool().getThreadContext();
executeWithSystemContext(
request,
threadContext,
ContextPreservingActionListener.wrapPreservingContext(listener, threadContext)
);
}

private void executeWithSystemContext(Request request, ThreadContext threadContext, ActionListener<Response> listener) {
try (var ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
if (request.remoteClusterServer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -114,6 +115,7 @@ protected <Request extends ActionRequest, Response extends ActionResponse> void
Request request,
ActionListener<Response> listener
) {
assertThat(threadContext.isSystemContext(), is(true));
assertSame(TransportNodesInfoAction.TYPE, action);
assertThat(
asInstanceOf(NodesInfoRequest.class, request).requestedMetrics(),
Expand All @@ -128,7 +130,10 @@ public void close() {}
);

final PlainActionFuture<RemoteClusterNodesAction.Response> future = new PlainActionFuture<>();
action.doExecute(mock(Task.class), RemoteClusterNodesAction.Request.REMOTE_CLUSTER_SERVER_NODES, future);
action.doExecute(mock(Task.class), RemoteClusterNodesAction.Request.REMOTE_CLUSTER_SERVER_NODES, ActionListener.wrap(response -> {
assertThat(threadContext.isSystemContext(), is(false));
future.onResponse(response);
}, future::onFailure));

final List<DiscoveryNode> actualNodes = future.actionGet().getNodes();
assertThat(Set.copyOf(actualNodes), equalTo(expectedRemoteServerNodes));
Expand Down Expand Up @@ -191,6 +196,7 @@ protected <Request extends ActionRequest, Response extends ActionResponse> void
Request request,
ActionListener<Response> listener
) {
assertThat(threadContext.isSystemContext(), is(true));
assertSame(TransportNodesInfoAction.TYPE, action);
assertThat(asInstanceOf(NodesInfoRequest.class, request).requestedMetrics(), empty());
listener.onResponse((Response) nodesInfoResponse);
Expand All @@ -202,7 +208,10 @@ public void close() {}
);

final PlainActionFuture<RemoteClusterNodesAction.Response> future = new PlainActionFuture<>();
action.doExecute(mock(Task.class), RemoteClusterNodesAction.Request.ALL_NODES, future);
action.doExecute(mock(Task.class), RemoteClusterNodesAction.Request.ALL_NODES, ActionListener.wrap(response -> {
assertThat(threadContext.isSystemContext(), is(false));
future.onResponse(response);
}, future::onFailure));

final List<DiscoveryNode> actualNodes = future.actionGet().getNodes();
assertThat(Set.copyOf(actualNodes), equalTo(expectedRemoteNodes));
Expand Down

0 comments on commit c9835b8

Please sign in to comment.