Skip to content

Commit

Permalink
[7.9][ML] Ensure bulk requests are not over memory limit (#60219) (#6…
Browse files Browse the repository at this point in the history
…0284)

Data frame analytics jobs that work with very large datasets
may produce bulk requests that are over the memory limit
for indexing. This commit adds a helper class that bundles
index requests in bulk requests that steer away from the
memory limit. We then use this class both from the results
joiner and the inference runner ensuring data frame analytics
jobs do not generate bulk requests that are too large.

Note the limit was implemented in #58885.

Backport of #60219
  • Loading branch information
dimitris-athanasiou committed Jul 28, 2020
1 parent 6ae66b1 commit 0ccdcf9
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,9 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
this.modelLoadingService.set(modelLoadingService);

// Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client,
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(
settings,
client,
threadPool,
analyticsProcessFactory,
dataFrameAnalyticsAuditor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ClientHelper;
Expand All @@ -27,6 +28,7 @@
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.utils.persistence.LimitAwareBulkIndexer;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.util.Deque;
Expand All @@ -40,8 +42,8 @@ public class InferenceRunner {
private static final Logger LOGGER = LogManager.getLogger(InferenceRunner.class);

private static final int MAX_PROGRESS_BEFORE_COMPLETION = 98;
private static final int RESULTS_BATCH_SIZE = 1000;

private final Settings settings;
private final Client client;
private final ModelLoadingService modelLoadingService;
private final ResultsPersisterService resultsPersisterService;
Expand All @@ -52,9 +54,10 @@ public class InferenceRunner {
private final DataCountsTracker dataCountsTracker;
private volatile boolean isCancelled;

public InferenceRunner(Client client, ModelLoadingService modelLoadingService, ResultsPersisterService resultsPersisterService,
TaskId parentTaskId, DataFrameAnalyticsConfig config, ExtractedFields extractedFields,
ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) {
public InferenceRunner(Settings settings, Client client, ModelLoadingService modelLoadingService,
ResultsPersisterService resultsPersisterService, TaskId parentTaskId, DataFrameAnalyticsConfig config,
ExtractedFields extractedFields, ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) {
this.settings = Objects.requireNonNull(settings);
this.client = Objects.requireNonNull(client);
this.modelLoadingService = Objects.requireNonNull(modelLoadingService);
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
Expand Down Expand Up @@ -92,36 +95,29 @@ public void run(String modelId) {
void inferTestDocs(LocalModel model, TestDocsIterator testDocsIterator) {
long totalDocCount = 0;
long processedDocCount = 0;
BulkRequest bulkRequest = new BulkRequest();

while (testDocsIterator.hasNext()) {
if (isCancelled) {
break;
}
try (LimitAwareBulkIndexer bulkIndexer = new LimitAwareBulkIndexer(settings, this::executeBulkRequest)) {
while (testDocsIterator.hasNext()) {
if (isCancelled) {
break;
}

Deque<SearchHit> batch = testDocsIterator.next();
Deque<SearchHit> batch = testDocsIterator.next();

if (totalDocCount == 0) {
totalDocCount = testDocsIterator.getTotalHits();
}
if (totalDocCount == 0) {
totalDocCount = testDocsIterator.getTotalHits();
}

for (SearchHit doc : batch) {
dataCountsTracker.incrementTestDocsCount();
InferenceResults inferenceResults = model.inferNoStats(featuresFromDoc(doc));
bulkRequest.add(createIndexRequest(doc, inferenceResults, config.getDest().getResultsField()));
for (SearchHit doc : batch) {
dataCountsTracker.incrementTestDocsCount();
InferenceResults inferenceResults = model.inferNoStats(featuresFromDoc(doc));
bulkIndexer.addAndExecuteIfNeeded(createIndexRequest(doc, inferenceResults, config.getDest().getResultsField()));

processedDocCount++;
int progressPercent = Math.min((int) (processedDocCount * 100.0 / totalDocCount), MAX_PROGRESS_BEFORE_COMPLETION);
progressTracker.updateInferenceProgress(progressPercent);
processedDocCount++;
int progressPercent = Math.min((int) (processedDocCount * 100.0 / totalDocCount), MAX_PROGRESS_BEFORE_COMPLETION);
progressTracker.updateInferenceProgress(progressPercent);
}
}

if (bulkRequest.numberOfActions() == RESULTS_BATCH_SIZE) {
executeBulkRequest(bulkRequest);
bulkRequest = new BulkRequest();
}
}
if (bulkRequest.numberOfActions() > 0 && isCancelled == false) {
executeBulkRequest(bulkRequest);
}

if (isCancelled == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
Expand Down Expand Up @@ -61,6 +62,7 @@ public class AnalyticsProcessManager {

private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class);

private final Settings settings;
private final Client client;
private final ExecutorService executorServiceForJob;
private final ExecutorService executorServiceForProcess;
Expand All @@ -72,7 +74,8 @@ public class AnalyticsProcessManager {
private final ResultsPersisterService resultsPersisterService;
private final int numAllocatedProcessors;

public AnalyticsProcessManager(Client client,
public AnalyticsProcessManager(Settings settings,
Client client,
ThreadPool threadPool,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor,
Expand All @@ -81,6 +84,7 @@ public AnalyticsProcessManager(Client client,
ResultsPersisterService resultsPersisterService,
int numAllocatedProcessors) {
this(
settings,
client,
threadPool.generic(),
threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME),
Expand All @@ -93,7 +97,8 @@ public AnalyticsProcessManager(Client client,
}

// Visible for testing
public AnalyticsProcessManager(Client client,
public AnalyticsProcessManager(Settings settings,
Client client,
ExecutorService executorServiceForJob,
ExecutorService executorServiceForProcess,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
Expand All @@ -102,6 +107,7 @@ public AnalyticsProcessManager(Client client,
ModelLoadingService modelLoadingService,
ResultsPersisterService resultsPersisterService,
int numAllocatedProcessors) {
this.settings = Objects.requireNonNull(settings);
this.client = Objects.requireNonNull(client);
this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
Expand Down Expand Up @@ -330,7 +336,7 @@ private void runInference(ParentTaskAssigningClient parentTaskClient, DataFrameA

if (processContext.config.getAnalysis().supportsInference()) {
refreshDest(parentTaskClient, processContext.config);
InferenceRunner inferenceRunner = new InferenceRunner(parentTaskClient, modelLoadingService, resultsPersisterService,
InferenceRunner inferenceRunner = new InferenceRunner(settings, parentTaskClient, modelLoadingService, resultsPersisterService,
task.getParentTaskId(), processContext.config, extractedFields, task.getStatsHolder().getProgressTracker(),
task.getStatsHolder().getDataCountsTracker());
processContext.setInferenceRunner(inferenceRunner);
Expand Down Expand Up @@ -489,7 +495,7 @@ private AnalyticsProcessConfig createProcessConfig(DataFrameDataExtractor dataEx
private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask task,
DataFrameDataExtractorFactory dataExtractorFactory) {
DataFrameRowsJoiner dataFrameRowsJoiner =
new DataFrameRowsJoiner(config.getId(), task.getParentTaskId(),
new DataFrameRowsJoiner(config.getId(), settings, task.getParentTaskId(),
dataExtractorFactory.newExtractor(true), resultsPersisterService);
return new AnalyticsResultProcessor(
config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, statsPersister,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.utils.persistence.LimitAwareBulkIndexer;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.io.IOException;
Expand All @@ -36,6 +38,7 @@ class DataFrameRowsJoiner implements AutoCloseable {
private static final int RESULTS_BATCH_SIZE = 1000;

private final String analyticsId;
private final Settings settings;
private final TaskId parentTaskId;
private final DataFrameDataExtractor dataExtractor;
private final ResultsPersisterService resultsPersisterService;
Expand All @@ -44,9 +47,10 @@ class DataFrameRowsJoiner implements AutoCloseable {
private volatile String failure;
private volatile boolean isCancelled;

DataFrameRowsJoiner(String analyticsId, TaskId parentTaskId, DataFrameDataExtractor dataExtractor,
DataFrameRowsJoiner(String analyticsId, Settings settings, TaskId parentTaskId, DataFrameDataExtractor dataExtractor,
ResultsPersisterService resultsPersisterService) {
this.analyticsId = Objects.requireNonNull(analyticsId);
this.settings = Objects.requireNonNull(settings);
this.parentTaskId = Objects.requireNonNull(parentTaskId);
this.dataExtractor = Objects.requireNonNull(dataExtractor);
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
Expand Down Expand Up @@ -86,25 +90,28 @@ private void addResultAndJoinIfEndOfBatch(RowResults rowResults) {
}

private void joinCurrentResults() {
BulkRequest bulkRequest = new BulkRequest();
while (currentResults.isEmpty() == false) {
RowResults result = currentResults.pop();
DataFrameDataExtractor.Row row = dataFrameRowsIterator.next();
checkChecksumsMatch(row, result);
bulkRequest.add(createIndexRequest(result, row.getHit()));
}
if (bulkRequest.numberOfActions() > 0) {
bulkRequest.setParentTask(parentTaskId);
resultsPersisterService.bulkIndexWithHeadersWithRetry(
dataExtractor.getHeaders(),
bulkRequest,
analyticsId,
() -> isCancelled == false,
errorMsg -> {});
try (LimitAwareBulkIndexer bulkIndexer = new LimitAwareBulkIndexer(settings, this::executeBulkRequest)) {
while (currentResults.isEmpty() == false) {
RowResults result = currentResults.pop();
DataFrameDataExtractor.Row row = dataFrameRowsIterator.next();
checkChecksumsMatch(row, result);
bulkIndexer.addAndExecuteIfNeeded(createIndexRequest(result, row.getHit()));
}
}

currentResults = new LinkedList<>();
}

private void executeBulkRequest(BulkRequest bulkRequest) {
bulkRequest.setParentTask(parentTaskId);
resultsPersisterService.bulkIndexWithHeadersWithRetry(
dataExtractor.getHeaders(),
bulkRequest,
analyticsId,
() -> isCancelled == false,
errorMsg -> {});
}

private void checkChecksumsMatch(DataFrameDataExtractor.Row row, RowResults result) {
if (row.getChecksum() != result.getChecksum()) {
String msg = "Detected checksum mismatch for document with id [" + row.getHit().getId() + "]; ";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.ml.utils.persistence;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexingPressure;

import java.util.Objects;
import java.util.function.Consumer;

/**
* A helper class that gathers index requests in bulk requests
* that do exceed a 1000 operations or half the available memory
* limit for indexing.
*/
public class LimitAwareBulkIndexer implements AutoCloseable {

private static final Logger LOGGER = LogManager.getLogger(LimitAwareBulkIndexer.class);

private static final int BATCH_SIZE = 1000;

private final long bytesLimit;
private final Consumer<BulkRequest> executor;
private BulkRequest currentBulkRequest = new BulkRequest();
private long currentRamBytes;

public LimitAwareBulkIndexer(Settings settings, Consumer<BulkRequest> executor) {
this((long) Math.ceil(0.5 * IndexingPressure.MAX_INDEXING_BYTES.get(settings).getBytes()), executor);
}

LimitAwareBulkIndexer(long bytesLimit, Consumer<BulkRequest> executor) {
this.bytesLimit = bytesLimit;
this.executor = Objects.requireNonNull(executor);
}

public void addAndExecuteIfNeeded(IndexRequest indexRequest) {
if (currentRamBytes + indexRequest.ramBytesUsed() > bytesLimit || currentBulkRequest.numberOfActions() == BATCH_SIZE) {
execute();
}
currentBulkRequest.add(indexRequest);
currentRamBytes += indexRequest.ramBytesUsed();
}

private void execute() {
if (currentBulkRequest.numberOfActions() > 0) {
LOGGER.debug("Executing bulk request; current bytes [{}]; bytes limit [{}]; number of actions [{}]",
currentRamBytes, bytesLimit, currentBulkRequest.numberOfActions());
executor.accept(currentBulkRequest);
currentBulkRequest = new BulkRequest();
currentRamBytes = 0;
}
}

@Override
public void close() {
execute();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.search.SearchHit;
Expand Down Expand Up @@ -164,7 +165,7 @@ private LocalModel localModelInferences(InferenceResults first, InferenceResults
}

private InferenceRunner createInferenceRunner(ExtractedFields extractedFields) {
return new InferenceRunner(client, modelLoadingService, resultsPersisterService, parentTaskId, config, extractedFields,
progressTracker, new DataCountsTracker());
return new InferenceRunner(Settings.EMPTY, client, modelLoadingService, resultsPersisterService, parentTaskId, config,
extractedFields, progressTracker, new DataCountsTracker());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ public void setUpMocks() {

resultsPersisterService = mock(ResultsPersisterService.class);
modelLoadingService = mock(ModelLoadingService.class);
processManager = new AnalyticsProcessManager(client, executorServiceForJob, executorServiceForProcess, processFactory, auditor,
trainedModelProvider, modelLoadingService, resultsPersisterService, 1);
processManager = new AnalyticsProcessManager(Settings.EMPTY, client, executorServiceForJob, executorServiceForProcess,
processFactory, auditor, trainedModelProvider, modelLoadingService, resultsPersisterService, 1);
}

public void testRunJob_TaskIsStopping() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.tasks.TaskId;
Expand All @@ -31,6 +32,7 @@
import java.util.stream.IntStream;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -264,7 +266,10 @@ public void testProcess_GivenMoreResultsThanRows() throws IOException {
RowResults result2 = new RowResults(2, resultFields);
givenProcessResults(Arrays.asList(result1, result2));

verifyNoMoreInteractions(resultsPersisterService);
List<BulkRequest> capturedBulkRequests = bulkRequestCaptor.getAllValues();
assertThat(capturedBulkRequests, hasSize(1));
BulkRequest capturedBulkRequest = capturedBulkRequests.get(0);
assertThat(capturedBulkRequest.numberOfActions(), equalTo(1));
}

public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws IOException {
Expand All @@ -284,7 +289,8 @@ public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws
}

private void givenProcessResults(List<RowResults> results) {
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, new TaskId(""), dataExtractor, resultsPersisterService)) {
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, Settings.EMPTY, new TaskId(""), dataExtractor,
resultsPersisterService)) {
results.forEach(joiner::processRowResults);
}
}
Expand Down

0 comments on commit 0ccdcf9

Please sign in to comment.