diff --git a/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java b/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java index 915e900b9ddf1..f690efa4c9a0c 100644 --- a/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java +++ b/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java @@ -25,10 +25,10 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; import org.elasticsearch.cluster.ClusterChangedEvent; -import org.elasticsearch.cluster.ClusterStateTaskExecutor; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateObserver; import org.elasticsearch.cluster.ClusterStateTaskConfig; +import org.elasticsearch.cluster.ClusterStateTaskExecutor; import org.elasticsearch.cluster.ClusterStateTaskListener; import org.elasticsearch.cluster.MasterNodeChangePredicate; import org.elasticsearch.cluster.NotMasterException; @@ -48,6 +48,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.discovery.Discovery; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.node.NodeClosedException; @@ -68,7 +69,9 @@ import java.util.HashSet; import java.util.List; import java.util.Locale; +import java.util.Objects; import java.util.Set; +import java.util.concurrent.ConcurrentMap; import java.util.function.Predicate; public class ShardStateAction extends AbstractComponent { @@ -80,6 +83,10 @@ public class ShardStateAction extends AbstractComponent { private final ClusterService clusterService; private final ThreadPool threadPool; + // a list of shards that failed during replication + // we keep track of these shards in order to avoid sending duplicate failed shard requests for a single failing shard. + private final ConcurrentMap remoteFailedShardsCache = ConcurrentCollections.newConcurrentMap(); + @Inject public ShardStateAction(Settings settings, ClusterService clusterService, TransportService transportService, AllocationService allocationService, RoutingService routingService, ThreadPool threadPool) { @@ -146,8 +153,35 @@ private static boolean isMasterChannelException(TransportException exp) { */ public void remoteShardFailed(final ShardId shardId, String allocationId, long primaryTerm, boolean markAsStale, final String message, @Nullable final Exception failure, Listener listener) { assert primaryTerm > 0L : "primary term should be strictly positive"; - FailedShardEntry shardEntry = new FailedShardEntry(shardId, allocationId, primaryTerm, message, failure, markAsStale); - sendShardAction(SHARD_FAILED_ACTION_NAME, clusterService.state(), shardEntry, listener); + final FailedShardEntry shardEntry = new FailedShardEntry(shardId, allocationId, primaryTerm, message, failure, markAsStale); + final CompositeListener compositeListener = new CompositeListener(listener); + final CompositeListener existingListener = remoteFailedShardsCache.putIfAbsent(shardEntry, compositeListener); + if (existingListener == null) { + sendShardAction(SHARD_FAILED_ACTION_NAME, clusterService.state(), shardEntry, new Listener() { + @Override + public void onSuccess() { + try { + compositeListener.onSuccess(); + } finally { + remoteFailedShardsCache.remove(shardEntry); + } + } + @Override + public void onFailure(Exception e) { + try { + compositeListener.onFailure(e); + } finally { + remoteFailedShardsCache.remove(shardEntry); + } + } + }); + } else { + existingListener.addListener(listener); + } + } + + int remoteShardFailedCacheSize() { + return remoteFailedShardsCache.size(); } /** @@ -414,6 +448,23 @@ public String toString() { components.add("markAsStale [" + markAsStale + "]"); return String.join(", ", components); } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FailedShardEntry that = (FailedShardEntry) o; + // Exclude message and exception from equals and hashCode + return Objects.equals(this.shardId, that.shardId) && + Objects.equals(this.allocationId, that.allocationId) && + primaryTerm == that.primaryTerm && + markAsStale == that.markAsStale; + } + + @Override + public int hashCode() { + return Objects.hash(shardId, allocationId, primaryTerm, markAsStale); + } } public void shardStarted(final ShardRouting shardRouting, final String message, Listener listener) { @@ -585,6 +636,72 @@ default void onFailure(final Exception e) { } + /** + * A composite listener that allows registering multiple listeners dynamically. + */ + static final class CompositeListener implements Listener { + private boolean isNotified = false; + private Exception failure = null; + private final List listeners = new ArrayList<>(); + + CompositeListener(Listener listener) { + listeners.add(listener); + } + + void addListener(Listener listener) { + final boolean ready; + synchronized (this) { + ready = this.isNotified; + if (ready == false) { + listeners.add(listener); + } + } + if (ready) { + if (failure != null) { + listener.onFailure(failure); + } else { + listener.onSuccess(); + } + } + } + + private void onCompleted(Exception failure) { + synchronized (this) { + this.failure = failure; + this.isNotified = true; + } + RuntimeException firstException = null; + for (Listener listener : listeners) { + try { + if (failure != null) { + listener.onFailure(failure); + } else { + listener.onSuccess(); + } + } catch (RuntimeException innerEx) { + if (firstException == null) { + firstException = innerEx; + } else { + firstException.addSuppressed(innerEx); + } + } + } + if (firstException != null) { + throw firstException; + } + } + + @Override + public void onSuccess() { + onCompleted(null); + } + + @Override + public void onFailure(Exception failure) { + onCompleted(failure); + } + } + public static class NoLongerPrimaryShardException extends ElasticsearchException { public NoLongerPrimaryShardException(ShardId shardId, String msg) { diff --git a/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardFailedClusterStateTaskExecutorTests.java b/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardFailedClusterStateTaskExecutorTests.java index 9eeef54dfd796..01d0c518c1be7 100644 --- a/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardFailedClusterStateTaskExecutorTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardFailedClusterStateTaskExecutorTests.java @@ -22,11 +22,11 @@ import com.carrotsearch.hppc.cursors.ObjectCursor; import org.apache.lucene.index.CorruptIndexException; import org.elasticsearch.Version; -import org.elasticsearch.cluster.action.shard.ShardStateAction.FailedShardEntry; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateTaskExecutor; import org.elasticsearch.cluster.ESAllocationTestCase; +import org.elasticsearch.cluster.action.shard.ShardStateAction.FailedShardEntry; import org.elasticsearch.cluster.metadata.IndexMetaData; import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.cluster.node.DiscoveryNodes; @@ -43,6 +43,7 @@ import org.elasticsearch.cluster.routing.allocation.StaleShard; import org.elasticsearch.cluster.routing.allocation.decider.ClusterRebalanceAllocationDecider; import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.index.Index; @@ -53,9 +54,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Set; -import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -131,10 +130,15 @@ ClusterState applyFailedShards(ClusterState currentState, List fail tasks.addAll(failingTasks); tasks.addAll(nonExistentTasks); ClusterStateTaskExecutor.ClusterTasksResult result = failingExecutor.execute(currentState, tasks); - Map taskResultMap = - failingTasks.stream().collect(Collectors.toMap(Function.identity(), task -> ClusterStateTaskExecutor.TaskResult.failure(new RuntimeException("simulated applyFailedShards failure")))); - taskResultMap.putAll(nonExistentTasks.stream().collect(Collectors.toMap(Function.identity(), task -> ClusterStateTaskExecutor.TaskResult.success()))); - assertTaskResults(taskResultMap, result, currentState, false); + List> taskResultList = new ArrayList<>(); + for (FailedShardEntry failingTask : failingTasks) { + taskResultList.add(Tuple.tuple(failingTask, + ClusterStateTaskExecutor.TaskResult.failure(new RuntimeException("simulated applyFailedShards failure")))); + } + for (FailedShardEntry nonExistentTask : nonExistentTasks) { + taskResultList.add(Tuple.tuple(nonExistentTask, ClusterStateTaskExecutor.TaskResult.success())); + } + assertTaskResults(taskResultList, result, currentState, false); } public void testIllegalShardFailureRequests() throws Exception { @@ -147,14 +151,14 @@ public void testIllegalShardFailureRequests() throws Exception { tasks.add(new FailedShardEntry(failingTask.shardId, failingTask.allocationId, randomIntBetween(1, (int) primaryTerm - 1), failingTask.message, failingTask.failure, randomBoolean())); } - Map taskResultMap = - tasks.stream().collect(Collectors.toMap( - Function.identity(), - task -> ClusterStateTaskExecutor.TaskResult.failure(new ShardStateAction.NoLongerPrimaryShardException(task.shardId, - "primary term [" + task.primaryTerm + "] did not match current primary term [" + - currentState.metaData().index(task.shardId.getIndex()).primaryTerm(task.shardId.id()) + "]")))); + List> taskResultList = tasks.stream() + .map(task -> Tuple.tuple(task, ClusterStateTaskExecutor.TaskResult.failure( + new ShardStateAction.NoLongerPrimaryShardException(task.shardId, "primary term [" + + task.primaryTerm + "] did not match current primary term [" + + currentState.metaData().index(task.shardId.getIndex()).primaryTerm(task.shardId.id()) + "]")))) + .collect(Collectors.toList()); ClusterStateTaskExecutor.ClusterTasksResult result = executor.execute(currentState, tasks); - assertTaskResults(taskResultMap, result, currentState, false); + assertTaskResults(taskResultList, result, currentState, false); } public void testMarkAsStaleWhenFailingShard() throws Exception { @@ -251,44 +255,44 @@ private static void assertTasksSuccessful( ClusterState clusterState, boolean clusterStateChanged ) { - Map taskResultMap = - tasks.stream().collect(Collectors.toMap(Function.identity(), task -> ClusterStateTaskExecutor.TaskResult.success())); - assertTaskResults(taskResultMap, result, clusterState, clusterStateChanged); + List> taskResultList = tasks.stream() + .map(t -> Tuple.tuple(t, ClusterStateTaskExecutor.TaskResult.success())).collect(Collectors.toList()); + assertTaskResults(taskResultList, result, clusterState, clusterStateChanged); } private static void assertTaskResults( - Map taskResultMap, + List> taskResultList, ClusterStateTaskExecutor.ClusterTasksResult result, ClusterState clusterState, boolean clusterStateChanged ) { // there should be as many task results as tasks - assertEquals(taskResultMap.size(), result.executionResults.size()); + assertEquals(taskResultList.size(), result.executionResults.size()); - for (Map.Entry entry : taskResultMap.entrySet()) { + for (Tuple entry : taskResultList) { // every task should have a corresponding task result - assertTrue(result.executionResults.containsKey(entry.getKey())); + assertTrue(result.executionResults.containsKey(entry.v1())); // the task results are as expected - assertEquals(entry.getKey().toString(), entry.getValue().isSuccess(), result.executionResults.get(entry.getKey()).isSuccess()); + assertEquals(entry.v1().toString(), entry.v2().isSuccess(), result.executionResults.get(entry.v1()).isSuccess()); } List shards = clusterState.getRoutingTable().allShards(); - for (Map.Entry entry : taskResultMap.entrySet()) { - if (entry.getValue().isSuccess()) { + for (Tuple entry : taskResultList) { + if (entry.v2().isSuccess()) { // the shard was successfully failed and so should not be in the routing table for (ShardRouting shard : shards) { if (shard.assignedToNode()) { - assertFalse("entry key " + entry.getKey() + ", shard routing " + shard, - entry.getKey().getShardId().equals(shard.shardId()) && - entry.getKey().getAllocationId().equals(shard.allocationId().getId())); + assertFalse("entry key " + entry.v1() + ", shard routing " + shard, + entry.v1().getShardId().equals(shard.shardId()) && + entry.v1().getAllocationId().equals(shard.allocationId().getId())); } } } else { // check we saw the expected failure - ClusterStateTaskExecutor.TaskResult actualResult = result.executionResults.get(entry.getKey()); - assertThat(actualResult.getFailure(), instanceOf(entry.getValue().getFailure().getClass())); - assertThat(actualResult.getFailure().getMessage(), equalTo(entry.getValue().getFailure().getMessage())); + ClusterStateTaskExecutor.TaskResult actualResult = result.executionResults.get(entry.v1()); + assertThat(actualResult.getFailure(), instanceOf(entry.v2().getFailure().getClass())); + assertThat(actualResult.getFailure().getMessage(), equalTo(entry.v2().getFailure().getMessage())); } } diff --git a/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardStateActionTests.java b/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardStateActionTests.java index bbd326ff2fedb..1d78cdeb98374 100644 --- a/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardStateActionTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardStateActionTests.java @@ -59,9 +59,10 @@ import org.junit.BeforeClass; import java.io.IOException; -import java.util.UUID; import java.util.Collections; +import java.util.UUID; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Phaser; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -73,6 +74,8 @@ import static org.elasticsearch.test.ClusterServiceUtils.setState; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.sameInstance; +import static org.hamcrest.Matchers.arrayWithSize; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; @@ -138,6 +141,7 @@ public void tearDown() throws Exception { clusterService.close(); transportService.close(); super.tearDown(); + assertThat(shardStateAction.remoteShardFailedCacheSize(), equalTo(0)); } @AfterClass @@ -381,6 +385,89 @@ public void onFailure(Exception e) { assertThat(failure.get().getMessage(), equalTo(catastrophicError.getMessage())); } + public void testCacheRemoteShardFailed() throws Exception { + final String index = "test"; + setState(clusterService, ClusterStateCreationUtils.stateWithActivePrimary(index, true, randomInt(5))); + ShardRouting failedShard = getRandomShardRouting(index); + boolean markAsStale = randomBoolean(); + int numListeners = between(1, 100); + CountDownLatch latch = new CountDownLatch(numListeners); + long primaryTerm = randomLongBetween(1, Long.MAX_VALUE); + for (int i = 0; i < numListeners; i++) { + shardStateAction.remoteShardFailed(failedShard.shardId(), failedShard.allocationId().getId(), + primaryTerm, markAsStale, "test", getSimulatedFailure(), new ShardStateAction.Listener() { + @Override + public void onSuccess() { + latch.countDown(); + } + @Override + public void onFailure(Exception e) { + latch.countDown(); + } + }); + } + CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear(); + assertThat(capturedRequests, arrayWithSize(1)); + transport.handleResponse(capturedRequests[0].requestId, TransportResponse.Empty.INSTANCE); + latch.await(); + assertThat(transport.capturedRequests(), arrayWithSize(0)); + } + + public void testRemoteShardFailedConcurrently() throws Exception { + final String index = "test"; + final AtomicBoolean shutdown = new AtomicBoolean(false); + setState(clusterService, ClusterStateCreationUtils.stateWithActivePrimary(index, true, randomInt(5))); + ShardRouting[] failedShards = new ShardRouting[between(1, 5)]; + for (int i = 0; i < failedShards.length; i++) { + failedShards[i] = getRandomShardRouting(index); + } + Thread[] clientThreads = new Thread[between(1, 6)]; + int iterationsPerThread = scaledRandomIntBetween(50, 500); + Phaser barrier = new Phaser(clientThreads.length + 2); // one for master thread, one for the main thread + Thread masterThread = new Thread(() -> { + barrier.arriveAndAwaitAdvance(); + while (shutdown.get() == false) { + for (CapturingTransport.CapturedRequest request : transport.getCapturedRequestsAndClear()) { + if (randomBoolean()) { + transport.handleResponse(request.requestId, TransportResponse.Empty.INSTANCE); + } else { + transport.handleRemoteError(request.requestId, randomFrom(getSimulatedFailure())); + } + } + } + }); + masterThread.start(); + + AtomicInteger notifiedResponses = new AtomicInteger(); + for (int t = 0; t < clientThreads.length; t++) { + clientThreads[t] = new Thread(() -> { + barrier.arriveAndAwaitAdvance(); + for (int i = 0; i < iterationsPerThread; i++) { + ShardRouting failedShard = randomFrom(failedShards); + shardStateAction.remoteShardFailed(failedShard.shardId(), failedShard.allocationId().getId(), + randomLongBetween(1, Long.MAX_VALUE), randomBoolean(), "test", getSimulatedFailure(), new ShardStateAction.Listener() { + @Override + public void onSuccess() { + notifiedResponses.incrementAndGet(); + } + @Override + public void onFailure(Exception e) { + notifiedResponses.incrementAndGet(); + } + }); + } + }); + clientThreads[t].start(); + } + barrier.arriveAndAwaitAdvance(); + for (Thread t : clientThreads) { + t.join(); + } + assertBusy(() -> assertThat(notifiedResponses.get(), equalTo(clientThreads.length * iterationsPerThread))); + shutdown.set(true); + masterThread.join(); + } + private ShardRouting getRandomShardRouting(String index) { IndexRoutingTable indexRoutingTable = clusterService.state().routingTable().index(index); ShardsIterator shardsIterator = indexRoutingTable.randomAllActiveShardsIt(); @@ -452,4 +539,61 @@ BytesReference serialize(Writeable writeable, Version version) throws IOExceptio return out.bytes(); } } + + public void testCompositeListener() throws Exception { + AtomicInteger successCount = new AtomicInteger(); + AtomicInteger failureCount = new AtomicInteger(); + Exception failure = randomBoolean() ? getSimulatedFailure() : null; + ShardStateAction.CompositeListener compositeListener = new ShardStateAction.CompositeListener(new ShardStateAction.Listener() { + @Override + public void onSuccess() { + successCount.incrementAndGet(); + } + @Override + public void onFailure(Exception e) { + assertThat(e, sameInstance(failure)); + failureCount.incrementAndGet(); + } + }); + int iterationsPerThread = scaledRandomIntBetween(100, 1000); + Thread[] threads = new Thread[between(1, 4)]; + Phaser barrier = new Phaser(threads.length + 1); + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(() -> { + barrier.arriveAndAwaitAdvance(); + for (int n = 0; n < iterationsPerThread; n++) { + compositeListener.addListener(new ShardStateAction.Listener() { + @Override + public void onSuccess() { + successCount.incrementAndGet(); + } + @Override + public void onFailure(Exception e) { + assertThat(e, sameInstance(failure)); + failureCount.incrementAndGet(); + } + }); + } + }); + threads[i].start(); + } + barrier.arriveAndAwaitAdvance(); + if (failure != null) { + compositeListener.onFailure(failure); + } else { + compositeListener.onSuccess(); + } + for (Thread t : threads) { + t.join(); + } + assertBusy(() -> { + if (failure != null) { + assertThat(successCount.get(), equalTo(0)); + assertThat(failureCount.get(), equalTo(threads.length*iterationsPerThread + 1)); + } else { + assertThat(successCount.get(), equalTo(threads.length*iterationsPerThread + 1)); + assertThat(failureCount.get(), equalTo(0)); + } + }); + } }