Skip to content

Commit

Permalink
Speed up building RoutingTable from RoutingNodes (#86903)
Browse files Browse the repository at this point in the history
We do the routing nodes -> routing table step during reroute and it's a significant contributor
to the runtime of reroute.
The PR speeds it up by more than double in the common one shard + one primary case by saving the expensive
and slow-to-iterate hashmap keyed by shard id, the needless precomputation of the allocation ids set
and a redundant round of building `IndexShardRoutingTable` in `addShard`.

This gives us another 2-3% speedup on the initial indices bootstrap in the many-shards benchmark.
  • Loading branch information
original-brownbear committed May 19, 2022
1 parent b9c504b commit ab5ff6f
Show file tree
Hide file tree
Showing 40 changed files with 251 additions and 180 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -384,13 +384,18 @@ static ClusterState createClusterState(PersistentTasksCustomMetadata tasksCustom
if (noStartedShards == false) {
shardRouting = shardRouting.moveToStarted();
}
IndexShardRoutingTable table = new IndexShardRoutingTable.Builder(new ShardId(index, 0)).addShard(shardRouting).build();
return ClusterState.builder(new ClusterName("name"))
.metadata(Metadata.builder().putCustom(TYPE, tasksCustomMetadata).put(idxMeta))
.nodes(
DiscoveryNodes.builder().add(new DiscoveryNode("_id1", buildNewFakeTransportAddress(), Version.CURRENT)).localNodeId("_id1")
)
.routingTable(RoutingTable.builder().add(IndexRoutingTable.builder(index).addIndexShard(table)))
.routingTable(
RoutingTable.builder()
.add(
IndexRoutingTable.builder(index)
.addIndexShard(IndexShardRoutingTable.builder(new ShardId(index, 0)).addShard(shardRouting))
)
)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ private IndexRoutingTable randomIndexRoutingTable(String index, String[] nodeIds
)
);
}
builder.addIndexShard(indexShard.build());
builder.addIndexShard(indexShard);
}
return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ public void testShardActiveElseWhere() throws Exception {
indexRoutingTableBuilder.addIndexShard(
new IndexShardRoutingTable.Builder(shardId).addShard(
TestShardRouting.newShardRouting(shardId, masterId, true, ShardRoutingState.STARTED)
).build()
)
);
}
ClusterState newState = ClusterState.builder(currentState)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;

Expand Down Expand Up @@ -66,14 +64,13 @@ public class IndexRoutingTable implements SimpleDiffable<IndexRoutingTable> {

private final List<ShardRouting> allActiveShards;

IndexRoutingTable(Index index, Map<Integer, IndexShardRoutingTable> shards) {
IndexRoutingTable(Index index, IndexShardRoutingTable[] shards) {
this.index = index;
this.shuffler = new RotationShardShuffler(Randomness.get().nextInt());
this.shards = new IndexShardRoutingTable[shards.size()];
this.shards = shards;
List<ShardRouting> allActiveShards = new ArrayList<>();
for (Map.Entry<Integer, IndexShardRoutingTable> cursor : shards.entrySet()) {
this.shards[cursor.getKey()] = cursor.getValue();
allActiveShards.addAll(cursor.getValue().activeShards());
for (IndexShardRoutingTable shard : shards) {
allActiveShards.addAll(shard.activeShards());
}
this.allActiveShards = CollectionUtils.wrapUnmodifiableOrEmptySingleton(allActiveShards);
}
Expand Down Expand Up @@ -301,6 +298,7 @@ public static IndexRoutingTable readFrom(StreamInput in) throws IOException {
Builder builder = new Builder(index);

int size = in.readVInt();
builder.ensureShardArray(size);
for (int i = 0; i < size; i++) {
builder.addIndexShard(IndexShardRoutingTable.Builder.readFromThin(in, index));
}
Expand Down Expand Up @@ -328,7 +326,7 @@ public static Builder builder(Index index) {
public static class Builder {

private final Index index;
private final Map<Integer, IndexShardRoutingTable> shards = new HashMap<>();
private IndexShardRoutingTable.Builder[] shards;

public Builder(Index index) {
this.index = index;
Expand Down Expand Up @@ -414,12 +412,13 @@ private Builder initializeAsRestore(
UnassignedInfo unassignedInfo
) {
assert indexMetadata.getIndex().equals(index);
if (shards.isEmpty() == false) {
if (shards != null) {
throw new IllegalStateException("trying to initialize an index with fresh shards, but already has shards created");
}
shards = new IndexShardRoutingTable.Builder[indexMetadata.getNumberOfShards()];
for (int shardNumber = 0; shardNumber < indexMetadata.getNumberOfShards(); shardNumber++) {
ShardId shardId = new ShardId(index, shardNumber);
IndexShardRoutingTable.Builder indexShardRoutingBuilder = new IndexShardRoutingTable.Builder(shardId);
IndexShardRoutingTable.Builder indexShardRoutingBuilder = IndexShardRoutingTable.builder(shardId);
for (int i = 0; i <= indexMetadata.getNumberOfReplicas(); i++) {
boolean primary = i == 0;
if (asNew && ignoreShards.contains(shardNumber)) {
Expand All @@ -443,7 +442,7 @@ private Builder initializeAsRestore(
);
}
}
shards.put(shardNumber, indexShardRoutingBuilder.build());
shards[shardNumber] = indexShardRoutingBuilder;
}
return this;
}
Expand All @@ -453,9 +452,10 @@ private Builder initializeAsRestore(
*/
private Builder initializeEmpty(IndexMetadata indexMetadata, UnassignedInfo unassignedInfo) {
assert indexMetadata.getIndex().equals(index);
if (shards.isEmpty() == false) {
if (shards != null) {
throw new IllegalStateException("trying to initialize an index with fresh shards, but already has shards created");
}
shards = new IndexShardRoutingTable.Builder[indexMetadata.getNumberOfShards()];
for (int shardNumber = 0; shardNumber < indexMetadata.getNumberOfShards(); shardNumber++) {
ShardId shardId = new ShardId(index, shardNumber);
final RecoverySource primaryRecoverySource;
Expand All @@ -469,7 +469,7 @@ private Builder initializeEmpty(IndexMetadata indexMetadata, UnassignedInfo unas
// a freshly created index with no restriction
primaryRecoverySource = EmptyStoreRecoverySource.INSTANCE;
}
IndexShardRoutingTable.Builder indexShardRoutingBuilder = new IndexShardRoutingTable.Builder(shardId);
IndexShardRoutingTable.Builder indexShardRoutingBuilder = IndexShardRoutingTable.builder(shardId);
for (int i = 0; i <= indexMetadata.getNumberOfReplicas(); i++) {
boolean primary = i == 0;
indexShardRoutingBuilder.addShard(
Expand All @@ -481,35 +481,40 @@ private Builder initializeEmpty(IndexMetadata indexMetadata, UnassignedInfo unas
)
);
}
shards.put(shardNumber, indexShardRoutingBuilder.build());
shards[shardNumber] = indexShardRoutingBuilder;
}
return this;
}

public Builder addReplica() {
for (var shardNumber : shards.keySet()) {
ShardId shardId = new ShardId(index, shardNumber);
assert shards != null;
for (IndexShardRoutingTable.Builder existing : shards) {
assert existing != null;
// version 0, will get updated when reroute will happen
ShardRouting shard = ShardRouting.newUnassigned(
shardId,
false,
PeerRecoverySource.INSTANCE,
new UnassignedInfo(UnassignedInfo.Reason.REPLICA_ADDED, null)
existing.addShard(
ShardRouting.newUnassigned(
existing.shardId(),
false,
PeerRecoverySource.INSTANCE,
new UnassignedInfo(UnassignedInfo.Reason.REPLICA_ADDED, null)
)
);
shards.put(shardNumber, new IndexShardRoutingTable.Builder(shards.get(shard.id())).addShard(shard).build());
}
return this;
}

public Builder removeReplica() {
for (var shardId : shards.keySet()) {
IndexShardRoutingTable indexShard = shards.get(shardId);
assert shards != null;
for (int shardId = 0; shardId < shards.length; shardId++) {
IndexShardRoutingTable.Builder found = shards[shardId];
assert found != null;
final IndexShardRoutingTable indexShard = found.build();
if (indexShard.replicaShards().isEmpty()) {
// nothing to do here!
return this;
}
// re-add all the current ones
IndexShardRoutingTable.Builder builder = new IndexShardRoutingTable.Builder(indexShard.shardId());
IndexShardRoutingTable.Builder builder = IndexShardRoutingTable.builder(indexShard.shardId());
for (int copy = 0; copy < indexShard.size(); copy++) {
ShardRouting shardRouting = indexShard.shard(copy);
builder.addShard(shardRouting);
Expand All @@ -528,15 +533,17 @@ public Builder removeReplica() {
}
}
}
shards.put(shardId, builder.build());
shards[shardId] = builder;
}
return this;
}

public Builder addIndexShard(IndexShardRoutingTable indexShard) {
public Builder addIndexShard(IndexShardRoutingTable.Builder indexShard) {
assert indexShard.shardId().getIndex().equals(index)
: "cannot add shard routing table for " + indexShard.shardId() + " to index routing table for " + index;
shards.put(indexShard.shardId().id(), indexShard);
final int sid = indexShard.shardId().id();
ensureShardArray(sid + 1);
shards[sid] = indexShard;
return this;
}

Expand All @@ -546,18 +553,38 @@ public Builder addIndexShard(IndexShardRoutingTable indexShard) {
*/
public Builder addShard(ShardRouting shard) {
assert shard.index().equals(index) : "cannot add [" + shard + "] to routing table for " + index;
IndexShardRoutingTable indexShard = shards.get(shard.id());
int shardId = shard.id();
ensureShardArray(shardId + 1);
IndexShardRoutingTable.Builder indexShard = shards[shardId];
if (indexShard == null) {
indexShard = new IndexShardRoutingTable.Builder(shard.shardId()).addShard(shard).build();
shards[shardId] = IndexShardRoutingTable.builder(shard.shardId()).addShard(shard);
} else {
indexShard = new IndexShardRoutingTable.Builder(indexShard).addShard(shard).build();
indexShard.addShard(shard);
}
shards.put(indexShard.shardId().id(), indexShard);
return this;
}

void ensureShardArray(int shardCount) {
if (shards == null) {
shards = new IndexShardRoutingTable.Builder[shardCount];
} else if (shards.length < shardCount) {
IndexShardRoutingTable.Builder[] updated = new IndexShardRoutingTable.Builder[shardCount];
System.arraycopy(shards, 0, updated, 0, shards.length);
shards = updated;
}
}

public IndexRoutingTable build() {
return new IndexRoutingTable(index, shards);
final IndexShardRoutingTable[] res;
if (shards != null) {
res = new IndexShardRoutingTable[shards.length];
for (int i = 0; i < shards.length; i++) {
res[i] = shards[i].build();
}
} else {
res = new IndexShardRoutingTable[0];
}
return new IndexRoutingTable(index, res);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.elasticsearch.cluster.routing;

import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.MasterService;
import org.elasticsearch.common.ExponentiallyWeightedMovingAverage;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -50,7 +51,6 @@ public class IndexShardRoutingTable {
final ShardRouting[] shards;
final List<ShardRouting> activeShards;
final List<ShardRouting> assignedShards;
final Set<String> allAllocationIds;
final boolean allShardsStarted;

/**
Expand All @@ -69,7 +69,6 @@ public class IndexShardRoutingTable {
List<ShardRouting> activeShards = new ArrayList<>();
List<ShardRouting> assignedShards = new ArrayList<>();
List<ShardRouting> allInitializingShards = new ArrayList<>();
Set<String> allAllocationIds = new HashSet<>();
boolean allShardsStarted = true;
for (ShardRouting shard : this.shards) {
if (shard.primary()) {
Expand All @@ -87,15 +86,12 @@ public class IndexShardRoutingTable {
if (shard.relocating()) {
// create the target initializing shard routing on the node the shard is relocating to
allInitializingShards.add(shard.getTargetRelocatingShard());
allAllocationIds.add(shard.getTargetRelocatingShard().allocationId().getId());

assert shard.assignedToNode() : "relocating from unassigned " + shard;
assert shard.getTargetRelocatingShard().assignedToNode() : "relocating to unassigned " + shard.getTargetRelocatingShard();
assignedShards.add(shard.getTargetRelocatingShard());
}
if (shard.assignedToNode()) {
assignedShards.add(shard);
allAllocationIds.add(shard.allocationId().getId());
}
if (shard.state() != ShardRoutingState.STARTED) {
allShardsStarted = false;
Expand All @@ -107,7 +103,6 @@ public class IndexShardRoutingTable {
this.activeShards = CollectionUtils.wrapUnmodifiableOrEmptySingleton(activeShards);
this.assignedShards = CollectionUtils.wrapUnmodifiableOrEmptySingleton(assignedShards);
this.allInitializingShards = CollectionUtils.wrapUnmodifiableOrEmptySingleton(allInitializingShards);
this.allAllocationIds = Collections.unmodifiableSet(allAllocationIds);
}

/**
Expand Down Expand Up @@ -477,6 +472,16 @@ public ShardRouting getByAllocationId(String allocationId) {
}

public Set<String> getAllAllocationIds() {
assert MasterService.assertNotMasterUpdateThread("not using this on the master thread so we don't have to pre-compute this");
Set<String> allAllocationIds = new HashSet<>();
for (ShardRouting shard : shards) {
if (shard.relocating()) {
allAllocationIds.add(shard.getTargetRelocatingShard().allocationId().getId());
}
if (shard.assignedToNode()) {
allAllocationIds.add(shard.allocationId().getId());
}
}
return allAllocationIds;
}

Expand Down Expand Up @@ -515,6 +520,10 @@ public List<ShardRouting> shardsWithState(ShardRoutingState state) {
return shards;
}

public static Builder builder(ShardId shardId) {
return new Builder(shardId);
}

public static class Builder {

private final ShardId shardId;
Expand All @@ -526,6 +535,10 @@ public Builder(IndexShardRoutingTable indexShard) {
Collections.addAll(this.shards, indexShard.shards);
}

public ShardId shardId() {
return shardId;
}

public Builder(ShardId shardId) {
this.shardId = shardId;
this.shards = new ArrayList<>();
Expand Down Expand Up @@ -581,12 +594,12 @@ static boolean noDuplicatePrimary(List<ShardRouting> shards) {
return true;
}

public static IndexShardRoutingTable readFrom(StreamInput in) throws IOException {
public static IndexShardRoutingTable.Builder readFrom(StreamInput in) throws IOException {
Index index = new Index(in);
return readFromThin(in, index);
}

public static IndexShardRoutingTable readFromThin(StreamInput in, Index index) throws IOException {
public static IndexShardRoutingTable.Builder readFromThin(StreamInput in, Index index) throws IOException {
int iShardId = in.readVInt();
ShardId shardId = new ShardId(index, iShardId);
Builder builder = new Builder(shardId);
Expand All @@ -597,7 +610,7 @@ public static IndexShardRoutingTable readFromThin(StreamInput in, Index index) t
builder.addShard(shard);
}

return builder.build();
return builder;
}

public static void writeTo(IndexShardRoutingTable indexShard, StreamOutput out) throws IOException {
Expand Down

0 comments on commit ab5ff6f

Please sign in to comment.