Skip to content

Commit

Permalink
[7.x] [ML][Transform] reset failure count when a transform aggregatio…
Browse files Browse the repository at this point in the history
…n page is handled successfully (#76355) (#76365)

* [ML][Transform] reset failure count when a transform aggregation page is handled successfully (#76355)

Failure count should not only be reset at checkpoints. Checkpoints could have many pages of data. Consequently, we should reset the failure count once we handle a single composite aggregation page.

This way, the transform won't mark itself as failed erroneously when it has actually succeeded searches + indexing results within the same checkpoint.

closes #76074

* fixing compilation
  • Loading branch information
benwtrent committed Aug 11, 2021
1 parent 28d6c89 commit a87610e
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,71 +143,72 @@ protected void doNextBulk(BulkRequest request, ActionListener<BulkResponse> next
client,
BulkAction.INSTANCE,
request,
ActionListener.wrap(bulkResponse -> {
if (bulkResponse.hasFailures()) {
int failureCount = 0;
// dedup the failures by the type of the exception, as they most likely have the same cause
Map<String, BulkItemResponse> deduplicatedFailures = new LinkedHashMap<>();

for (BulkItemResponse item : bulkResponse.getItems()) {
if (item.isFailed()) {
deduplicatedFailures.putIfAbsent(item.getFailure().getCause().getClass().getSimpleName(), item);
failureCount++;
}
}

// note: bulk failures are audited/logged in {@link TransformIndexer#handleFailure(Exception)}

// This calls AsyncTwoPhaseIndexer#finishWithIndexingFailure
// Determine whether the failure is irrecoverable (transform should go into failed state) or not (transform increments
// the indexing failure counter
// and possibly retries)
Throwable irrecoverableException = ExceptionRootCauseFinder.getFirstIrrecoverableExceptionFromBulkResponses(
deduplicatedFailures.values()
);
if (irrecoverableException == null) {
String failureMessage = getBulkIndexDetailedFailureMessage("Significant failures: ", deduplicatedFailures);
logger.debug("[{}] Bulk index experienced [{}] failures. {}", getJobId(), failureCount, failureMessage);

Exception firstException = deduplicatedFailures.values().iterator().next().getFailure().getCause();
nextPhase.onFailure(
new BulkIndexingException(
"Bulk index experienced [{}] failures. {}",
firstException,
false,
failureCount,
failureMessage
)
);
} else {
deduplicatedFailures.remove(irrecoverableException.getClass().getSimpleName());
String failureMessage = getBulkIndexDetailedFailureMessage("Other failures: ", deduplicatedFailures);
irrecoverableException = decorateBulkIndexException(irrecoverableException);
ActionListener.wrap(bulkResponse -> handleBulkResponse(bulkResponse, nextPhase), nextPhase::onFailure)
);
}

logger.debug(
"[{}] Bulk index experienced [{}] failures and at least 1 irrecoverable [{}]. {}",
getJobId(),
failureCount,
ExceptionRootCauseFinder.getDetailedMessage(irrecoverableException),
failureMessage
);
protected void handleBulkResponse(BulkResponse bulkResponse, ActionListener<BulkResponse> nextPhase) {
if (bulkResponse.hasFailures() == false) {
// We don't know the of failures that have occurred (searching, processing, indexing, etc.),
// but if we search, process and bulk index then we have
// successfully processed an entire page of the transform and should reset the counter, even if we are in the middle
// of a checkpoint
context.resetReasonAndFailureCounter();
nextPhase.onResponse(bulkResponse);
return;
}
int failureCount = 0;
// dedup the failures by the type of the exception, as they most likely have the same cause
Map<String, BulkItemResponse> deduplicatedFailures = new LinkedHashMap<>();

for (BulkItemResponse item : bulkResponse.getItems()) {
if (item.isFailed()) {
deduplicatedFailures.putIfAbsent(item.getFailure().getCause().getClass().getSimpleName(), item);
failureCount++;
}
}

nextPhase.onFailure(
new BulkIndexingException(
"Bulk index experienced [{}] failures and at least 1 irrecoverable [{}]. {}",
irrecoverableException,
true,
failureCount,
ExceptionRootCauseFinder.getDetailedMessage(irrecoverableException),
failureMessage
)
);
}
} else {
nextPhase.onResponse(bulkResponse);
}
}, nextPhase::onFailure)
// note: bulk failures are audited/logged in {@link TransformIndexer#handleFailure(Exception)}

// This calls AsyncTwoPhaseIndexer#finishWithIndexingFailure
// Determine whether the failure is irrecoverable (transform should go into failed state) or not (transform increments
// the indexing failure counter
// and possibly retries)
Throwable irrecoverableException = ExceptionRootCauseFinder.getFirstIrrecoverableExceptionFromBulkResponses(
deduplicatedFailures.values()
);
if (irrecoverableException == null) {
String failureMessage = getBulkIndexDetailedFailureMessage("Significant failures: ", deduplicatedFailures);
logger.debug("[{}] Bulk index experienced [{}] failures. {}", getJobId(), failureCount, failureMessage);

Exception firstException = deduplicatedFailures.values().iterator().next().getFailure().getCause();
nextPhase.onFailure(
new BulkIndexingException("Bulk index experienced [{}] failures. {}", firstException, false, failureCount, failureMessage)
);
} else {
deduplicatedFailures.remove(irrecoverableException.getClass().getSimpleName());
String failureMessage = getBulkIndexDetailedFailureMessage("Other failures: ", deduplicatedFailures);
irrecoverableException = decorateBulkIndexException(irrecoverableException);

logger.debug(
"[{}] Bulk index experienced [{}] failures and at least 1 irrecoverable [{}]. {}",
getJobId(),
failureCount,
ExceptionRootCauseFinder.getDetailedMessage(irrecoverableException),
failureMessage
);

nextPhase.onFailure(
new BulkIndexingException(
"Bulk index experienced [{}] failures and at least 1 irrecoverable [{}]. {}",
irrecoverableException,
true,
failureCount,
ExceptionRootCauseFinder.getDetailedMessage(irrecoverableException),
failureMessage
)
);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public TransformIndexer(
this.checkpointProvider = ExceptionsHelper.requireNonNull(checkpointProvider, "checkpointProvider");
this.auditor = transformServices.getAuditor();
this.transformConfig = ExceptionsHelper.requireNonNull(transformConfig, "transformConfig");
this.progress = progress != null ? progress : new TransformProgress();
this.progress = transformProgress != null ? transformProgress : new TransformProgress();
this.lastCheckpoint = ExceptionsHelper.requireNonNull(lastCheckpoint, "lastCheckpoint");
this.nextCheckpoint = ExceptionsHelper.requireNonNull(nextCheckpoint, "nextCheckpoint");
this.context = ExceptionsHelper.requireNonNull(context, "context");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.elasticsearch.xpack.transform.notifications.MockTransformAuditor;
import org.elasticsearch.xpack.transform.notifications.TransformAuditor;
import org.elasticsearch.xpack.transform.persistence.IndexBasedTransformConfigManager;
import org.elasticsearch.xpack.transform.persistence.SeqNoPrimaryTermAndIndex;
import org.junit.After;
import org.junit.Before;

Expand All @@ -63,6 +64,7 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
Expand Down Expand Up @@ -90,7 +92,7 @@ public class TransformIndexerFailureHandlingTests extends ESTestCase {
private Client client;
private ThreadPool threadPool;

class MockedTransformIndexer extends TransformIndexer {
static class MockedTransformIndexer extends ClientTransformIndexer {

private final Function<SearchRequest, SearchResponse> searchFunction;
private final Function<BulkRequest, BulkResponse> bulkFunction;
Expand Down Expand Up @@ -126,14 +128,17 @@ class MockedTransformIndexer extends TransformIndexer {
mock(SchedulerEngine.class)
),
checkpointProvider,
transformConfig,
initialState,
initialPosition,
mock(Client.class),
jobStats,
transformConfig,
/* TransformProgress */ null,
TransformCheckpoint.EMPTY,
TransformCheckpoint.EMPTY,
context
new SeqNoPrimaryTermAndIndex(1, 1, "foo"),
context,
false
);
this.searchFunction = searchFunction;
this.bulkFunction = bulkFunction;
Expand Down Expand Up @@ -188,7 +193,7 @@ protected void doNextBulk(BulkRequest request, ActionListener<BulkResponse> next

try {
BulkResponse response = bulkFunction.apply(request);
nextPhase.onResponse(response);
super.handleBulkResponse(response, nextPhase);
} catch (Exception e) {
nextPhase.onFailure(e);
}
Expand Down Expand Up @@ -253,7 +258,7 @@ void doGetInitialProgress(SearchRequest request, ActionListener<SearchResponse>
}

@Override
void doDeleteByQuery(DeleteByQueryRequest deleteByQueryRequest, ActionListener<BulkByScrollResponse> responseListener) {
protected void doDeleteByQuery(DeleteByQueryRequest deleteByQueryRequest, ActionListener<BulkByScrollResponse> responseListener) {
try {
BulkByScrollResponse response = deleteByQueryFunction.apply(deleteByQueryRequest);
responseListener.onResponse(response);
Expand All @@ -263,7 +268,7 @@ void doDeleteByQuery(DeleteByQueryRequest deleteByQueryRequest, ActionListener<B
}

@Override
void refreshDestinationIndex(ActionListener<RefreshResponse> responseListener) {
protected void refreshDestinationIndex(ActionListener<RefreshResponse> responseListener) {
responseListener.onResponse(new RefreshResponse(1, 1, 0, Collections.emptyList()));
}

Expand Down Expand Up @@ -705,6 +710,116 @@ public void testRetentionPolicyDeleteByQueryThrowsTemporaryProblem() throws Exce
assertEquals(1, context.getFailureCount());
}

public void testFailureCounterIsResetOnSuccess() throws Exception {
String transformId = randomAlphaOfLength(10);
TransformConfig config = new TransformConfig(
transformId,
randomSourceConfig(),
randomDestConfig(),
null,
null,
null,
randomPivotConfig(),
null,
randomBoolean() ? null : randomAlphaOfLengthBetween(1, 1000),
null,
null,
null,
null
);

final SearchResponse searchResponse = new SearchResponse(
new InternalSearchResponse(
new SearchHits(new SearchHit[] { new SearchHit(1) }, new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1.0f),
// Simulate completely null aggs
null,
new Suggest(Collections.emptyList()),
new SearchProfileShardResults(Collections.emptyMap()),
false,
false,
1
),
"",
1,
1,
0,
0,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
);

AtomicReference<IndexerState> state = new AtomicReference<>(IndexerState.STOPPED);
Function<SearchRequest, SearchResponse> searchFunction = new Function<SearchRequest, SearchResponse>() {
final AtomicInteger calls = new AtomicInteger(0);

@Override
public SearchResponse apply(SearchRequest searchRequest) {
int call = calls.getAndIncrement();
if (call == 0) {
throw new SearchPhaseExecutionException(
"query",
"Partial shards failure",
new ShardSearchFailure[] { new ShardSearchFailure(new Exception()) }
);
}
return searchResponse;
}
};

Function<BulkRequest, BulkResponse> bulkFunction = request -> new BulkResponse(new BulkItemResponse[0], 1);

final AtomicBoolean failIndexerCalled = new AtomicBoolean(false);
final AtomicReference<String> failureMessage = new AtomicReference<>();
Consumer<String> failureConsumer = message -> {
failIndexerCalled.compareAndSet(false, true);
failureMessage.compareAndSet(null, message);
};

MockTransformAuditor auditor = MockTransformAuditor.createMockAuditor();
TransformContext.Listener contextListener = mock(TransformContext.Listener.class);
TransformContext context = new TransformContext(TransformTaskState.STARTED, "", 0, contextListener);

MockedTransformIndexer indexer = createMockIndexer(
config,
state,
searchFunction,
bulkFunction,
null,
failureConsumer,
threadPool,
ThreadPool.Names.GENERIC,
auditor,
context
);

final CountDownLatch latch = indexer.newLatch(1);

indexer.start();
assertThat(indexer.getState(), equalTo(IndexerState.STARTED));
assertTrue(indexer.maybeTriggerAsyncJob(System.currentTimeMillis()));
assertThat(indexer.getState(), equalTo(IndexerState.INDEXING));

latch.countDown();
assertBusy(() -> assertThat(indexer.getState(), equalTo(IndexerState.STARTED)), 10, TimeUnit.SECONDS);
assertFalse(failIndexerCalled.get());
assertThat(indexer.getState(), equalTo(IndexerState.STARTED));
assertEquals(1, context.getFailureCount());

final CountDownLatch secondLatch = indexer.newLatch(1);

indexer.start();
assertThat(indexer.getState(), equalTo(IndexerState.STARTED));
assertTrue(indexer.maybeTriggerAsyncJob(System.currentTimeMillis()));
assertThat(indexer.getState(), equalTo(IndexerState.INDEXING));

secondLatch.countDown();
assertBusy(() -> assertThat(indexer.getState(), equalTo(IndexerState.STARTED)), 10, TimeUnit.SECONDS);
assertFalse(failIndexerCalled.get());
assertThat(indexer.getState(), equalTo(IndexerState.STARTED));
auditor.assertAllExpectationsMatched();
assertEquals(0, context.getFailureCount());
}

private MockedTransformIndexer createMockIndexer(
TransformConfig config,
AtomicReference<IndexerState> state,
Expand Down

0 comments on commit a87610e

Please sign in to comment.