Navigation Menu

Skip to content

Commit

Permalink
Keep context during reindex's retries (#21941)
Browse files Browse the repository at this point in the history
* Keep context during reindex's retries

This fixes reindex and friend's retries to keep the context.

* Docs
  • Loading branch information
nik9000 committed Dec 2, 2016
1 parent 0cfcd82 commit 248021e
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 38 deletions.
4 changes: 3 additions & 1 deletion core/src/main/java/org/elasticsearch/action/bulk/Retry.java
Expand Up @@ -142,7 +142,9 @@ private void retry(BulkRequest bulkRequestForRetry) {
assert backoff.hasNext();
TimeValue next = backoff.next();
logger.trace("Retry of bulk request scheduled in {} ms.", next.millis());
scheduledRequestFuture = client.threadPool().schedule(next, ThreadPool.Names.SAME, (() -> this.execute(bulkRequestForRetry)));
Runnable retry = () -> this.execute(bulkRequestForRetry);
retry = client.threadPool().getThreadContext().preserveContext(retry);
scheduledRequestFuture = client.threadPool().schedule(next, ThreadPool.Names.SAME, retry);
}

private BulkRequest createBulkRequestForRetry(BulkResponse bulkItemResponses) {
Expand Down
19 changes: 18 additions & 1 deletion core/src/test/java/org/elasticsearch/action/bulk/RetryTests.java
Expand Up @@ -30,6 +30,8 @@
import org.junit.After;
import org.junit.Before;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;

Expand All @@ -43,12 +45,21 @@ public class RetryTests extends ESTestCase {
private static final int CALLS_TO_FAIL = 5;

private MockBulkClient bulkClient;
/**
* Headers that are expected to be sent with all bulk requests.
*/
private Map<String, String> expectedHeaders = new HashMap<>();

@Override
@Before
public void setUp() throws Exception {
super.setUp();
this.bulkClient = new MockBulkClient(getTestName(), CALLS_TO_FAIL);
// Stash some random headers so we can assert that we preserve them
bulkClient.threadPool().getThreadContext().stashContext();
expectedHeaders.clear();
expectedHeaders.put(randomAsciiOfLength(5), randomAsciiOfLength(5));
bulkClient.threadPool().getThreadContext().putHeader(expectedHeaders);
}

@Override
Expand Down Expand Up @@ -177,7 +188,7 @@ public void assertOnFailureNeverCalled() {
}
}

private static class MockBulkClient extends NoOpClient {
private class MockBulkClient extends NoOpClient {
private int numberOfCallsToFail;

private MockBulkClient(String testName, int numberOfCallsToFail) {
Expand All @@ -194,6 +205,12 @@ public ActionFuture<BulkResponse> bulk(BulkRequest request) {

@Override
public void bulk(BulkRequest request, ActionListener<BulkResponse> listener) {
if (false == expectedHeaders.equals(threadPool().getThreadContext().getHeaders())) {
listener.onFailure(
new RuntimeException("Expected " + expectedHeaders + " but got " + threadPool().getThreadContext().getHeaders()));
return;
}

// do everything synchronously, that's fine for a test
boolean shouldFail = numberOfCallsToFail > 0;
numberOfCallsToFail--;
Expand Down
Expand Up @@ -81,10 +81,12 @@ public void doStart(Consumer<? super Response> onResponse) {

@Override
protected void doStartNextScroll(String scrollId, TimeValue extraKeepAlive, Consumer<? super Response> onResponse) {
SearchScrollRequest request = new SearchScrollRequest();
// Add the wait time into the scroll timeout so it won't timeout while we wait for throttling
request.scrollId(scrollId).scroll(timeValueNanos(firstSearchRequest.scroll().keepAlive().nanos() + extraKeepAlive.nanos()));
searchWithRetry(listener -> client.searchScroll(request, listener), r -> consume(r, onResponse));
searchWithRetry(listener -> {
SearchScrollRequest request = new SearchScrollRequest();
// Add the wait time into the scroll timeout so it won't timeout while we wait for throttling
request.scrollId(scrollId).scroll(timeValueNanos(firstSearchRequest.scroll().keepAlive().nanos() + extraKeepAlive.nanos()));
client.searchScroll(request, listener);
}, r -> consume(r, onResponse));
}

@Override
Expand Down Expand Up @@ -128,6 +130,10 @@ private void searchWithRetry(Consumer<ActionListener<SearchResponse>> action, Co
*/
class RetryHelper extends AbstractRunnable implements ActionListener<SearchResponse> {
private final Iterator<TimeValue> retries = backoffPolicy.iterator();
/**
* The runnable to run that retries in the same context as the original call.
*/
private Runnable retryWithContext;
private volatile int retryCount = 0;

@Override
Expand All @@ -148,7 +154,7 @@ public void onFailure(Exception e) {
TimeValue delay = retries.next();
logger.trace((Supplier<?>) () -> new ParameterizedMessage("retrying rejected search after [{}]", delay), e);
countSearchRetry.run();
threadPool.schedule(delay, ThreadPool.Names.SAME, this);
threadPool.schedule(delay, ThreadPool.Names.SAME, retryWithContext);
} else {
logger.warn(
(Supplier<?>) () -> new ParameterizedMessage(
Expand All @@ -161,7 +167,10 @@ public void onFailure(Exception e) {
}
}
}
new RetryHelper().run();
RetryHelper helper = new RetryHelper();
// Wrap the helper in a runnable that preserves the current context so we keep it on retry.
helper.retryWithContext = threadPool.getThreadContext().preserveContext(helper);
helper.run();
}

private void consume(SearchResponse response, Consumer<? super Response> onResponse) {
Expand Down
Expand Up @@ -80,6 +80,7 @@

import java.util.ArrayList;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand All @@ -95,8 +96,10 @@
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static java.util.Collections.newSetFromMap;
import static java.util.Collections.singleton;
import static java.util.Collections.singletonList;
import static java.util.Collections.synchronizedSet;
import static org.apache.lucene.util.TestUtil.randomSimpleString;
import static org.elasticsearch.action.bulk.BackoffPolicy.constantBackoff;
import static org.elasticsearch.common.unit.TimeValue.timeValueMillis;
Expand All @@ -114,7 +117,6 @@

public class AsyncBulkByScrollActionTests extends ESTestCase {
private MyMockClient client;
private ThreadPool threadPool;
private DummyAbstractBulkByScrollRequest testRequest;
private SearchRequest firstSearchRequest;
private PlainActionFuture<BulkIndexByScrollResponse> listener;
Expand All @@ -127,28 +129,34 @@ public class AsyncBulkByScrollActionTests extends ESTestCase {

@Before
public void setupForTest() {
client = new MyMockClient(new NoOpClient(getTestName()));
threadPool = new TestThreadPool(getTestName());
// Fill the context with something random so we can make sure we inherited it appropriately.
expectedHeaders.clear();
expectedHeaders.put(randomSimpleString(random()), randomSimpleString(random()));

setupClient(new TestThreadPool(getTestName()));
firstSearchRequest = new SearchRequest();
testRequest = new DummyAbstractBulkByScrollRequest(firstSearchRequest);
listener = new PlainActionFuture<>();
scrollId = null;
taskManager = new TaskManager(Settings.EMPTY);
testTask = (WorkingBulkByScrollTask) taskManager.register("don'tcare", "hereeither", testRequest);

// Fill the context with something random so we can make sure we inherited it appropriately.
expectedHeaders.clear();
expectedHeaders.put(randomSimpleString(random()), randomSimpleString(random()));
threadPool.getThreadContext().newStoredContext();
threadPool.getThreadContext().putHeader(expectedHeaders);
localNode = new DiscoveryNode("thenode", new LocalTransportAddress("dead.end:666"), emptyMap(), emptySet(), Version.CURRENT);
taskId = new TaskId(localNode.getId(), testTask.getId());
}

private void setupClient(ThreadPool threadPool) {
if (client != null) {
client.close();
}
client = new MyMockClient(new NoOpClient(threadPool));
client.threadPool().getThreadContext().newStoredContext();
client.threadPool().getThreadContext().putHeader(expectedHeaders);
}

@After
public void tearDownAndVerifyCommonStuff() {
client.close();
threadPool.shutdown();
}

/**
Expand Down Expand Up @@ -237,12 +245,6 @@ public void testScrollResponseBatchingBehavior() throws Exception {
// Use assert busy because the update happens on another thread
final int expectedBatches = batches;
assertBusy(() -> assertEquals(expectedBatches, testTask.getStatus().getBatches()));

/*
* Also while we're here check that we preserved the headers from the last request. assertBusy because no requests might have
* come in yet.
*/
assertBusy(() -> assertEquals(expectedHeaders, client.lastHeaders.get()));
}
}

Expand Down Expand Up @@ -301,8 +303,7 @@ public void testBulkResponseSetsLotsOfStatus() {
*/
public void testThreadPoolRejectionsAbortRequest() throws Exception {
testTask.rethrottle(1);
threadPool.shutdown();
threadPool = new TestThreadPool(getTestName()) {
setupClient(new TestThreadPool(getTestName()) {
@Override
public ScheduledFuture<?> schedule(TimeValue delay, String name, Runnable command) {
// While we're here we can check that the sleep made it through
Expand All @@ -311,7 +312,7 @@ public ScheduledFuture<?> schedule(TimeValue delay, String name, Runnable comman
((AbstractRunnable) command).onRejection(new EsRejectedExecutionException("test"));
return null;
}
};
});
ScrollableHitSource.Response response = new ScrollableHitSource.Response(false, emptyList(), 0, emptyList(), null);
simulateScrollResponse(new DummyAbstractAsyncBulkByScrollAction(), timeValueNanos(System.nanoTime()), 10, response);
ExecutionException e = expectThrows(ExecutionException.class, () -> listener.get());
Expand Down Expand Up @@ -413,15 +414,14 @@ public void testScrollDelay() throws Exception {
*/
AtomicReference<TimeValue> capturedDelay = new AtomicReference<>();
AtomicReference<Runnable> capturedCommand = new AtomicReference<>();
threadPool.shutdown();
threadPool = new TestThreadPool(getTestName()) {
setupClient(new TestThreadPool(getTestName()) {
@Override
public ScheduledFuture<?> schedule(TimeValue delay, String name, Runnable command) {
capturedDelay.set(delay);
capturedCommand.set(command);
return null;
}
};
});

DummyAbstractAsyncBulkByScrollAction action = new DummyAbstractAsyncBulkByScrollAction();
action.setScroll(scrollId());
Expand Down Expand Up @@ -493,7 +493,7 @@ void startNextScroll(TimeValue lastBatchStartTime, int lastBatchSize) {
assertThat(response.getSearchFailures(), empty());
assertNull(response.getReasonCancelled());
} else {
successLatch.await(10, TimeUnit.SECONDS);
assertTrue(successLatch.await(10, TimeUnit.SECONDS));
}
}

Expand Down Expand Up @@ -589,8 +589,7 @@ public void testCancelWhileDelayedAfterScrollResponse() throws Exception {
* Replace the thread pool with one that will cancel the task as soon as anything is scheduled, which reindex tries to do when there
* is a delay.
*/
threadPool.shutdown();
threadPool = new TestThreadPool(getTestName()) {
setupClient(new TestThreadPool(getTestName()) {
@Override
public ScheduledFuture<?> schedule(TimeValue delay, String name, Runnable command) {
/*
Expand All @@ -605,7 +604,7 @@ public ScheduledFuture<?> schedule(TimeValue delay, String name, Runnable comman
}
return super.schedule(delay, name, command);
}
};
});

// Send the scroll response which will trigger the custom thread pool above, canceling the request before running the response
DummyAbstractAsyncBulkByScrollAction action = new DummyAbstractAsyncBulkByScrollAction();
Expand Down Expand Up @@ -656,7 +655,7 @@ private class DummyAbstractAsyncBulkByScrollAction
extends AbstractAsyncBulkByScrollAction<DummyAbstractBulkByScrollRequest> {
public DummyAbstractAsyncBulkByScrollAction() {
super(testTask, AsyncBulkByScrollActionTests.this.logger, new ParentTaskAssigningClient(client, localNode, testTask),
AsyncBulkByScrollActionTests.this.threadPool, testRequest, listener);
client.threadPool(), testRequest, listener);
}

@Override
Expand Down Expand Up @@ -702,7 +701,6 @@ private class MyMockClient extends FilterClient {
private final AtomicInteger bulksAttempts = new AtomicInteger();
private final AtomicInteger searchAttempts = new AtomicInteger();
private final AtomicInteger scrollAttempts = new AtomicInteger();
private final AtomicReference<Map<String, String>> lastHeaders = new AtomicReference<>();
private final AtomicReference<RefreshRequest> lastRefreshRequest = new AtomicReference<>();
/**
* Last search attempt that wasn't rejected outright.
Expand All @@ -712,7 +710,10 @@ private class MyMockClient extends FilterClient {
* Last scroll attempt that wasn't rejected outright.
*/
private final AtomicReference<RequestAndListener<SearchScrollRequest, SearchResponse>> lastScroll = new AtomicReference<>();

/**
* Set of all scrolls we've already used. Used to check that we don't reuse the same request twice.
*/
private final Set<SearchScrollRequest> usedScolls = synchronizedSet(newSetFromMap(new IdentityHashMap<>()));

private int bulksToReject = 0;
private int searchesToReject = 0;
Expand All @@ -727,7 +728,12 @@ public MyMockClient(Client in) {
protected <Request extends ActionRequest, Response extends ActionResponse,
RequestBuilder extends ActionRequestBuilder<Request, Response, RequestBuilder>> void doExecute(
Action<Request, Response, RequestBuilder> action, Request request, ActionListener<Response> listener) {
lastHeaders.set(threadPool.getThreadContext().getHeaders());
if (false == expectedHeaders.equals(threadPool().getThreadContext().getHeaders())) {
listener.onFailure(
new RuntimeException("Expected " + expectedHeaders + " but got " + threadPool().getThreadContext().getHeaders()));
return;
}

if (request instanceof ClearScrollRequest) {
assertEquals(TaskId.EMPTY_TASK_ID, request.getParentTask());
} else {
Expand All @@ -747,11 +753,14 @@ RequestBuilder extends ActionRequestBuilder<Request, Response, RequestBuilder>>
return;
}
if (request instanceof SearchScrollRequest) {
SearchScrollRequest scroll = (SearchScrollRequest) request;
boolean newRequest = usedScolls.add(scroll);
assertTrue("We can't reuse scroll requests", newRequest);
if (scrollAttempts.incrementAndGet() <= scrollsToReject) {
listener.onFailure(wrappedRejectedException());
return;
}
lastScroll.set(new RequestAndListener<>((SearchScrollRequest) request, (ActionListener<SearchResponse>) listener));
lastScroll.set(new RequestAndListener<>(scroll, (ActionListener<SearchResponse>) listener));
return;
}
if (request instanceof ClearScrollRequest) {
Expand Down
Expand Up @@ -32,8 +32,20 @@

import java.util.concurrent.TimeUnit;

/**
* Client that always responds with {@code null} to every request. Override this for testing.
*/
public class NoOpClient extends AbstractClient {
/**
* Build with {@link ThreadPool}. This {@linkplain ThreadPool} is terminated on {@link #close()}.
*/
public NoOpClient(ThreadPool threadPool) {
super(Settings.EMPTY, threadPool);
}

/**
* Create a new {@link TestThreadPool} for this client.
*/
public NoOpClient(String testName) {
super(Settings.EMPTY, new TestThreadPool(testName));
}
Expand Down

0 comments on commit 248021e

Please sign in to comment.