Skip to content

Commit

Permalink
Misc improvements to TBbNA tests (#93435)
Browse files Browse the repository at this point in the history
Similar to #92983, this commit reworks the tests in
`TransportBroadcastByNodeActionTests` to use the `ReachabilityChecker`
to check that things are released on cancellation, and adds a test
showing the cancellation behaviour of the shard-level operations.
  • Loading branch information
DaveCTurner committed Feb 2, 2023
1 parent 221c935 commit b0c380d
Showing 1 changed file with 146 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
import org.elasticsearch.cluster.routing.ShardsIterator;
import org.elasticsearch.cluster.routing.TestShardRouting;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
Expand All @@ -50,6 +52,7 @@
import org.elasticsearch.tasks.TaskCancelHelper;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.ReachabilityChecker;
import org.elasticsearch.test.transport.CapturingTransport;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
Expand All @@ -63,6 +66,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -71,6 +75,7 @@
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;

import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
Expand All @@ -79,6 +84,7 @@
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.object.HasToString.hasToString;

public class TransportBroadcastByNodeActionTests extends ESTestCase {
Expand All @@ -91,9 +97,9 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase {
private CapturingTransport transport;

private TestTransportBroadcastByNodeAction action;
private TransportService transportService;

public static class Request extends BroadcastRequest<Request> {

public Request(StreamInput in) throws IOException {
super(in);
}
Expand All @@ -113,29 +119,36 @@ public Response(int totalShards, int successfulShards, int failedShards, List<De
}
}

class TestTransportBroadcastByNodeAction extends TransportBroadcastByNodeAction<
Request,
Response,
TransportBroadcastByNodeAction.EmptyResult> {
// empty per-shard result, but not a singleton so we can check each instance is released on cancellation
public static class ShardResult implements Writeable {
public ShardResult() {}

@Override
public void writeTo(StreamOutput out) throws IOException {}
}

class TestTransportBroadcastByNodeAction extends TransportBroadcastByNodeAction<Request, Response, ShardResult> {
private final Map<ShardRouting, Object> shards = new HashMap<>();

TestTransportBroadcastByNodeAction(
TransportService transportService,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver,
Writeable.Reader<Request> request,
String executor
) {
super("indices:admin/test", clusterService, transportService, actionFilters, indexNameExpressionResolver, request, executor);
TestTransportBroadcastByNodeAction(String actionName) {
super(
actionName,
clusterService,
transportService,
new ActionFilters(Set.of()),
new MyResolver(),
Request::new,
ThreadPool.Names.SAME
);
}

@Override
protected EmptyResult readShardResult(StreamInput in) {
return EmptyResult.INSTANCE;
protected ShardResult readShardResult(StreamInput in) {
return new ShardResult();
}

@Override
protected ResponseFactory<Response, EmptyResult> getResponseFactory(Request request, ClusterState clusterState) {
protected ResponseFactory<Response, ShardResult> getResponseFactory(Request request, ClusterState clusterState) {
return (totalShards, successfulShards, failedShards, emptyResults, shardFailures) -> new Response(
totalShards,
successfulShards,
Expand All @@ -150,11 +163,11 @@ protected Request readRequestFrom(StreamInput in) throws IOException {
}

@Override
protected void shardOperation(Request request, ShardRouting shardRouting, Task task, ActionListener<EmptyResult> listener) {
protected void shardOperation(Request request, ShardRouting shardRouting, Task task, ActionListener<ShardResult> listener) {
ActionListener.completeWith(listener, () -> {
if (rarely()) {
shards.put(shardRouting, Boolean.TRUE);
return EmptyResult.INSTANCE;
return new ShardResult();
} else {
ElasticsearchException e = new ElasticsearchException("operation failed");
shards.put(shardRouting, e);
Expand All @@ -181,6 +194,7 @@ protected ClusterBlockException checkRequestBlock(ClusterState state, Request re
public Map<ShardRouting, Object> getResults() {
return shards;
}

}

static class MyResolver extends IndexNameExpressionResolver {
Expand All @@ -204,7 +218,7 @@ public void setUp() throws Exception {
super.setUp();
transport = new CapturingTransport();
clusterService = createClusterService(THREAD_POOL);
TransportService transportService = transport.createTransportService(
transportService = transport.createTransportService(
clusterService.getSettings(),
THREAD_POOL,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
Expand All @@ -215,13 +229,7 @@ public void setUp() throws Exception {
transportService.start();
transportService.acceptIncomingRequests();
setClusterState(clusterService);
action = new TestTransportBroadcastByNodeAction(
transportService,
new ActionFilters(new HashSet<>()),
new MyResolver(),
Request::new,
ThreadPool.Names.SAME
);
action = new TestTransportBroadcastByNodeAction("indices:admin/test");
}

@After
Expand Down Expand Up @@ -348,16 +356,15 @@ public void testNoShardOperationsExecutedIfTaskCancelled() throws Exception {
shards.add(shard);
}
}
final TransportBroadcastByNodeAction<
Request,
Response,
TransportBroadcastByNodeAction.EmptyResult>.BroadcastByNodeTransportRequestHandler handler =
action.new BroadcastByNodeTransportRequestHandler();
final TransportBroadcastByNodeAction<Request, Response, ShardResult>.BroadcastByNodeTransportRequestHandler handler =
action.new BroadcastByNodeTransportRequestHandler();

final PlainActionFuture<TransportResponse> future = PlainActionFuture.newFuture();
TestTransportChannel channel = new TestTransportChannel(future);

handler.messageReceived(action.new NodeRequest(new Request(), new ArrayList<>(shards), nodeId), channel, cancelledTask());
final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap());
TaskCancelHelper.cancel(cancellableTask, "simulated");
handler.messageReceived(action.new NodeRequest(new Request(), new ArrayList<>(shards), nodeId), channel, cancellableTask);
expectThrows(TaskCancelledException.class, future::actionGet);

assertThat(action.getResults(), anEmptyMap());
Expand Down Expand Up @@ -410,11 +417,8 @@ public void testOperationExecution() throws Exception {
shards.add(shard);
}
}
final TransportBroadcastByNodeAction<
Request,
Response,
TransportBroadcastByNodeAction.EmptyResult>.BroadcastByNodeTransportRequestHandler handler =
action.new BroadcastByNodeTransportRequestHandler();
final TransportBroadcastByNodeAction<Request, Response, ShardResult>.BroadcastByNodeTransportRequestHandler handler =
action.new BroadcastByNodeTransportRequestHandler();

final PlainActionFuture<TransportResponse> future = PlainActionFuture.newFuture();
TestTransportChannel channel = new TestTransportChannel(future);
Expand All @@ -441,13 +445,12 @@ public void testOperationExecution() throws Exception {
failedShards++;
}
}

// check the operation results
assertEquals("successful shards", successfulShards, nodeResponse.getSuccessfulShards());
assertEquals("total shards", action.getResults().size(), nodeResponse.getTotalShards());
assertEquals("failed shards", failedShards, nodeResponse.getExceptions().size());
@SuppressWarnings("unchecked")
List<BroadcastShardOperationFailedException> exceptions = nodeResponse.getExceptions();
assertEquals("exceptions count", failedShards, exceptions.size());
for (BroadcastShardOperationFailedException exception : exceptions) {
assertThat(exception.getMessage(), is("operation indices:admin/test failed"));
assertThat(exception.getCause(), hasToString(containsString("operation failed")));
Expand Down Expand Up @@ -495,20 +498,21 @@ public void testResultAggregation() throws ExecutionException, InterruptedExcept
transport.handleRemoteError(requestId, new Exception());
} else {
List<ShardRouting> shards = map.get(entry.getKey());
List<TransportBroadcastByNodeAction.EmptyResult> shardResults = new ArrayList<>();
List<ShardResult> shardResults = new ArrayList<>();
for (ShardRouting shard : shards) {
totalShards++;
if (rarely()) {
// simulate operation failure
totalFailedShards++;
exceptions.add(new BroadcastShardOperationFailedException(shard.shardId(), "operation indices:admin/test failed"));
} else {
shardResults.add(TransportBroadcastByNodeAction.EmptyResult.INSTANCE);
shardResults.add(new ShardResult());
}
}
totalSuccessfulShards += shardResults.size();
TransportBroadcastByNodeAction<Request, Response, TransportBroadcastByNodeAction.EmptyResult>.NodeResponse nodeResponse =
action.new NodeResponse(entry.getKey(), shards.size(), shardResults, exceptions);
TransportBroadcastByNodeAction<Request, Response, ShardResult>.NodeResponse nodeResponse = action.new NodeResponse(
entry.getKey(), shards.size(), shardResults, exceptions
);
transport.handleResponse(requestId, nodeResponse);
}
}
Expand All @@ -523,33 +527,110 @@ public void testResultAggregation() throws ExecutionException, InterruptedExcept
assertEquals("accumulated exceptions", totalFailedShards, response.getShardFailures().length);
}

public void testNoResultAggregationIfTaskCancelled() {
Request request = new Request(TEST_INDEX);
PlainActionFuture<Response> listener = new PlainActionFuture<>();
final CancellableTask task = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap());
TransportBroadcastByNodeAction<Request, Response, TransportBroadcastByNodeAction.EmptyResult>.AsyncAction asyncAction =
action.new AsyncAction(task, request, clusterService.state(), null, listener);
asyncAction.start();
Map<String, List<CapturingTransport.CapturedRequest>> capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear();
int cancelAt = randomIntBetween(0, Math.max(0, capturedRequests.size() - 2));
int i = 0;
for (Map.Entry<String, List<CapturingTransport.CapturedRequest>> entry : capturedRequests.entrySet()) {
if (cancelAt == i) {
TaskCancelHelper.cancel(task, "simulated");
public void testResponsesReleasedOnCancellation() {
final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap());
final PlainActionFuture<Response> listener = new PlainActionFuture<>();
action.execute(cancellableTask, new Request(TEST_INDEX), listener);

final List<CapturingTransport.CapturedRequest> capturedRequests = new ArrayList<>(
Arrays.asList(transport.getCapturedRequestsAndClear())
);
Randomness.shuffle(capturedRequests);

final ReachabilityChecker reachabilityChecker = new ReachabilityChecker();
final Runnable nextRequestProcessor = () -> {
final var capturedRequest = capturedRequests.remove(0);
if (randomBoolean()) {
// transport.handleResponse may de/serialize the response, releasing it early, so send the response straight to the handler
transport.getTransportResponseHandler(capturedRequest.requestId())
.handleResponse(
action.new NodeResponse(
capturedRequest.node().getId(), 1, List.of(reachabilityChecker.register(new ShardResult())), List.of()
)
);
} else {
// handleRemoteError may de/serialize the exception, releasing it early, so just use handleLocalError
transport.handleLocalError(
capturedRequest.requestId(),
reachabilityChecker.register(new ElasticsearchException("simulated"))
);
}
transport.handleRemoteError(entry.getValue().get(0).requestId(), new ElasticsearchException("simulated"));
i++;
};

assertThat(capturedRequests.size(), greaterThan(2));
final var responsesBeforeCancellation = between(1, capturedRequests.size() - 2);
for (int i = 0; i < responsesBeforeCancellation; i++) {
nextRequestProcessor.run();
}

reachabilityChecker.checkReachable();
TaskCancelHelper.cancel(cancellableTask, "simulated");

// responses captured before cancellation are now unreachable
reachabilityChecker.ensureUnreachable();

while (capturedRequests.size() > 0) {
// a response sent after cancellation is dropped immediately
assertFalse(listener.isDone());
nextRequestProcessor.run();
reachabilityChecker.ensureUnreachable();
}

assertTrue(listener.isDone());
assertTrue(asyncAction.getNodeResponseTracker().responsesDiscarded());
expectThrows(ExecutionException.class, TaskCancelledException.class, listener::get);
expectThrows(TaskCancelledException.class, () -> listener.actionGet(10, TimeUnit.SECONDS));
}

private static Task cancelledTask() {
final CancellableTask task = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap());
TaskCancelHelper.cancel(task, "simulated");
return task;
public void testShardLevelOperationsStopOnCancellation() throws Exception {
action = new TestTransportBroadcastByNodeAction("indices:admin/shard_level_test") {
int expectedShardId;

@Override
protected void shardOperation(Request request, ShardRouting shardRouting, Task task, ActionListener<ShardResult> listener) {
// this test runs a node-level operation on three shards, cancelling the task some time during the execution on the second
if (task instanceof CancellableTask cancellableTask) {
assertEquals(expectedShardId++, shardRouting.shardId().id());
switch (shardRouting.shardId().id()) {
case 0 -> {
assertFalse(cancellableTask.isCancelled());
listener.onResponse(new ShardResult());
}
case 1 -> {
assertFalse(cancellableTask.isCancelled());
TaskCancelHelper.cancel(cancellableTask, "simulated");
if (randomBoolean()) {
listener.onResponse(new ShardResult());
} else {
assertTrue(cancellableTask.notifyIfCancelled(listener));
}
}
default -> fail("unexpected shard execution: " + shardRouting);
}
} else {
fail("task was not cancellable");
}
}
};

final PlainActionFuture<TransportResponse> nodeResponseFuture = new PlainActionFuture<>();

action.new BroadcastByNodeTransportRequestHandler().messageReceived(
action.new NodeRequest(
new Request(), IntStream.range(0, 3)
.mapToObj(shardId -> TestShardRouting.newShardRouting(TEST_INDEX, shardId, "node-id", true, ShardRoutingState.STARTED))
.toList(), "node-id"
),
new TestTransportChannel(nodeResponseFuture),
new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap())
);

assertTrue(nodeResponseFuture.isDone());
assertEquals(
"task cancelled [simulated]",
expectThrows(
java.util.concurrent.ExecutionException.class,
org.elasticsearch.tasks.TaskCancelledException.class,
nodeResponseFuture::get
).getMessage()
);
}

}

0 comments on commit b0c380d

Please sign in to comment.