Skip to content

Commit

Permalink
[7.17] Update YAML Rest tests to check for product header on all resp…
Browse files Browse the repository at this point in the history
…onses (#83290) (#83996)

* Update YAML Rest tests to check for product header on all responses (#83290)

This PR fixes a number of misuses of ThreadContext, mostly because of stashing listeners 
without their accompanying contexts. The YAML Rest Test changes in the original PR were
not backported since they break BWC testing going further back into 7.x.
  • Loading branch information
jbaiera committed Feb 16, 2022
1 parent d5290f5 commit 92b3d0d
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 75 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/83290.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 83290
summary: Update YAML Rest tests to check for product header on all responses
area: Infra/REST API
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ protected void finishHim(Exception failure) {
*/
protected void finishHim(Exception failure, List<Failure> indexingFailures, List<SearchFailure> searchFailures, boolean timedOut) {
logger.debug("[{}]: finishing without any catastrophic failures", task.getId());
scrollSource.close(() -> {
scrollSource.close(threadPool.getThreadContext().preserveContext(() -> {
if (failure == null) {
BulkByScrollResponse response = buildResponse(
timeValueNanos(System.nanoTime() - startTime.get()),
Expand All @@ -569,7 +569,7 @@ protected void finishHim(Exception failure, List<Failure> indexingFailures, List
} else {
listener.onFailure(failure);
}
});
}));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
import org.elasticsearch.cluster.RestoreInProgress;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.collect.ImmutableOpenMap;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.snapshots.RestoreInfo;
import org.elasticsearch.snapshots.RestoreService;

import java.util.function.Supplier;

import static org.elasticsearch.snapshots.RestoreService.restoreInProgress;

public class RestoreClusterStateListener implements ClusterStateListener {
Expand All @@ -29,43 +32,48 @@ public class RestoreClusterStateListener implements ClusterStateListener {
private final ClusterService clusterService;
private final String uuid;
private final ActionListener<RestoreSnapshotResponse> listener;
private final Supplier<ThreadContext.StoredContext> contextSupplier;

private RestoreClusterStateListener(
ClusterService clusterService,
RestoreService.RestoreCompletionResponse response,
ActionListener<RestoreSnapshotResponse> listener
ActionListener<RestoreSnapshotResponse> listener,
Supplier<ThreadContext.StoredContext> contextSupplier
) {
this.clusterService = clusterService;
this.uuid = response.getUuid();
this.listener = listener;
this.contextSupplier = contextSupplier;
}

@Override
public void clusterChanged(ClusterChangedEvent changedEvent) {
final RestoreInProgress.Entry prevEntry = restoreInProgress(changedEvent.previousState(), uuid);
final RestoreInProgress.Entry newEntry = restoreInProgress(changedEvent.state(), uuid);
if (prevEntry == null) {
// When there is a master failure after a restore has been started, this listener might not be registered
// on the current master and as such it might miss some intermediary cluster states due to batching.
// Clean up listener in that case and acknowledge completion of restore operation to client.
clusterService.removeListener(this);
listener.onResponse(new RestoreSnapshotResponse((RestoreInfo) null));
} else if (newEntry == null) {
clusterService.removeListener(this);
ImmutableOpenMap<ShardId, RestoreInProgress.ShardRestoreStatus> shards = prevEntry.shards();
assert prevEntry.state().completed() : "expected completed snapshot state but was " + prevEntry.state();
assert RestoreService.completed(shards) : "expected all restore entries to be completed";
RestoreInfo ri = new RestoreInfo(
prevEntry.snapshot().getSnapshotId().getName(),
prevEntry.indices(),
shards.size(),
shards.size() - RestoreService.failedShards(shards)
);
RestoreSnapshotResponse response = new RestoreSnapshotResponse(ri);
logger.debug("restore of [{}] completed", prevEntry.snapshot().getSnapshotId());
listener.onResponse(response);
} else {
// restore not completed yet, wait for next cluster state update
try (ThreadContext.StoredContext stored = contextSupplier.get()) {
final RestoreInProgress.Entry prevEntry = restoreInProgress(changedEvent.previousState(), uuid);
final RestoreInProgress.Entry newEntry = restoreInProgress(changedEvent.state(), uuid);
if (prevEntry == null) {
// When there is a master failure after a restore has been started, this listener might not be registered
// on the current master and as such it might miss some intermediary cluster states due to batching.
// Clean up listener in that case and acknowledge completion of restore operation to client.
clusterService.removeListener(this);
listener.onResponse(new RestoreSnapshotResponse((RestoreInfo) null));
} else if (newEntry == null) {
clusterService.removeListener(this);
ImmutableOpenMap<ShardId, RestoreInProgress.ShardRestoreStatus> shards = prevEntry.shards();
assert prevEntry.state().completed() : "expected completed snapshot state but was " + prevEntry.state();
assert RestoreService.completed(shards) : "expected all restore entries to be completed";
RestoreInfo ri = new RestoreInfo(
prevEntry.snapshot().getSnapshotId().getName(),
prevEntry.indices(),
shards.size(),
shards.size() - RestoreService.failedShards(shards)
);
RestoreSnapshotResponse response = new RestoreSnapshotResponse(ri);
logger.debug("restore of [{}] completed", prevEntry.snapshot().getSnapshotId());
listener.onResponse(response);
} else {
// restore not completed yet, wait for next cluster state update
}
}
}

Expand All @@ -76,8 +84,11 @@ public void clusterChanged(ClusterChangedEvent changedEvent) {
public static void createAndRegisterListener(
ClusterService clusterService,
RestoreService.RestoreCompletionResponse response,
ActionListener<RestoreSnapshotResponse> listener
ActionListener<RestoreSnapshotResponse> listener,
ThreadContext threadContext
) {
clusterService.addListener(new RestoreClusterStateListener(clusterService, response, listener));
clusterService.addListener(
new RestoreClusterStateListener(clusterService, response, listener, threadContext.newRestorableContext(true))
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ protected void masterOperation(
) {
restoreService.restoreSnapshot(request, listener.delegateFailure((delegatedListener, restoreCompletionResponse) -> {
if (restoreCompletionResponse.getRestoreInfo() == null && request.waitForCompletion()) {
RestoreClusterStateListener.createAndRegisterListener(clusterService, restoreCompletionResponse, delegatedListener);
RestoreClusterStateListener.createAndRegisterListener(
clusterService,
restoreCompletionResponse,
delegatedListener,
threadPool.getThreadContext()
);
} else {
delegatedListener.onResponse(new RestoreSnapshotResponse(restoreCompletionResponse.getRestoreInfo()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public void removeApplier(ClusterStateApplier applier) {
}

/**
* Add a listener for updated cluster states
* Add a listener for updated cluster states. Listeners are executed in the system thread context.
*/
public void addListener(ClusterStateListener listener) {
clusterStateListeners.add(listener);
Expand All @@ -222,7 +222,7 @@ public void addListener(ClusterStateListener listener) {
/**
* Removes a listener for updated cluster states.
*/
public void removeListener(ClusterStateListener listener) {
public void removeListener(final ClusterStateListener listener) {
clusterStateListeners.remove(listener);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.action.admin.cluster.snapshots.create.CreateSnapshotRequest;
import org.elasticsearch.action.admin.cluster.snapshots.delete.DeleteSnapshotRequest;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.cluster.ClusterChangedEvent;
Expand Down Expand Up @@ -3600,7 +3601,8 @@ static Map<String, DataStreamAlias> filterDataStreamAliases(
* @param listener listener
*/
private void addListener(Snapshot snapshot, ActionListener<Tuple<RepositoryData, SnapshotInfo>> listener) {
snapshotCompletionListeners.computeIfAbsent(snapshot, k -> new CopyOnWriteArrayList<>()).add(listener);
snapshotCompletionListeners.computeIfAbsent(snapshot, k -> new CopyOnWriteArrayList<>())
.add(ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,20 @@ public String getReasonPhrase() {
* Get a list of all of the values of all warning headers returned in the response.
*/
public List<String> getWarningHeaders() {
List<String> warningHeaders = new ArrayList<>();
return getHeaders("Warning");
}

/**
* Get a list of all the values of a given header returned in the response.
*/
public List<String> getHeaders(String name) {
List<String> headers = new ArrayList<>();
for (Header header : response.getHeaders()) {
if (header.getName().equals("Warning")) {
warningHeaders.add(header.getValue());
if (header.getName().equalsIgnoreCase(name)) {
headers.add(header.getValue());
}
}
return warningHeaders;
return headers;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ public void onFailure(Exception e) {
assert restoreInfo.failedShards() > 0 : "Should have failed shards";
delegatedListener.onResponse(new PutFollowAction.Response(true, false, false));
}
})
}),
threadPool.getThreadContext()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction;
Expand Down Expand Up @@ -179,10 +180,16 @@ private void removeRetentionLeaseForShard(
) {
logger.trace("{} removing retention lease [{}] while unfollowing leader index", followerShardId, retentionLeaseId);
final ThreadContext threadContext = threadPool.getThreadContext();
// We're about to stash the thread context for this retention lease removal. The listener will be completed while the
// context is stashed. The context needs to be restored in the listener when it is completing or else it is simply wiped.
final ActionListener<ActionResponse.Empty> preservedListener = new ContextPreservingActionListener<>(
threadContext.newRestorableContext(true),
listener
);
try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) {
// we have to execute under the system context so that if security is enabled the removal is authorized
threadContext.markAsSystemContext();
CcrRetentionLeases.asyncRemoveRetentionLease(leaderShardId, retentionLeaseId, remoteClient, listener);
CcrRetentionLeases.asyncRemoveRetentionLease(leaderShardId, retentionLeaseId, remoteClient, preservedListener);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,12 @@
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.transport.Transports;
import org.elasticsearch.xpack.core.XPackFeatureSet;
import org.elasticsearch.xpack.core.XPackFeatureSet.Usage;
import org.elasticsearch.xpack.core.common.IteratingActionListener;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CancellationException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.function.BiConsumer;

public class TransportXPackUsageAction extends TransportMasterNodeAction<XPackUsageRequest, XPackUsageResponse> {

Expand Down Expand Up @@ -72,37 +67,29 @@ protected void masterOperation(

@Override
protected void masterOperation(Task task, XPackUsageRequest request, ClusterState state, ActionListener<XPackUsageResponse> listener) {
final ActionListener<List<XPackFeatureSet.Usage>> usageActionListener = listener.delegateFailure(
(l, usages) -> l.onResponse(new XPackUsageResponse(usages))
);
final AtomicReferenceArray<Usage> featureSetUsages = new AtomicReferenceArray<>(featureSets.size());
final AtomicInteger position = new AtomicInteger(0);
final BiConsumer<XPackFeatureSet, ActionListener<List<Usage>>> consumer = (featureSet, iteratingListener) -> {
assert Transports.assertNotTransportThread("calculating usage can be more expensive than we allow on transport threads");
if (task instanceof CancellableTask && ((CancellableTask) task).isCancelled()) {
throw new CancellationException("Task cancelled");
}
new ActionRunnable<XPackUsageResponse>(listener) {
final List<XPackFeatureSet.Usage> responses = new ArrayList<>(featureSets.size());

@Override
protected void doRun() throws Exception {
if (responses.size() < featureSets.size()) {
assert Transports.assertNotTransportThread(
"calculating usage can be more expensive than we allow on transport threads"
);
if (task instanceof CancellableTask && ((CancellableTask) task).isCancelled()) {
throw new CancellationException("Task cancelled");
}

featureSet.usage(iteratingListener.delegateFailure((l, usage) -> {
featureSetUsages.set(position.getAndIncrement(), usage);
threadPool.executor(ThreadPool.Names.MANAGEMENT).execute(ActionRunnable.supply(iteratingListener, Collections::emptyList));
}));
};
IteratingActionListener<List<XPackFeatureSet.Usage>, XPackFeatureSet> iteratingActionListener = new IteratingActionListener<>(
usageActionListener,
consumer,
featureSets,
threadPool.getThreadContext(),
(ignore) -> {
final List<Usage> usageList = new ArrayList<>(featureSetUsages.length());
for (int i = 0; i < featureSetUsages.length(); i++) {
usageList.add(featureSetUsages.get(i));
featureSets.get(responses.size()).usage(listener.delegateFailure((l, usage) -> {
responses.add(usage);
threadPool.executor(ThreadPool.Names.MANAGEMENT).execute(this);
}));
} else {
assert responses.size() == featureSets.size() : responses.size() + " vs " + featureSets.size();
listener.onResponse(new XPackUsageResponse(responses));
}
return usageList;
},
(ignore) -> true
);
iteratingActionListener.run();
}
}.run();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package org.elasticsearch.xpack.watcher;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
Expand Down Expand Up @@ -67,6 +68,10 @@ public Map<String, Object> nativeCodeInfo() {
@Override
public void usage(ActionListener<XPackFeatureSet.Usage> listener) {
if (enabled) {
ActionListener<XPackFeatureSet.Usage> preservingListener = ContextPreservingActionListener.wrapPreservingContext(
listener,
client.threadPool().getThreadContext()
);
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(WATCHER_ORIGIN)) {
WatcherClient watcherClient = new WatcherClient(client);
WatcherStatsRequest request = new WatcherStatsRequest();
Expand All @@ -78,8 +83,8 @@ public void usage(ActionListener<XPackFeatureSet.Usage> listener) {
.filter(Objects::nonNull)
.collect(Collectors.toList());
Counters mergedCounters = Counters.merge(countersPerNode);
listener.onResponse(new WatcherFeatureSetUsage(available(), enabled(), mergedCounters.toNestedMap()));
}, listener::onFailure));
preservingListener.onResponse(new WatcherFeatureSetUsage(available(), enabled(), mergedCounters.toNestedMap()));
}, preservingListener::onFailure));
}
} else {
listener.onResponse(new WatcherFeatureSetUsage(available(), enabled(), Collections.emptyMap()));
Expand Down

0 comments on commit 92b3d0d

Please sign in to comment.