diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessor.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessor.java index ad68471041b8d..9dd76b1bbc959 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessor.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessor.java @@ -28,6 +28,7 @@ import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.Semaphore; import java.util.function.Consumer; +import java.util.function.Supplier; /** * This async IO processor allows to batch IO operations and have a single writer processing the write operations. @@ -39,11 +40,13 @@ public abstract class AsyncIOProcessor { private final Logger logger; private final ArrayBlockingQueue>> queue; + private final ThreadContext threadContext; private final Semaphore promiseSemaphore = new Semaphore(1); - protected AsyncIOProcessor(Logger logger, int queueSize) { + protected AsyncIOProcessor(Logger logger, int queueSize, ThreadContext threadContext) { this.logger = logger; this.queue = new ArrayBlockingQueue<>(queueSize); + this.threadContext = threadContext; } /** @@ -58,11 +61,10 @@ public final void put(Item item, Consumer listener) { // we first try make a promise that we are responsible for the processing final boolean promised = promiseSemaphore.tryAcquire(); - final Tuple> itemTuple = new Tuple<>(item, listener); if (promised == false) { // in this case we are not responsible and can just block until there is space try { - queue.put(new Tuple<>(item, listener)); + queue.put(new Tuple<>(item, preserveContext(listener))); } catch (InterruptedException e) { Thread.currentThread().interrupt(); listener.accept(e); @@ -76,7 +78,8 @@ public final void put(Item item, Consumer listener) { try { if (promised) { // we are responsible for processing we don't need to add the tuple to the queue we can just add it to the candidates - candidates.add(itemTuple); + // no need to preserve context for listener since it runs in current thread. + candidates.add(new Tuple<>(item, listener)); } // since we made the promise to process we gotta do it here at least once drainAndProcess(candidates); @@ -121,6 +124,15 @@ private void processList(List>> candidates) { } } + private Consumer preserveContext(Consumer consumer) { + Supplier restorableContext = threadContext.newRestorableContext(false); + return e -> { + try (ThreadContext.StoredContext ignore = restorableContext.get()) { + consumer.accept(e); + } + }; + } + /** * Writes or processes the items out or to disk. */ diff --git a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java index 7b4e06a451c7d..8a741661116c5 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java @@ -66,6 +66,7 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.AsyncIOProcessor; import org.elasticsearch.common.util.concurrent.RunOnce; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.gateway.WriteStateException; @@ -291,6 +292,7 @@ public IndexShard( this.indexSortSupplier = indexSortSupplier; this.indexEventListener = indexEventListener; this.threadPool = threadPool; + this.translogSyncProcessor = createTranslogSyncProcessor(logger, threadPool.getThreadContext(), this::getEngine); this.mapperService = mapperService; this.indexCache = indexCache; this.internalIndexingStats = new InternalIndexingStats(); @@ -2789,19 +2791,24 @@ public List getActiveOperations() { return indexShardOperationPermits.getActiveOperations(); } - private final AsyncIOProcessor translogSyncProcessor = new AsyncIOProcessor(logger, 1024) { - @Override - protected void write(List>> candidates) throws IOException { - try { - getEngine().ensureTranslogSynced(candidates.stream().map(Tuple::v1)); - } catch (AlreadyClosedException ex) { - // that's fine since we already synced everything on engine close - this also is conform with the methods - // documentation - } catch (IOException ex) { // if this fails we are in deep shit - fail the request - logger.debug("failed to sync translog", ex); - throw ex; + private final AsyncIOProcessor translogSyncProcessor; + + private static AsyncIOProcessor createTranslogSyncProcessor(Logger logger, ThreadContext threadContext, + Supplier engineSupplier) { + return new AsyncIOProcessor<>(logger, 1024, threadContext) { + @Override + protected void write(List>> candidates) throws IOException { + try { + engineSupplier.get().ensureTranslogSynced(candidates.stream().map(Tuple::v1)); + } catch (AlreadyClosedException ex) { + // that's fine since we already synced everything on engine close - this also is conform with the methods + // documentation + } catch (IOException ex) { // if this fails we are in deep shit - fail the request + logger.debug("failed to sync translog", ex); + throw ex; + } } - } + }; }; /** diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessorTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessorTests.java index 72a1e21d78865..fb6a880f2d4de 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessorTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessorTests.java @@ -19,22 +19,40 @@ package org.elasticsearch.common.util.concurrent; import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; import java.io.IOException; import java.util.List; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.IntStream; public class AsyncIOProcessorTests extends ESTestCase { + private ThreadContext threadContext; + + @Before + public void setUpThreadContext() { + threadContext = new ThreadContext(Settings.EMPTY); + } + + @After + public void tearDownThreadContext() { + threadContext.close(); + } + public void testPut() throws InterruptedException { boolean blockInternal = randomBoolean(); AtomicInteger received = new AtomicInteger(0); - AsyncIOProcessor processor = new AsyncIOProcessor(logger, scaledRandomIntBetween(1, 2024)) { + AsyncIOProcessor processor = new AsyncIOProcessor(logger, scaledRandomIntBetween(1, 2024), threadContext) { @Override protected void write(List>> candidates) throws IOException { if (blockInternal) { @@ -83,7 +101,7 @@ public void testRandomFail() throws InterruptedException { AtomicInteger received = new AtomicInteger(0); AtomicInteger failed = new AtomicInteger(0); AtomicInteger actualFailed = new AtomicInteger(0); - AsyncIOProcessor processor = new AsyncIOProcessor(logger, scaledRandomIntBetween(1, 2024)) { + AsyncIOProcessor processor = new AsyncIOProcessor(logger, scaledRandomIntBetween(1, 2024), threadContext) { @Override protected void write(List>> candidates) throws IOException { received.addAndGet(candidates.size()); @@ -137,7 +155,7 @@ public void testConsumerCanThrowExceptions() { AtomicInteger received = new AtomicInteger(0); AtomicInteger notified = new AtomicInteger(0); - AsyncIOProcessor processor = new AsyncIOProcessor(logger, scaledRandomIntBetween(1, 2024)) { + AsyncIOProcessor processor = new AsyncIOProcessor(logger, scaledRandomIntBetween(1, 2024), threadContext) { @Override protected void write(List>> candidates) throws IOException { received.addAndGet(candidates.size()); @@ -156,7 +174,7 @@ protected void write(List>> candidates) throws } public void testNullArguments() { - AsyncIOProcessor processor = new AsyncIOProcessor(logger, scaledRandomIntBetween(1, 2024)) { + AsyncIOProcessor processor = new AsyncIOProcessor(logger, scaledRandomIntBetween(1, 2024), threadContext) { @Override protected void write(List>> candidates) throws IOException { } @@ -165,4 +183,59 @@ protected void write(List>> candidates) throws expectThrows(NullPointerException.class, () -> processor.put(null, (e) -> {})); expectThrows(NullPointerException.class, () -> processor.put(new Object(), null)); } + + public void testPreserveThreadContext() throws InterruptedException { + final int threadCount = randomIntBetween(2, 10); + final String testHeader = "testheader"; + + AtomicInteger received = new AtomicInteger(0); + AtomicInteger notified = new AtomicInteger(0); + + CountDownLatch writeDelay = new CountDownLatch(1); + AsyncIOProcessor processor = new AsyncIOProcessor(logger, scaledRandomIntBetween(threadCount - 1, 2024), + threadContext) { + @Override + protected void write(List>> candidates) throws IOException { + try { + assertTrue(writeDelay.await(10, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + received.addAndGet(candidates.size()); + } + }; + + // first thread blocks, the rest should be non blocking. + CountDownLatch nonBlockingDone = new CountDownLatch(randomIntBetween(0, threadCount - 1)); + List threads = IntStream.range(0, threadCount).mapToObj(i -> new Thread(getTestName() + "_" + i) { + private final String response = randomAlphaOfLength(10); + { + setDaemon(true); + } + + @Override + public void run() { + threadContext.addResponseHeader(testHeader, response); + processor.put(new Object(), (e) -> { + assertEquals(Map.of(testHeader, List.of(response)), threadContext.getResponseHeaders()); + notified.incrementAndGet(); + }); + nonBlockingDone.countDown(); + } + }).collect(Collectors.toList()); + threads.forEach(Thread::start); + assertTrue(nonBlockingDone.await(10, TimeUnit.SECONDS)); + writeDelay.countDown(); + threads.forEach(t -> { + try { + t.join(20000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(threadCount, notified.get()); + assertEquals(threadCount, received.get()); + threads.forEach(t -> assertFalse(t.isAlive())); + } }