diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/ShardFlushRequest.java b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/ShardFlushRequest.java index 62f01561fc644..8d520544dd8a4 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/ShardFlushRequest.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/ShardFlushRequest.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.support.ActiveShardCount; import org.elasticsearch.action.support.replication.ReplicationRequest; +import org.elasticsearch.cluster.routing.SplitShardCountSummary; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.shard.ShardId; @@ -21,8 +22,13 @@ public class ShardFlushRequest extends ReplicationRequest { private final FlushRequest request; - public ShardFlushRequest(FlushRequest request, ShardId shardId) { - super(shardId); + /** + * Creates a request for a resolved shard id and SplitShardCountSummary (used + * to determine if the request needs to be executed on a split shard not yet seen by the + * coordinator that sent the request) + */ + public ShardFlushRequest(FlushRequest request, ShardId shardId, SplitShardCountSummary reshardSplitShardCountSummary) { + super(shardId, reshardSplitShardCountSummary); this.request = request; this.waitForActiveShards = ActiveShardCount.NONE; // don't wait for any active shards before proceeding, by default } diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java index 6771dc445fc15..8d82b7945d90b 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java @@ -18,6 +18,7 @@ import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.project.ProjectResolver; +import org.elasticsearch.cluster.routing.SplitShardCountSummary; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.injection.guice.Inject; @@ -59,8 +60,8 @@ public TransportFlushAction( } @Override - protected ShardFlushRequest newShardRequest(FlushRequest request, ShardId shardId) { - return new ShardFlushRequest(request, shardId); + protected ShardFlushRequest newShardRequest(FlushRequest request, ShardId shardId, SplitShardCountSummary shardCountSummary) { + return new ShardFlushRequest(request, shardId, shardCountSummary); } @Override diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportShardFlushAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportShardFlushAction.java index c0a3e568ffeeb..1312f32c1918e 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportShardFlushAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportShardFlushAction.java @@ -13,13 +13,16 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.replication.ReplicationRequestSplitHelper; import org.elasticsearch.action.support.replication.ReplicationResponse; import org.elasticsearch.action.support.replication.TransportReplicationAction; import org.elasticsearch.cluster.action.shard.ShardStateAction; +import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Tuple; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndicesService; @@ -32,12 +35,15 @@ import org.elasticsearch.transport.TransportService; import java.io.IOException; +import java.util.Map; public class TransportShardFlushAction extends TransportReplicationAction { public static final String NAME = FlushAction.NAME + "[s]"; public static final ActionType TYPE = new ActionType<>(NAME); + private final ProjectResolver projectResolver; + @Inject public TransportShardFlushAction( Settings settings, @@ -46,7 +52,8 @@ public TransportShardFlushAction( IndicesService indicesService, ThreadPool threadPool, ShardStateAction shardStateAction, - ActionFilters actionFilters + ActionFilters actionFilters, + ProjectResolver projectResolver ) { super( settings, @@ -64,6 +71,7 @@ public TransportShardFlushAction( PrimaryActionExecution.RejectOnOverload, ReplicaActionExecution.SubjectToCircuitBreaker ); + this.projectResolver = projectResolver; transportService.registerRequestHandler( PRE_SYNCED_FLUSH_ACTION_NAME, threadPool.executor(ThreadPool.Names.FLUSH), @@ -89,6 +97,27 @@ protected void shardOperationOnPrimary( })); } + // We are here because there was a mismatch between the SplitShardCountSummary in the request + // and that on the primary shard node. We assume that the request is exactly 1 reshard split behind + // the current state. + @Override + protected Map splitRequestOnPrimary(ShardFlushRequest request) { + return ReplicationRequestSplitHelper.splitRequest( + request, + projectResolver.getProjectMetadata(clusterService.state()), + (targetShard, shardCountSummary) -> new ShardFlushRequest(request.getRequest(), targetShard, shardCountSummary) + ); + } + + @Override + protected Tuple combineSplitResponses( + ShardFlushRequest originalRequest, + Map splitRequests, + Map> responses + ) { + return ReplicationRequestSplitHelper.combineSplitResponses(originalRequest, splitRequests, responses); + } + @Override protected void shardOperationOnReplica(ShardFlushRequest request, IndexShard replica, ActionListener listener) { replica.flush(request.getRequest(), listener.map(flushed -> { diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java index f804dc9ffe907..2856112c05177 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java @@ -19,6 +19,7 @@ import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.project.ProjectResolver; +import org.elasticsearch.cluster.routing.SplitShardCountSummary; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.injection.guice.Inject; @@ -60,8 +61,8 @@ public TransportRefreshAction( } @Override - protected BasicReplicationRequest newShardRequest(RefreshRequest request, ShardId shardId) { - BasicReplicationRequest replicationRequest = new BasicReplicationRequest(shardId); + protected BasicReplicationRequest newShardRequest(RefreshRequest request, ShardId shardId, SplitShardCountSummary shardCountSummary) { + BasicReplicationRequest replicationRequest = new BasicReplicationRequest(shardId, shardCountSummary); replicationRequest.waitForActiveShards(ActiveShardCount.NONE); return replicationRequest; } diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java index 3dc3e19dcb979..7aee16376f9a1 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java @@ -16,14 +16,18 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.replication.BasicReplicationRequest; import org.elasticsearch.action.support.replication.ReplicationOperation; +import org.elasticsearch.action.support.replication.ReplicationRequestSplitHelper; import org.elasticsearch.action.support.replication.ReplicationResponse; import org.elasticsearch.action.support.replication.TransportReplicationAction; import org.elasticsearch.cluster.action.shard.ShardStateAction; +import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Tuple; import org.elasticsearch.index.shard.IndexShard; +import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.logging.LogManager; @@ -32,6 +36,7 @@ import org.elasticsearch.transport.TransportService; import java.io.IOException; +import java.util.Map; import java.util.concurrent.Executor; public class TransportShardRefreshAction extends TransportReplicationAction< @@ -46,6 +51,7 @@ public class TransportShardRefreshAction extends TransportReplicationAction< public static final String SOURCE_API = "api"; private final Executor refreshExecutor; + private final ProjectResolver projectResolver; @Inject public TransportShardRefreshAction( @@ -55,7 +61,8 @@ public TransportShardRefreshAction( IndicesService indicesService, ThreadPool threadPool, ShardStateAction shardStateAction, - ActionFilters actionFilters + ActionFilters actionFilters, + ProjectResolver projectResolver ) { super( settings, @@ -73,6 +80,7 @@ public TransportShardRefreshAction( PrimaryActionExecution.RejectOnOverload, ReplicaActionExecution.SubjectToCircuitBreaker ); + this.projectResolver = projectResolver; // registers the unpromotable version of shard refresh action new TransportUnpromotableShardRefreshAction( clusterService, @@ -104,6 +112,27 @@ protected void shardOperationOnPrimary( })); } + // We are here because there was mismatch between the SplitShardCountSummary in the request + // and that on the primary shard node. We assume that the request is exactly 1 reshard split behind + // the current state. + @Override + protected Map splitRequestOnPrimary(BasicReplicationRequest request) { + return ReplicationRequestSplitHelper.splitRequest( + request, + projectResolver.getProjectMetadata(clusterService.state()), + (targetShard, shardCountSummary) -> new BasicReplicationRequest(targetShard, shardCountSummary) + ); + } + + @Override + protected Tuple combineSplitResponses( + BasicReplicationRequest originalRequest, + Map splitRequests, + Map> responses + ) { + return ReplicationRequestSplitHelper.combineSplitResponses(originalRequest, splitRequests, responses); + } + @Override protected void shardOperationOnReplica(ShardRefreshReplicaRequest request, IndexShard replica, ActionListener listener) { replica.externalRefresh(SOURCE_API, listener.safeMap(refreshResult -> { diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index 69492c67f6ba1..3f3e3a9127526 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -403,11 +403,9 @@ private void executeBulkRequestsByShard( final List requests = entry.getValue(); // Get effective shardCount for shardId and pass it on as parameter to new BulkShardRequest - var indexMetadata = project.index(shardId.getIndexName()); - SplitShardCountSummary reshardSplitShardCountSummary = SplitShardCountSummary.UNSET; - if (indexMetadata != null) { - reshardSplitShardCountSummary = SplitShardCountSummary.forIndexing(indexMetadata, shardId.getId()); - } + var indexMetadata = project.getIndexSafe(shardId.getIndex()); + SplitShardCountSummary reshardSplitShardCountSummary = SplitShardCountSummary.forIndexing(indexMetadata, shardId.getId()); + BulkShardRequest bulkShardRequest = new BulkShardRequest( shardId, reshardSplitShardCountSummary, @@ -416,7 +414,7 @@ private void executeBulkRequestsByShard( bulkRequest.isSimulated() ); - if (indexMetadata != null && indexMetadata.getInferenceFields().isEmpty() == false) { + if (indexMetadata.getInferenceFields().isEmpty() == false) { bulkShardRequest.setInferenceFieldMap(indexMetadata.getInferenceFields()); } bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/ShardBulkSplitHelper.java b/server/src/main/java/org/elasticsearch/action/bulk/ShardBulkSplitHelper.java index fa8c32057d34e..3762943e85d4c 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/ShardBulkSplitHelper.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/ShardBulkSplitHelper.java @@ -10,8 +10,10 @@ package org.elasticsearch.action.bulk; import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.routing.IndexRouting; +import org.elasticsearch.cluster.routing.SplitShardCountSummary; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.Index; import org.elasticsearch.index.shard.ShardId; @@ -33,7 +35,9 @@ private ShardBulkSplitHelper() {} public static Map splitRequests(BulkShardRequest request, ProjectMetadata project) { final ShardId sourceShardId = request.shardId(); final Index index = sourceShardId.getIndex(); - IndexRouting indexRouting = IndexRouting.fromIndexMetadata(project.getIndexSafe(index)); + IndexMetadata indexMetadata = project.getIndexSafe(index); + IndexRouting indexRouting = IndexRouting.fromIndexMetadata(indexMetadata); + SplitShardCountSummary shardCountSummary = SplitShardCountSummary.forIndexing(indexMetadata, request.shardId().getId()); Map> requestsByShard = new HashMap<>(); Map bulkRequestsPerShard = new HashMap<>(); @@ -57,17 +61,26 @@ public static Map splitRequests(BulkShardRequest requ // All items belong to either the source shard or target shard. if (requestsByShard.size() == 1) { - // Return the original request if no items were split to target. + // Return the original request if no items were split to target. Note that + // this original request still contains the stale SplitShardCountSummary. + // This is alright because we hold primary indexing permits while calling this split + // method and we execute this request on the primary without letting go of the indexing permits. + // This means that a second split cannot occur in the meantime. if (requestsByShard.containsKey(sourceShardId)) { return Map.of(sourceShardId, request); } } + // Create a new BulkShardRequest(s) with the updated SplitShardCountSummary. This is because + // we do not hold primary permits on the target shard, and hence it can proceed with + // a second split operation while this request is still pending. We must verify the + // SplitShardCountSummary again on the target. for (Map.Entry> entry : requestsByShard.entrySet()) { final ShardId shardId = entry.getKey(); final List requests = entry.getValue(); BulkShardRequest bulkShardRequest = new BulkShardRequest( shardId, + shardCountSummary, request.getRefreshPolicy(), requests.toArray(new BulkItemRequest[0]), request.isSimulated() diff --git a/server/src/main/java/org/elasticsearch/action/get/TransportShardMultiGetAction.java b/server/src/main/java/org/elasticsearch/action/get/TransportShardMultiGetAction.java index 806b55e6ad7c9..65a0a8751cb5a 100644 --- a/server/src/main/java/org/elasticsearch/action/get/TransportShardMultiGetAction.java +++ b/server/src/main/java/org/elasticsearch/action/get/TransportShardMultiGetAction.java @@ -174,6 +174,7 @@ private void handleMultiGetOnUnpromotableShard( ShardId shardId = indexShard.shardId(); if (request.refresh()) { logger.trace("send refresh action for shard {}", shardId); + // TODO: Do we need to pass in shardCountSummary here ? var refreshRequest = new BasicReplicationRequest(shardId); refreshRequest.setParentTask(request.getParentTask()); client.executeLocally( diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/BasicReplicationRequest.java b/server/src/main/java/org/elasticsearch/action/support/replication/BasicReplicationRequest.java index 8d04a31101d0c..c1a023ff5419c 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/BasicReplicationRequest.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/BasicReplicationRequest.java @@ -9,6 +9,7 @@ package org.elasticsearch.action.support.replication; +import org.elasticsearch.cluster.routing.SplitShardCountSummary; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.index.shard.ShardId; @@ -24,10 +25,20 @@ public class BasicReplicationRequest extends ReplicationRequest> Map splitRequest( + T request, + ProjectMetadata project, + BiFunction targetRequestFactory + ) { + final ShardId sourceShard = request.shardId(); + IndexMetadata indexMetadata = project.getIndexSafe(sourceShard.getIndex()); + SplitShardCountSummary shardCountSummary = SplitShardCountSummary.forIndexing(indexMetadata, sourceShard.getId()); + + Map requestsByShard = new HashMap<>(); + requestsByShard.put(sourceShard, request); + + // Create a request for original source shard and for each target shard. + // New requests that are to be handled by target shards should contain the + // latest ShardCountSummary. + // TODO: This will not work if the reshard metadata is gone + int targetShardId = indexMetadata.getReshardingMetadata().getSplit().targetShard(sourceShard.id()); + ShardId targetShard = new ShardId(sourceShard.getIndex(), targetShardId); + + requestsByShard.put(targetShard, targetRequestFactory.apply(targetShard, shardCountSummary)); + return requestsByShard; + } + + public static > Tuple combineSplitResponses( + T originalRequest, + Map splitRequests, + Map> responses + ) { + int failed = 0; + int successful = 0; + int total = 0; + List failures = new ArrayList<>(); + + // If the action fails on either one of the shards, we return an exception. + // Case 1: Both source and target shards return a response: Add up total, successful, failures + // Case 2: Both source and target shards return an exception : return exception + // Case 3: One shard returns a response, the other returns an exception : return exception + for (Map.Entry> entry : responses.entrySet()) { + Tuple value = entry.getValue(); + Exception exception = value.v2(); + + if (exception != null) { + return new Tuple<>(null, exception); + } + + ReplicationResponse response = value.v1(); + failed += response.getShardInfo().getFailed(); + successful += response.getShardInfo().getSuccessful(); + total += response.getShardInfo().getTotal(); + Collections.addAll(failures, response.getShardInfo().getFailures()); + } + + ReplicationResponse.ShardInfo.Failure[] failureArray = failures.toArray(new ReplicationResponse.ShardInfo.Failure[0]); + assert failureArray.length == failed; + + ReplicationResponse.ShardInfo shardInfo = ReplicationResponse.ShardInfo.of(total, successful, failureArray); + + ReplicationResponse response = new ReplicationResponse(); + response.setShardInfo(shardInfo); + return new Tuple<>(response, null); + } +} diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java index aeb44696e5134..b69c0ef1ee662 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java @@ -29,6 +29,7 @@ import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.OperationRouting; +import org.elasticsearch.cluster.routing.SplitShardCountSummary; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.concurrent.EsExecutors; @@ -62,6 +63,8 @@ public abstract class TransportBroadcastReplicationAction< private final Executor executor; private final ProjectResolver projectResolver; + protected record ShardRecord(ShardId shardId, SplitShardCountSummary splitSummary) {} + public TransportBroadcastReplicationAction( String name, Writeable.Reader requestReader, @@ -102,19 +105,22 @@ public void accept(ActionListener listener) { final ClusterState clusterState = clusterService.state(); final ProjectState projectState = projectResolver.getProjectState(clusterState); final ProjectMetadata project = projectState.metadata(); - final List shards = shards(request, projectState); + final List shards = shards(request, projectState); final Map indexMetadataByName = project.indices(); try (var refs = new RefCountingRunnable(() -> finish(listener))) { - for (final ShardId shardId : shards) { + shards.forEach(shardRecord -> { + ShardId shardId = shardRecord.shardId(); + SplitShardCountSummary shardCountSummary = shardRecord.splitSummary(); // NB This sends O(#shards) requests in a tight loop; TODO add some throttling here? shardExecute( task, request, shardId, + shardCountSummary, ActionListener.releaseAfter(new ReplicationResponseActionListener(shardId, indexMetadataByName), refs.acquire()) ); - } + }); } } @@ -178,9 +184,15 @@ public void onFailure(Exception e) { }; } - protected void shardExecute(Task task, Request request, ShardId shardId, ActionListener shardActionListener) { + protected void shardExecute( + Task task, + Request request, + ShardId shardId, + SplitShardCountSummary shardCountSummary, + ActionListener shardActionListener + ) { assert Transports.assertNotTransportThread("may hit all the shards"); - ShardRequest shardRequest = newShardRequest(request, shardId); + ShardRequest shardRequest = newShardRequest(request, shardId, shardCountSummary); shardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); client.executeLocally(replicatedBroadcastShardAction, shardRequest, shardActionListener); } @@ -188,23 +200,29 @@ protected void shardExecute(Task task, Request request, ShardId shardId, ActionL /** * @return all shard ids the request should run on */ - protected List shards(Request request, ProjectState projectState) { + protected List shards(Request request, ProjectState projectState) { assert Transports.assertNotTransportThread("may hit all the shards"); - List shardIds = new ArrayList<>(); + + List shards = new ArrayList<>(); OperationRouting operationRouting = clusterService.operationRouting(); + ProjectMetadata project = projectState.metadata(); - String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(projectState.metadata(), request); + String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(project, request); for (String index : concreteIndices) { Iterator iterator = operationRouting.allWritableShards(projectState, index); + IndexMetadata indexMetadata = project.index(index); + while (iterator.hasNext()) { - shardIds.add(iterator.next().shardId()); + ShardId shardId = iterator.next().shardId(); + SplitShardCountSummary splitSummary = SplitShardCountSummary.forIndexing(indexMetadata, shardId.getId()); + shards.add(new ShardRecord(shardId, splitSummary)); } } - return shardIds; + return shards; } - protected abstract ShardRequest newShardRequest(Request request, ShardId shardId); + protected abstract ShardRequest newShardRequest(Request request, ShardId shardId, SplitShardCountSummary shardCountSummary); protected abstract Response newResponse( int successfulShards, diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java b/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java index 2b134b43f339d..3eed9724997b5 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java @@ -338,6 +338,9 @@ protected abstract void shardOperationOnReplica( /** * During Resharding, we might need to split the primary request. + * We are here because there was mismatch between the SplitShardCountSummary in the request + * and that on the primary shard node. We assume that the request is exactly 1 reshard split behind + * the current state. */ protected Map splitRequestOnPrimary(Request request) { return Map.of(request.shardId(), request); diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexReshardingState.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexReshardingState.java index 2a9ea56b17c47..9b9faa9b574f0 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexReshardingState.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexReshardingState.java @@ -268,10 +268,24 @@ TargetShardState[] targetShards() { return targetShards.clone(); } + /** Return the source shard from which this target shard was split + * @param targetShard target shard id + * @return source shard id + */ public int sourceShard(int targetShard) { return targetShard % shardCountBefore(); } + /** Return the new target shard that is split from the given source shard + * This calculation assumes we only always double the number of shards in + * a reshard split operation, so that only one target shard is created per source shard. + * @param sourceShard source shard id + * @return target shard id + */ + public int targetShard(int sourceShard) { + return (sourceShard + shardCountBefore()); + } + /** * Create resharding metadata representing a new split operation * Split only supports updating an index to a multiple of its current shard count diff --git a/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java b/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java index 15b66ed32dbad..98c83ba796220 100644 --- a/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java @@ -30,6 +30,7 @@ import org.elasticsearch.cluster.project.TestProjectResolvers; import org.elasticsearch.cluster.routing.GlobalRoutingTable; import org.elasticsearch.cluster.routing.ShardRoutingState; +import org.elasticsearch.cluster.routing.SplitShardCountSummary; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; @@ -246,12 +247,12 @@ public void testShardsList() throws InterruptedException, ExecutionException { index ); logger.debug("--> using initial state:\n{}", clusterService.state()); - List shards = broadcastReplicationAction.shards( + List shards = broadcastReplicationAction.shards( new DummyBroadcastRequest().indices(shardId.getIndexName()), clusterState.projectState(projectId) ); assertThat(shards.size(), equalTo(1)); - assertThat(shards.get(0), equalTo(shardId)); + assertThat(shards.get(0).shardId(), equalTo(shardId)); } private class TestBroadcastReplicationAction extends TransportBroadcastReplicationAction< @@ -285,7 +286,11 @@ private class TestBroadcastReplicationAction extends TransportBroadcastReplicati } @Override - protected BasicReplicationRequest newShardRequest(DummyBroadcastRequest request, ShardId shardId) { + protected BasicReplicationRequest newShardRequest( + DummyBroadcastRequest request, + ShardId shardId, + SplitShardCountSummary shardCountSummary + ) { return new BasicReplicationRequest(shardId); } @@ -304,6 +309,7 @@ protected void shardExecute( Task task, DummyBroadcastRequest request, ShardId shardId, + SplitShardCountSummary shardCountSummary, ActionListener shardActionListener ) { capturedShardRequests.add(new Tuple<>(shardId, shardActionListener));