Skip to content

Commit

Permalink
[7.17] Cancel cold cache prewarming tasks if store is closing (#95891) (
Browse files Browse the repository at this point in the history
#96024)

Cold cache prewarming tasks are not stopped immediately when
the shard is closed, causing excessive use of disk for nothing.
This change adds a Supplier<Boolean> to prewarming logic that
can be checked before executing any consuming operation to know
if the Store is closing.

I'm not super happy with my test changes but the logic for
checking if prewarming works correctly is tricky.

Closes #95504
  • Loading branch information
tlrx committed May 11, 2023
1 parent fc5e40b commit 868373e
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.action.admin.cluster.snapshots.create.CreateSnapshotResponse;
import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotResponse;
import org.elasticsearch.action.admin.indices.recovery.RecoveryResponse;
import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.routing.allocation.decider.ThrottlingAllocationDecider;
import org.elasticsearch.cluster.service.ClusterService;
Expand All @@ -26,11 +28,16 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.env.Environment;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.snapshots.blobstore.BlobStoreIndexShardSnapshot;
import org.elasticsearch.index.store.LuceneFilesExtensions;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.indices.recovery.RecoverySettings;
import org.elasticsearch.indices.recovery.RecoveryState;
import org.elasticsearch.license.LicenseService;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.PluginsService;
Expand All @@ -57,24 +64,33 @@
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.Collections.emptySet;
import static java.util.stream.Collectors.toList;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots.CACHE_PREWARMING_THREAD_POOL_NAME;
import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots.SNAPSHOT_CACHE_ENABLED_SETTING;
import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots.SNAPSHOT_CACHE_PREWARM_ENABLED_SETTING;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.emptyCollectionOf;
import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
Expand Down Expand Up @@ -189,17 +205,22 @@ public void testConcurrentPrewarming() throws Exception {

final Map<String, List<String>> exclusionsPerIndex = new HashMap<>();
for (int index = 0; index < nbIndices; index++) {
exclusionsPerIndex.put("index-" + index, randomSubsetOf(Arrays.asList("fdt", "fdx", "nvd", "dvd", "tip", "cfs", "dim")));
exclusionsPerIndex.put(
"index-" + index,
randomSubsetOf(Stream.of(LuceneFilesExtensions.values()).map(LuceneFilesExtensions::getExtension).collect(toList()))
);
}

logger.debug("--> mounting indices");
final List<String> mountedIndices = new ArrayList<>(nbIndices);
final Thread[] threads = new Thread[nbIndices];
final AtomicArray<Throwable> throwables = new AtomicArray<>(nbIndices);
final CountDownLatch latch = new CountDownLatch(1);

for (int i = 0; i < threads.length; i++) {
int threadId = i;
final String indexName = "index-" + threadId;
mountedIndices.add(indexName);
final Thread thread = new Thread(() -> {
try {
latch.await();
Expand Down Expand Up @@ -245,18 +266,73 @@ public void testConcurrentPrewarming() throws Exception {
thread.start();
}

// some indices are randomly removed before prewarming completes
final Set<String> deletedIndicesDuringPrewarming = randomBoolean() ? new HashSet<>(randomSubsetOf(mountedIndices)) : emptySet();

final CountDownLatch startPrewarmingLatch = new CountDownLatch(1);
final ThreadPool threadPool = getInstanceFromNode(ThreadPool.class);
final int maxUploadTasks = threadPool.info(CACHE_PREWARMING_THREAD_POOL_NAME).getMax();
for (int i = 0; i < maxUploadTasks; i++) {
threadPool.executor(CACHE_PREWARMING_THREAD_POOL_NAME).execute(new AbstractRunnable() {

@Override
protected void doRun() throws Exception {
startPrewarmingLatch.await();
}

@Override
public void onFailure(Exception e) {
throw new AssertionError(e);
}
});
}

ExecutorService prewarmingExecutor = threadPool.executor(SearchableSnapshots.CACHE_PREWARMING_THREAD_POOL_NAME);
assertThat(prewarmingExecutor, instanceOf(ThreadPoolExecutor.class));
assertThat(((ThreadPoolExecutor) prewarmingExecutor).getActiveCount(), equalTo(maxUploadTasks));

latch.countDown();
for (Thread thread : threads) {
thread.join();
}

assertThat("Failed to mount snapshot as indices", throwables.asList(), emptyCollectionOf(Throwable.class));

logger.debug("--> waiting for background cache prewarming to");
logger.debug("--> waiting for background cache to complete");
assertBusy(() -> {
final ThreadPool threadPool = getInstanceFromNode(ThreadPool.class);
assertThat(threadPool.info(SearchableSnapshots.CACHE_FETCH_ASYNC_THREAD_POOL_NAME).getQueueSize(), nullValue());
assertThat(threadPool.info(SearchableSnapshots.CACHE_PREWARMING_THREAD_POOL_NAME).getQueueSize(), nullValue());
ExecutorService executor = threadPool.executor(SearchableSnapshots.CACHE_FETCH_ASYNC_THREAD_POOL_NAME);
if (executor instanceof ThreadPoolExecutor) {
assertThat(((ThreadPoolExecutor) executor).getQueue().size(), equalTo(0));
assertThat(((ThreadPoolExecutor) executor).getActiveCount(), equalTo(0));
}
});

if (deletedIndicesDuringPrewarming.isEmpty() == false) {
Set<Index> deletedIndices = deletedIndicesDuringPrewarming.stream().map(this::resolveIndex).collect(Collectors.toSet());
logger.debug("--> deleting indices [{}] before prewarming", deletedIndices);
assertAcked(client().admin().indices().prepareDelete(deletedIndicesDuringPrewarming.toArray(new String[] {})));

IndicesService indicesService = getInstanceFromNode(IndicesService.class);
assertBusy(() -> deletedIndices.forEach(deletedIndex -> assertThat(indicesService.hasIndex(deletedIndex), is(false))));
}

startPrewarmingLatch.countDown();

// wait for recovery to be DONE
assertBusy(() -> {
RecoveryResponse recoveryResponse = client().admin()
.indices()
.prepareRecoveries(mountedIndices.toArray(new String[] {}))
.setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN)
.get();
assertThat(
recoveryResponse.shardRecoveryStates()
.values()
.stream()
.flatMap(Collection::stream)
.allMatch(recoveryState -> recoveryState.getStage().equals(RecoveryState.Stage.DONE)),
is(true)
);
});

logger.debug("--> loading snapshot metadata");
Expand Down Expand Up @@ -287,14 +363,55 @@ public void testConcurrentPrewarming() throws Exception {
.stream()
.filter(file -> file.metadata().hashEqualsContents() == false)
.filter(file -> exclusionsPerIndex.get(indexName).contains(IndexFileNames.getExtension(file.physicalName())) == false)
.collect(Collectors.toList());

for (BlobStoreIndexShardSnapshot.FileInfo expectedPrewarmedBlob : expectedPrewarmedBlobs) {
for (int part = 0; part < expectedPrewarmedBlob.numberOfParts(); part++) {
final String blobName = expectedPrewarmedBlob.partName(part);
long actualBytesRead = tracker.totalBytesRead(blobName);
long expectedBytesRead = expectedPrewarmedBlob.partBytes(part);
assertThat("Blob [" + blobName + "] not fully warmed", actualBytesRead, greaterThanOrEqualTo(expectedBytesRead));
.collect(toList());

if (deletedIndicesDuringPrewarming.contains(indexName) == false) {
for (BlobStoreIndexShardSnapshot.FileInfo blob : expectedPrewarmedBlobs) {
for (int part = 0; part < blob.numberOfParts(); part++) {
final String blobName = blob.partName(part);
try {
assertThat(
"Blob [" + blobName + "][" + blob.physicalName() + "] not prewarmed",
tracker.totalBytesRead(blobName),
equalTo(blob.partBytes(part))
);
} catch (AssertionError ae) {
assertThat(
"Only blobs from physical file with specific extensions are expected to be prewarmed over their sizes ["
+ blobName
+ "]["
+ blob.physicalName()
+ "] but got :"
+ ae,
blob.physicalName(),
anyOf(endsWith(".cfe"), endsWith(".cfs"))
);
}
}
}
} else {
for (BlobStoreIndexShardSnapshot.FileInfo blob : expectedPrewarmedBlobs) {
for (int part = 0; part < blob.numberOfParts(); part++) {
final String blobName = blob.partName(part);
try {
assertThat(
"Blob [" + blobName + "][" + blob.physicalName() + "] should not have been fully prewarmed",
tracker.totalBytesRead(blobName),
nullValue()
);
} catch (AssertionError ae) {
assertThat(
"Only blobs from physical file with specific extensions are expected to be prewarmed over their sizes ["
+ blobName
+ "]["
+ blob.physicalName()
+ "] but got :"
+ ae,
blob.physicalName(),
anyOf(endsWith(".cfe"), endsWith(".cfs"))
);
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.index.shard.IndexEventListener;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.index.store.Store;
import org.elasticsearch.index.translog.Translog;
import org.elasticsearch.index.translog.TranslogException;
import org.elasticsearch.indices.cluster.IndicesClusterStateService.AllocatedIndices.IndexRemovalReason;
Expand Down Expand Up @@ -69,10 +70,11 @@ public void beforeIndexShardRecovery(IndexShard indexShard, IndexSettings indexS
}

private static void ensureSnapshotIsLoaded(IndexShard indexShard) {
final SearchableSnapshotDirectory directory = unwrapDirectory(indexShard.store().directory());
final Store store = indexShard.store();
final SearchableSnapshotDirectory directory = unwrapDirectory(store.directory());
assert directory != null;
final StepListener<Void> preWarmListener = new StepListener<>();
final boolean success = directory.loadSnapshot(indexShard.recoveryState(), preWarmListener);
final boolean success = directory.loadSnapshot(indexShard.recoveryState(), store::isClosing, preWarmListener);
final ShardRouting shardRouting = indexShard.routingEntry();
if (success && shardRouting.isRelocationTarget()) {
final Runnable preWarmCondition = indexShard.addCleanFilesDependency();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,11 @@ protected final boolean assertCurrentThreadMayLoadSnapshot() {
*
* @return true if the snapshot was loaded by executing this method, false otherwise
*/
public boolean loadSnapshot(RecoveryState snapshotRecoveryState, ActionListener<Void> preWarmListener) {
public boolean loadSnapshot(
RecoveryState snapshotRecoveryState,
Supplier<Boolean> cancelPreWarming,
ActionListener<Void> preWarmListener
) {
assert snapshotRecoveryState != null;
assert snapshotRecoveryState instanceof SearchableSnapshotRecoveryState;
assert snapshotRecoveryState.getRecoverySource().getType() == RecoverySource.Type.SNAPSHOT
Expand All @@ -230,7 +234,7 @@ public boolean loadSnapshot(RecoveryState snapshotRecoveryState, ActionListener<
cleanExistingRegularShardFiles();
waitForPendingEvictions();
this.recoveryState = (SearchableSnapshotRecoveryState) snapshotRecoveryState;
prewarmCache(preWarmListener);
prewarmCache(preWarmListener, cancelPreWarming);
}
}
}
Expand Down Expand Up @@ -492,8 +496,8 @@ private void waitForPendingEvictions() {
cacheService.waitForCacheFilesEvictionIfNeeded(snapshotId.getUUID(), indexId.getName(), shardId);
}

private void prewarmCache(ActionListener<Void> listener) {
if (prewarmCache == false) {
private void prewarmCache(ActionListener<Void> listener, Supplier<Boolean> cancelPreWarming) {
if (prewarmCache == false || cancelPreWarming.get()) {
recoveryState.setPreWarmComplete();
listener.onResponse(null);
return;
Expand All @@ -508,6 +512,10 @@ private void prewarmCache(ActionListener<Void> listener) {
}, listener::onFailure), snapshot().totalFileCount());

for (BlobStoreIndexShardSnapshot.FileInfo file : snapshot().indexFiles()) {
if (cancelPreWarming.get()) {
completionListener.onResponse(null);
continue;
}
boolean hashEqualsContents = file.metadata().hashEqualsContents();
if (hashEqualsContents || isExcludedFromCache(file.physicalName())) {
if (hashEqualsContents) {
Expand Down Expand Up @@ -540,11 +548,9 @@ private void prewarmCache(ActionListener<Void> listener) {
for (int p = 0; p < numberOfParts; p++) {
final int part = p;
queue.add(Tuple.tuple(partsListener, () -> {
ensureOpen();

logger.trace("{} warming cache for [{}] part [{}/{}]", shardId, file.physicalName(), part + 1, numberOfParts);
final long startTimeInNanos = statsCurrentTimeNanosSupplier.getAsLong();
final long persistentCacheLength = ((CachedBlobContainerIndexInput) input).prefetchPart(part).v1();
long persistentCacheLength = ((CachedBlobContainerIndexInput) input).prefetchPart(part, cancelPreWarming).v1();
if (persistentCacheLength == file.length()) {
recoveryState.markIndexFileAsReused(file.physicalName());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.nio.ByteBuffer;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsUtils.toIntBytes;
Expand All @@ -34,7 +35,7 @@ public class CachedBlobContainerIndexInput extends MetadataCachingIndexInput {
/**
* Specific IOContext used for prewarming the cache. This context allows to write
* a complete part of the {@link #fileInfo} at once in the cache and should not be
* used for anything else than what the {@link #prefetchPart(int)} method does.
* used for anything else than what the {@link #prefetchPart(int, Supplier)} method does.
*/
public static final IOContext CACHE_WARMING_CONTEXT = new IOContext();

Expand Down Expand Up @@ -140,11 +141,15 @@ protected void readWithoutBlobCache(ByteBuffer b) throws Exception {
* Prefetches a complete part and writes it in cache. This method is used to prewarm the cache.
* @return a tuple with {@code Tuple<Persistent Cache Length, Prefetched Length>} values
*/
public Tuple<Long, Long> prefetchPart(final int part) throws IOException {
public Tuple<Long, Long> prefetchPart(final int part, Supplier<Boolean> isCancelled) throws IOException {
ensureContext(ctx -> ctx == CACHE_WARMING_CONTEXT);
if (part >= fileInfo.numberOfParts()) {
throw new IllegalArgumentException("Unexpected part number [" + part + "]");
}
if (isCancelled.get()) {
return Tuple.tuple(0L, 0L);
}

final ByteRange partRange = computeRange(IntStream.range(0, part).mapToLong(fileInfo::partBytes).sum());
assert assertRangeIsAlignedWithPart(partRange);

Expand Down Expand Up @@ -182,6 +187,9 @@ public Tuple<Long, Long> prefetchPart(final int part) throws IOException {
try (InputStream input = openInputStreamFromBlobStore(range.start(), range.length())) {
while (remainingBytes > 0L) {
assert totalBytesRead + remainingBytes == range.length();
if (isCancelled.get()) {
return Tuple.tuple(cacheFile.getInitialLength(), totalBytesRead);
}
final int bytesRead = readSafe(input, copyBuffer, range.start(), range.end(), remainingBytes, cacheFileReference);

// The range to prewarm in cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ protected IndexInputStats createIndexInputStats(long numFiles, long totalSize, l
DiscoveryNode targetNode = new DiscoveryNode("local", buildNewFakeTransportAddress(), Version.CURRENT);
RecoveryState recoveryState = new SearchableSnapshotRecoveryState(shardRouting, targetNode, null);
final PlainActionFuture<Void> future = PlainActionFuture.newFuture();
final boolean loaded = directory.loadSnapshot(recoveryState, future);
final boolean loaded = directory.loadSnapshot(recoveryState, () -> false, future);
future.get();
assertThat("Failed to load snapshot", loaded, is(true));
assertThat("Snapshot should be loaded", directory.snapshot(), notNullValue());
Expand Down

0 comments on commit 868373e

Please sign in to comment.