Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AsyncIOProcessor preserve thread context #43729

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.function.Consumer;
import java.util.function.Supplier;

/**
* This async IO processor allows to batch IO operations and have a single writer processing the write operations.
Expand All @@ -39,11 +40,13 @@
public abstract class AsyncIOProcessor<Item> {
private final Logger logger;
private final ArrayBlockingQueue<Tuple<Item, Consumer<Exception>>> queue;
private final ThreadContext threadContext;
private final Semaphore promiseSemaphore = new Semaphore(1);

protected AsyncIOProcessor(Logger logger, int queueSize) {
protected AsyncIOProcessor(Logger logger, int queueSize, ThreadContext threadContext) {
this.logger = logger;
this.queue = new ArrayBlockingQueue<>(queueSize);
this.threadContext = threadContext;
}

/**
Expand All @@ -58,11 +61,10 @@ public final void put(Item item, Consumer<Exception> listener) {

// we first try make a promise that we are responsible for the processing
final boolean promised = promiseSemaphore.tryAcquire();
final Tuple<Item, Consumer<Exception>> itemTuple = new Tuple<>(item, listener);
if (promised == false) {
// in this case we are not responsible and can just block until there is space
try {
queue.put(new Tuple<>(item, listener));
queue.put(new Tuple<>(item, preserveContext(listener)));
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
listener.accept(e);
Expand All @@ -76,7 +78,8 @@ public final void put(Item item, Consumer<Exception> listener) {
try {
if (promised) {
// we are responsible for processing we don't need to add the tuple to the queue we can just add it to the candidates
candidates.add(itemTuple);
// no need to preserve context for listener since it runs in current thread.
candidates.add(new Tuple<>(item, listener));
}
// since we made the promise to process we gotta do it here at least once
drainAndProcess(candidates);
Expand Down Expand Up @@ -121,6 +124,15 @@ private void processList(List<Tuple<Item, Consumer<Exception>>> candidates) {
}
}

private Consumer<Exception> preserveContext(Consumer<Exception> consumer) {
Supplier<ThreadContext.StoredContext> restorableContext = threadContext.newRestorableContext(false);
return e -> {
try (ThreadContext.StoredContext ignore = restorableContext.get()) {
consumer.accept(e);
}
};
}

/**
* Writes or processes the items out or to disk.
*/
Expand Down
31 changes: 19 additions & 12 deletions server/src/main/java/org/elasticsearch/index/shard/IndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.AsyncIOProcessor;
import org.elasticsearch.common.util.concurrent.RunOnce;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.gateway.WriteStateException;
Expand Down Expand Up @@ -291,6 +292,7 @@ public IndexShard(
this.indexSortSupplier = indexSortSupplier;
this.indexEventListener = indexEventListener;
this.threadPool = threadPool;
this.translogSyncProcessor = createTranslogSyncProcessor(logger, threadPool.getThreadContext(), this::getEngine);
this.mapperService = mapperService;
this.indexCache = indexCache;
this.internalIndexingStats = new InternalIndexingStats();
Expand Down Expand Up @@ -2789,19 +2791,24 @@ public List<String> getActiveOperations() {
return indexShardOperationPermits.getActiveOperations();
}

private final AsyncIOProcessor<Translog.Location> translogSyncProcessor = new AsyncIOProcessor<Translog.Location>(logger, 1024) {
@Override
protected void write(List<Tuple<Translog.Location, Consumer<Exception>>> candidates) throws IOException {
try {
getEngine().ensureTranslogSynced(candidates.stream().map(Tuple::v1));
} catch (AlreadyClosedException ex) {
// that's fine since we already synced everything on engine close - this also is conform with the methods
// documentation
} catch (IOException ex) { // if this fails we are in deep shit - fail the request
logger.debug("failed to sync translog", ex);
throw ex;
private final AsyncIOProcessor<Translog.Location> translogSyncProcessor;

private static AsyncIOProcessor<Translog.Location> createTranslogSyncProcessor(Logger logger, ThreadContext threadContext,
Supplier<Engine> engineSupplier) {
return new AsyncIOProcessor<>(logger, 1024, threadContext) {
@Override
protected void write(List<Tuple<Translog.Location, Consumer<Exception>>> candidates) throws IOException {
try {
engineSupplier.get().ensureTranslogSynced(candidates.stream().map(Tuple::v1));
} catch (AlreadyClosedException ex) {
// that's fine since we already synced everything on engine close - this also is conform with the methods
// documentation
} catch (IOException ex) { // if this fails we are in deep shit - fail the request
logger.debug("failed to sync translog", ex);
throw ex;
}
}
}
};
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,40 @@
package org.elasticsearch.common.util.concurrent;

import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.test.ESTestCase;
import org.junit.After;
import org.junit.Before;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class AsyncIOProcessorTests extends ESTestCase {

private ThreadContext threadContext;

@Before
public void setUpThreadContext() {
threadContext = new ThreadContext(Settings.EMPTY);
}

@After
public void tearDownThreadContext() {
threadContext.close();
}

public void testPut() throws InterruptedException {
boolean blockInternal = randomBoolean();
AtomicInteger received = new AtomicInteger(0);
AsyncIOProcessor<Object> processor = new AsyncIOProcessor<Object>(logger, scaledRandomIntBetween(1, 2024)) {
AsyncIOProcessor<Object> processor = new AsyncIOProcessor<Object>(logger, scaledRandomIntBetween(1, 2024), threadContext) {
@Override
protected void write(List<Tuple<Object, Consumer<Exception>>> candidates) throws IOException {
if (blockInternal) {
Expand Down Expand Up @@ -83,7 +101,7 @@ public void testRandomFail() throws InterruptedException {
AtomicInteger received = new AtomicInteger(0);
AtomicInteger failed = new AtomicInteger(0);
AtomicInteger actualFailed = new AtomicInteger(0);
AsyncIOProcessor<Object> processor = new AsyncIOProcessor<Object>(logger, scaledRandomIntBetween(1, 2024)) {
AsyncIOProcessor<Object> processor = new AsyncIOProcessor<Object>(logger, scaledRandomIntBetween(1, 2024), threadContext) {
@Override
protected void write(List<Tuple<Object, Consumer<Exception>>> candidates) throws IOException {
received.addAndGet(candidates.size());
Expand Down Expand Up @@ -137,7 +155,7 @@ public void testConsumerCanThrowExceptions() {
AtomicInteger received = new AtomicInteger(0);
AtomicInteger notified = new AtomicInteger(0);

AsyncIOProcessor<Object> processor = new AsyncIOProcessor<Object>(logger, scaledRandomIntBetween(1, 2024)) {
AsyncIOProcessor<Object> processor = new AsyncIOProcessor<Object>(logger, scaledRandomIntBetween(1, 2024), threadContext) {
@Override
protected void write(List<Tuple<Object, Consumer<Exception>>> candidates) throws IOException {
received.addAndGet(candidates.size());
Expand All @@ -156,7 +174,7 @@ protected void write(List<Tuple<Object, Consumer<Exception>>> candidates) throws
}

public void testNullArguments() {
AsyncIOProcessor<Object> processor = new AsyncIOProcessor<Object>(logger, scaledRandomIntBetween(1, 2024)) {
AsyncIOProcessor<Object> processor = new AsyncIOProcessor<Object>(logger, scaledRandomIntBetween(1, 2024), threadContext) {
@Override
protected void write(List<Tuple<Object, Consumer<Exception>>> candidates) throws IOException {
}
Expand All @@ -165,4 +183,59 @@ protected void write(List<Tuple<Object, Consumer<Exception>>> candidates) throws
expectThrows(NullPointerException.class, () -> processor.put(null, (e) -> {}));
expectThrows(NullPointerException.class, () -> processor.put(new Object(), null));
}

public void testPreserveThreadContext() throws InterruptedException {
final int threadCount = randomIntBetween(2, 10);
final String testHeader = "testheader";

AtomicInteger received = new AtomicInteger(0);
AtomicInteger notified = new AtomicInteger(0);

CountDownLatch writeDelay = new CountDownLatch(1);
AsyncIOProcessor<Object> processor = new AsyncIOProcessor<Object>(logger, scaledRandomIntBetween(threadCount - 1, 2024),
threadContext) {
@Override
protected void write(List<Tuple<Object, Consumer<Exception>>> candidates) throws IOException {
try {
assertTrue(writeDelay.await(10, TimeUnit.SECONDS));
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
received.addAndGet(candidates.size());
}
};

// first thread blocks, the rest should be non blocking.
CountDownLatch nonBlockingDone = new CountDownLatch(randomIntBetween(0, threadCount - 1));
List<Thread> threads = IntStream.range(0, threadCount).mapToObj(i -> new Thread(getTestName() + "_" + i) {
private final String response = randomAlphaOfLength(10);
{
setDaemon(true);
}

@Override
public void run() {
threadContext.addResponseHeader(testHeader, response);
processor.put(new Object(), (e) -> {
assertEquals(Map.of(testHeader, List.of(response)), threadContext.getResponseHeaders());
notified.incrementAndGet();
});
nonBlockingDone.countDown();
}
}).collect(Collectors.toList());
threads.forEach(Thread::start);
assertTrue(nonBlockingDone.await(10, TimeUnit.SECONDS));
writeDelay.countDown();
threads.forEach(t -> {
try {
t.join(20000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
});

assertEquals(threadCount, notified.get());
assertEquals(threadCount, received.get());
threads.forEach(t -> assertFalse(t.isAlive()));
}
}