Skip to content

Commit

Permalink
Auto sharding uses the sum of shards write loads (#106785)
Browse files Browse the repository at this point in the history
Data stream auto sharding uses the index write load to decide the
optimal number of shards. We read this previously from the indexing
stats output, using the `total/write_load` value however, this
proved to be wrong as that value takes into account the search shard
write load (which will always be 0).
Even more, the `total/write_load` value averages the write loads for
every shard so you can end up with indices that only have one primary
and one replica, with the primary shard having a write load of 1.7 and
the `total/write_load` reporting to be `0.8`.

For data stream auto sharding we're interested in the **total** index
write load, defined as the sum of all the shards write loads (yes we
can include the replica shard write loads in this sum as they're 0).

This PR changes the rollover write load computation to sum all the shard
write loads for the data stream write index, and in the
`DataStreamAutoShardingService` when looking at the historic write load
over the cooldown period to, again, sum the write loads of every shard
in the index metadata/stats.
  • Loading branch information
andreidan committed Mar 27, 2024
1 parent 5f132ca commit 9776f54
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,22 +129,17 @@ public void testRolloverOnAutoShardCondition() throws Exception {
for (int i = 0; i < firstGenerationMeta.getNumberOfShards(); i++) {
// the shard stats will yield a write load of 75.0 which will make the auto sharding service recommend an optimal number
// of 5 shards
shards.add(getShardStats(firstGenerationMeta, i, 75, assignedShardNodeId));
}

for (DiscoveryNode node : clusterStateBeforeRollover.nodes().getAllNodes()) {
MockTransportService.getInstance(node.getName())
.addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
TransportIndicesStatsAction instance = internalCluster().getInstance(
TransportIndicesStatsAction.class,
node.getName()
);
channel.sendResponse(
instance.new NodeResponse(node.getId(), firstGenerationMeta.getNumberOfShards(), shards, List.of())
);
});
shards.add(
getShardStats(
firstGenerationMeta,
i,
(long) Math.ceil(75.0 / firstGenerationMeta.getNumberOfShards()),
assignedShardNodeId
)
);
}

mockStatsForIndex(clusterStateBeforeRollover, assignedShardNodeId, firstGenerationMeta, shards);
assertAcked(indicesAdmin().rolloverIndex(new RolloverRequest(dataStreamName, null)).actionGet());

ClusterState clusterStateAfterRollover = internalCluster().getCurrentMasterNodeInstance(ClusterService.class).state();
Expand Down Expand Up @@ -180,21 +175,16 @@ instance.new NodeResponse(node.getId(), firstGenerationMeta.getNumberOfShards(),
for (int i = 0; i < secondGenerationMeta.getNumberOfShards(); i++) {
// the shard stats will yield a write load of 100.0 which will make the auto sharding service recommend an optimal number of
// 7 shards
shards.add(getShardStats(secondGenerationMeta, i, 100, assignedShardNodeId));
}

for (DiscoveryNode node : clusterStateBeforeRollover.nodes().getAllNodes()) {
MockTransportService.getInstance(node.getName())
.addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
TransportIndicesStatsAction instance = internalCluster().getInstance(
TransportIndicesStatsAction.class,
node.getName()
);
channel.sendResponse(
instance.new NodeResponse(node.getId(), secondGenerationMeta.getNumberOfShards(), shards, List.of())
);
});
shards.add(
getShardStats(
secondGenerationMeta,
i,
(long) Math.ceil(100.0 / secondGenerationMeta.getNumberOfShards()),
assignedShardNodeId
)
);
}
mockStatsForIndex(clusterStateBeforeRollover, assignedShardNodeId, secondGenerationMeta, shards);

RolloverResponse response = indicesAdmin().rolloverIndex(new RolloverRequest(dataStreamName, null)).actionGet();
assertAcked(response);
Expand Down Expand Up @@ -232,21 +222,11 @@ instance.new NodeResponse(node.getId(), secondGenerationMeta.getNumberOfShards()
for (int i = 0; i < thirdGenIndex.getNumberOfShards(); i++) {
// the shard stats will yield a write load of 100.0 which will make the auto sharding service recommend an optimal
// number of 7 shards
shards.add(getShardStats(thirdGenIndex, i, 100, assignedShardNodeId));
}

for (DiscoveryNode node : clusterStateBeforeRollover.nodes().getAllNodes()) {
MockTransportService.getInstance(node.getName())
.addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
TransportIndicesStatsAction instance = internalCluster().getInstance(
TransportIndicesStatsAction.class,
node.getName()
);
channel.sendResponse(
instance.new NodeResponse(node.getId(), thirdGenIndex.getNumberOfShards(), shards, List.of())
);
});
shards.add(
getShardStats(thirdGenIndex, i, (long) Math.ceil(100.0 / thirdGenIndex.getNumberOfShards()), assignedShardNodeId)
);
}
mockStatsForIndex(clusterStateBeforeRollover, assignedShardNodeId, thirdGenIndex, shards);

RolloverRequest request = new RolloverRequest(dataStreamName, null);
request.setConditions(RolloverConditions.newBuilder().addMaxIndexDocsCondition(1_000_000L).build());
Expand Down Expand Up @@ -309,22 +289,10 @@ public void testReduceShardsOnRollover() throws IOException {
for (int i = 0; i < firstGenerationMeta.getNumberOfShards(); i++) {
// the shard stats will yield a write load of 2.0 which will make the auto sharding service recommend an optimal number
// of 2 shards
shards.add(getShardStats(firstGenerationMeta, i, 2, assignedShardNodeId));
}

for (DiscoveryNode node : clusterStateBeforeRollover.nodes().getAllNodes()) {
MockTransportService.getInstance(node.getName())
.addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
TransportIndicesStatsAction instance = internalCluster().getInstance(
TransportIndicesStatsAction.class,
node.getName()
);
channel.sendResponse(
instance.new NodeResponse(node.getId(), firstGenerationMeta.getNumberOfShards(), shards, List.of())
);
});
shards.add(getShardStats(firstGenerationMeta, i, i < 2 ? 1 : 0, assignedShardNodeId));
}

mockStatsForIndex(clusterStateBeforeRollover, assignedShardNodeId, firstGenerationMeta, shards);
assertAcked(indicesAdmin().rolloverIndex(new RolloverRequest(dataStreamName, null)).actionGet());

ClusterState clusterStateAfterRollover = internalCluster().getCurrentMasterNodeInstance(ClusterService.class).state();
Expand Down Expand Up @@ -356,23 +324,11 @@ instance.new NodeResponse(node.getId(), firstGenerationMeta.getNumberOfShards(),
.index(dataStreamBeforeRollover.getIndices().get(1));
List<ShardStats> shards = new ArrayList<>(secondGenerationIndex.getNumberOfShards());
for (int i = 0; i < secondGenerationIndex.getNumberOfShards(); i++) {
// the shard stats will yield a write load of 2.0 which will make the auto sharding service recommend an optimal
// number of 2 shards
shards.add(getShardStats(secondGenerationIndex, i, 2, assignedShardNodeId));
}

for (DiscoveryNode node : clusterStateBeforeRollover.nodes().getAllNodes()) {
MockTransportService.getInstance(node.getName())
.addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
TransportIndicesStatsAction instance = internalCluster().getInstance(
TransportIndicesStatsAction.class,
node.getName()
);
channel.sendResponse(
instance.new NodeResponse(node.getId(), secondGenerationIndex.getNumberOfShards(), shards, List.of())
);
});
// the shard stats will yield a write load of 2.0 which will make the auto sharding service recommend an
// optimal number of 2 shards
shards.add(getShardStats(secondGenerationIndex, i, i < 2 ? 1 : 0, assignedShardNodeId));
}
mockStatsForIndex(clusterStateBeforeRollover, assignedShardNodeId, secondGenerationIndex, shards);

RolloverRequest request = new RolloverRequest(dataStreamName, null);
// adding condition that does NOT match
Expand Down Expand Up @@ -438,36 +394,25 @@ public void testLazyRolloverKeepsPreviousAutoshardingDecision() throws IOExcepti
IndexMetadata firstGenerationMeta = clusterStateBeforeRollover.getMetadata().index(firstGenerationIndex);

List<ShardStats> shards = new ArrayList<>(firstGenerationMeta.getNumberOfShards());
String assignedShardNodeId = clusterStateBeforeRollover.routingTable()
.index(dataStreamBeforeRollover.getWriteIndex())
.shard(0)
.primaryShard()
.currentNodeId();
for (int i = 0; i < firstGenerationMeta.getNumberOfShards(); i++) {
// the shard stats will yield a write load of 75.0 which will make the auto sharding service recommend an optimal number
// of 5 shards
shards.add(
getShardStats(
firstGenerationMeta,
i,
75,
clusterStateBeforeRollover.routingTable()
.index(dataStreamBeforeRollover.getWriteIndex())
.shard(0)
.primaryShard()
.currentNodeId()
(long) Math.ceil(75.0 / firstGenerationMeta.getNumberOfShards()),
assignedShardNodeId
)
);
}

for (DiscoveryNode node : clusterStateBeforeRollover.nodes().getAllNodes()) {
MockTransportService.getInstance(node.getName())
.addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
TransportIndicesStatsAction instance = internalCluster().getInstance(
TransportIndicesStatsAction.class,
node.getName()
);
channel.sendResponse(
instance.new NodeResponse(node.getId(), firstGenerationMeta.getNumberOfShards(), shards, List.of())
);
});
}

mockStatsForIndex(clusterStateBeforeRollover, assignedShardNodeId, firstGenerationMeta, shards);
assertAcked(indicesAdmin().rolloverIndex(new RolloverRequest(dataStreamName, null)).actionGet());

ClusterState clusterStateAfterRollover = internalCluster().getCurrentMasterNodeInstance(ClusterService.class).state();
Expand All @@ -491,37 +436,22 @@ instance.new NodeResponse(node.getId(), firstGenerationMeta.getNumberOfShards(),
ClusterState clusterStateBeforeRollover = internalCluster().getCurrentMasterNodeInstance(ClusterService.class).state();
DataStream dataStreamBeforeRollover = clusterStateBeforeRollover.getMetadata().dataStreams().get(dataStreamName);

String assignedShardNodeId = clusterStateBeforeRollover.routingTable()
.index(dataStreamBeforeRollover.getWriteIndex())
.shard(0)
.primaryShard()
.currentNodeId();
IndexMetadata secondGenIndex = clusterStateBeforeRollover.metadata().index(dataStreamBeforeRollover.getIndices().get(1));
List<ShardStats> shards = new ArrayList<>(secondGenIndex.getNumberOfShards());
for (int i = 0; i < secondGenIndex.getNumberOfShards(); i++) {
// the shard stats will yield a write load of 100.0 which will make the auto sharding service recommend an optimal
// number of 7 shards
shards.add(
getShardStats(
secondGenIndex,
i,
100,
clusterStateBeforeRollover.routingTable()
.index(dataStreamBeforeRollover.getWriteIndex())
.shard(i)
.primaryShard()
.currentNodeId()
)
getShardStats(secondGenIndex, i, (long) Math.ceil(100.0 / secondGenIndex.getNumberOfShards()), assignedShardNodeId)
);
}

for (DiscoveryNode node : clusterStateBeforeRollover.nodes().getAllNodes()) {
MockTransportService.getInstance(node.getName())
.addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
TransportIndicesStatsAction instance = internalCluster().getInstance(
TransportIndicesStatsAction.class,
node.getName()
);
channel.sendResponse(
instance.new NodeResponse(node.getId(), secondGenIndex.getNumberOfShards(), shards, List.of())
);
});
}
mockStatsForIndex(clusterStateBeforeRollover, assignedShardNodeId, secondGenIndex, shards);

RolloverRequest request = new RolloverRequest(dataStreamName, null);
request.lazy(true);
Expand Down Expand Up @@ -612,4 +542,33 @@ public Settings additionalSettings() {
}
}

private static void mockStatsForIndex(
ClusterState clusterState,
String assignedShardNodeId,
IndexMetadata indexMetadata,
List<ShardStats> shards
) {
for (DiscoveryNode node : clusterState.nodes().getAllNodes()) {
// one node returns the stats for all our shards, the other nodes don't return any stats
if (node.getId().equals(assignedShardNodeId)) {
MockTransportService.getInstance(node.getName())
.addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
TransportIndicesStatsAction instance = internalCluster().getInstance(
TransportIndicesStatsAction.class,
node.getName()
);
channel.sendResponse(instance.new NodeResponse(node.getId(), indexMetadata.getNumberOfShards(), shards, List.of()));
});
} else {
MockTransportService.getInstance(node.getName())
.addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
TransportIndicesStatsAction instance = internalCluster().getInstance(
TransportIndicesStatsAction.class,
node.getName()
);
channel.sendResponse(instance.new NodeResponse(node.getId(), 0, List.of(), List.of()));
});
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,16 @@ protected void masterOperation(
final Optional<IndexStats> indexStats = Optional.ofNullable(statsResponse)
.map(stats -> stats.getIndex(dataStream.getWriteIndex().getName()));

Double writeLoad = indexStats.map(stats -> stats.getTotal().getIndexing())
.map(indexing -> indexing.getTotal().getWriteLoad())
.orElse(null);

rolloverAutoSharding = dataStreamAutoShardingService.calculate(clusterState, dataStream, writeLoad);
Double indexWriteLoad = indexStats.map(
stats -> Arrays.stream(stats.getShards())
.filter(shardStats -> shardStats.getStats().indexing != null)
// only take primaries into account as in stateful the replicas also index data
.filter(shardStats -> shardStats.getShardRouting().primary())
.map(shardStats -> shardStats.getStats().indexing.getTotal().getWriteLoad())
.reduce(0.0, Double::sum)
).orElse(null);

rolloverAutoSharding = dataStreamAutoShardingService.calculate(clusterState, dataStream, indexWriteLoad);
logger.debug("auto sharding result for data stream [{}] is [{}]", dataStream.getName(), rolloverAutoSharding);

// if auto sharding recommends increasing the number of shards we want to trigger a rollover even if there are no
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import java.util.List;
import java.util.Objects;
import java.util.OptionalDouble;
import java.util.OptionalLong;
import java.util.function.Function;
import java.util.function.LongSupplier;

Expand Down Expand Up @@ -381,27 +380,11 @@ static double getMaxIndexLoadWithinCoolingPeriod(
// assume the current write index load is the highest observed and look back to find the actual maximum
double maxIndexLoadWithinCoolingPeriod = writeIndexLoad;
for (IndexWriteLoad writeLoad : writeLoadsWithinCoolingPeriod) {
// the IndexWriteLoad stores _for each shard_ a shard average write load ( calculated using : shard indexing time / shard
// uptime ) and its corresponding shard uptime
//
// to reconstruct the average _index_ write load we recalculate the shard indexing time by multiplying the shard write load
// to its uptime, and then, having the indexing time and uptime for each shard we calculate the average _index_ write load using
// (indexingTime_shard0 + indexingTime_shard1) / (uptime_shard0 + uptime_shard1)
// as {@link org.elasticsearch.index.shard.IndexingStats#add} does
double totalShardIndexingTime = 0;
long totalShardUptime = 0;
double totalIndexLoad = 0;
for (int shardId = 0; shardId < writeLoad.numberOfShards(); shardId++) {
final OptionalDouble writeLoadForShard = writeLoad.getWriteLoadForShard(shardId);
final OptionalLong uptimeInMillisForShard = writeLoad.getUptimeInMillisForShard(shardId);
if (writeLoadForShard.isPresent()) {
assert uptimeInMillisForShard.isPresent();
double shardIndexingTime = writeLoadForShard.getAsDouble() * uptimeInMillisForShard.getAsLong();
long shardUptimeInMillis = uptimeInMillisForShard.getAsLong();
totalShardIndexingTime += shardIndexingTime;
totalShardUptime += shardUptimeInMillis;
}
totalIndexLoad += writeLoadForShard.orElse(0);
}
double totalIndexLoad = totalShardUptime == 0 ? 0.0 : (totalShardIndexingTime / totalShardUptime);
if (totalIndexLoad > maxIndexLoadWithinCoolingPeriod) {
maxIndexLoadWithinCoolingPeriod = totalIndexLoad;
}
Expand Down

0 comments on commit 9776f54

Please sign in to comment.