Skip to content

Commit

Permalink
TSDB: Add time series aggs cancellation (#83492)
Browse files Browse the repository at this point in the history
Adds support for low-level cancelling time-series based aggregations before
they reach the reduce phase.

Relates to #74660
  • Loading branch information
imotov committed Feb 15, 2022
1 parent 1fe2b0d commit a89d4c3
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.LogManager;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
Expand All @@ -28,6 +29,8 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexMode;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.PluginsService;
import org.elasticsearch.rest.RestStatus;
Expand All @@ -36,13 +39,16 @@
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.timeseries.TimeSeriesAggregationBuilder;
import org.elasticsearch.search.lookup.LeafStoredFieldsLookup;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.transport.TransportService;
import org.junit.BeforeClass;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand All @@ -55,9 +61,12 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import static org.elasticsearch.index.IndexSettings.TIME_SERIES_END_TIME;
import static org.elasticsearch.index.IndexSettings.TIME_SERIES_START_TIME;
import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery;
import static org.elasticsearch.index.query.QueryBuilders.scriptQuery;
import static org.elasticsearch.search.SearchCancellationIT.ScriptedBlockPlugin.SEARCH_BLOCK_SCRIPT_NAME;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertFailures;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures;
import static org.hamcrest.Matchers.containsString;
Expand All @@ -69,14 +78,20 @@
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE)
public class SearchCancellationIT extends ESIntegTestCase {

private static boolean lowLevelCancellation;

@BeforeClass
public static void init() {
lowLevelCancellation = randomBoolean();
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singleton(ScriptedBlockPlugin.class);
}

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
boolean lowLevelCancellation = randomBoolean();
logger.info("Using lowLevelCancellation: {}", lowLevelCancellation);
return Settings.builder()
.put(super.nodeSettings(nodeOrdinal, otherSettings))
Expand Down Expand Up @@ -227,7 +242,12 @@ public void testCancellationDuringAggregation() throws Exception {
new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.COMBINE_SCRIPT_NAME, Collections.emptyMap())
)
.reduceScript(
new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.REDUCE_SCRIPT_NAME, Collections.emptyMap())
new Script(
ScriptType.INLINE,
"mockscript",
ScriptedBlockPlugin.REDUCE_BLOCK_SCRIPT_NAME,
Collections.emptyMap()
)
)
)
)
Expand All @@ -238,6 +258,80 @@ public void testCancellationDuringAggregation() throws Exception {
ensureSearchWasCancelled(searchResponse);
}

public void testCancellationDuringTimeSeriesAggregation() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory();
int numberOfShards = between(2, 5);
long now = Instant.now().toEpochMilli();
int numberOfRefreshes = between(1, 5);
int numberOfDocsPerRefresh = numberOfShards * between(1500, 2000) / numberOfRefreshes;
assertAcked(
prepareCreate("test").setSettings(
Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numberOfShards)
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
.put(IndexSettings.MODE.getKey(), IndexMode.TIME_SERIES.name())
.put(IndexMetadata.INDEX_ROUTING_PATH.getKey(), "dim")
.put(TIME_SERIES_START_TIME.getKey(), now)
.put(TIME_SERIES_END_TIME.getKey(), now + (long) numberOfRefreshes * numberOfDocsPerRefresh + 1)
.build()
).setMapping("""
{
"properties": {
"@timestamp": {"type": "date", "format": "epoch_millis"},
"dim": {"type": "keyword", "time_series_dimension": true}
}
}
""")
);

for (int i = 0; i < numberOfRefreshes; i++) {
// Make sure we sometimes have a few segments
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int j = 0; j < numberOfDocsPerRefresh; j++) {
bulkRequestBuilder.add(
client().prepareIndex("test")
.setOpType(DocWriteRequest.OpType.CREATE)
.setSource("@timestamp", now + (long) i * numberOfDocsPerRefresh + j, "val", (double) j, "dim", String.valueOf(i))
);
}
assertNoFailures(bulkRequestBuilder.get());
}

logger.info("Executing search");
TimeSeriesAggregationBuilder timeSeriesAggregationBuilder = new TimeSeriesAggregationBuilder("test_agg");
ActionFuture<SearchResponse> searchResponse = client().prepareSearch("test")
.setQuery(matchAllQuery())
.addAggregation(
timeSeriesAggregationBuilder.subAggregation(
new ScriptedMetricAggregationBuilder("sub_agg").initScript(
new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.INIT_SCRIPT_NAME, Collections.emptyMap())
)
.mapScript(
new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.MAP_BLOCK_SCRIPT_NAME, Collections.emptyMap())
)
.combineScript(
new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.COMBINE_SCRIPT_NAME, Collections.emptyMap())
)
.reduceScript(
new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.REDUCE_FAIL_SCRIPT_NAME, Collections.emptyMap())
)
)
)
.execute();
awaitForBlock(plugins);
cancelSearch(SearchAction.NAME);
disableBlocks(plugins);

SearchPhaseExecutionException ex = expectThrows(SearchPhaseExecutionException.class, searchResponse::actionGet);
assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.BAD_REQUEST));
logger.info("All shards failed with", ex);
if (lowLevelCancellation) {
// Ensure that we cancelled in TimeSeriesIndexSearcher and not in reduce phase
assertThat(ExceptionsHelper.stackTrace(ex), containsString("TimeSeriesIndexSearcher"));
}

}

public void testCancellationOfScrollSearches() throws Exception {

List<ScriptedBlockPlugin> plugins = initBlockFactory();
Expand Down Expand Up @@ -414,8 +508,11 @@ public static class ScriptedBlockPlugin extends MockScriptPlugin {
static final String SEARCH_BLOCK_SCRIPT_NAME = "search_block";
static final String INIT_SCRIPT_NAME = "init";
static final String MAP_SCRIPT_NAME = "map";
static final String MAP_BLOCK_SCRIPT_NAME = "map_block";
static final String COMBINE_SCRIPT_NAME = "combine";
static final String REDUCE_SCRIPT_NAME = "reduce";
static final String REDUCE_FAIL_SCRIPT_NAME = "reduce_fail";
static final String REDUCE_BLOCK_SCRIPT_NAME = "reduce_block";
static final String TERM_SCRIPT_NAME = "term";

private final AtomicInteger hits = new AtomicInteger();
Expand Down Expand Up @@ -449,10 +546,16 @@ public Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
this::nullScript,
MAP_SCRIPT_NAME,
this::nullScript,
MAP_BLOCK_SCRIPT_NAME,
this::mapBlockScript,
COMBINE_SCRIPT_NAME,
this::nullScript,
REDUCE_SCRIPT_NAME,
REDUCE_BLOCK_SCRIPT_NAME,
this::blockScript,
REDUCE_SCRIPT_NAME,
this::termScript,
REDUCE_FAIL_SCRIPT_NAME,
this::reduceFailScript,
TERM_SCRIPT_NAME,
this::termScript
);
Expand All @@ -474,6 +577,11 @@ private Object searchBlockScript(Map<String, Object> params) {
return true;
}

private Object reduceFailScript(Map<String, Object> params) {
fail("Shouldn't reach reduce");
return true;
}

private Object nullScript(Map<String, Object> params) {
return null;
}
Expand All @@ -483,7 +591,9 @@ private Object blockScript(Map<String, Object> params) {
if (runnable != null) {
runnable.run();
}
LogManager.getLogger(SearchCancellationIT.class).info("Blocking in reduce");
if (shouldBlock.get()) {
LogManager.getLogger(SearchCancellationIT.class).info("Blocking in reduce");
}
hits.incrementAndGet();
try {
assertBusy(() -> assertFalse(shouldBlock.get()));
Expand All @@ -493,6 +603,23 @@ private Object blockScript(Map<String, Object> params) {
return 42;
}

private Object mapBlockScript(Map<String, Object> params) {
final Runnable runnable = beforeExecution.get();
if (runnable != null) {
runnable.run();
}
if (shouldBlock.get()) {
LogManager.getLogger(SearchCancellationIT.class).info("Blocking in map");
}
hits.incrementAndGet();
try {
assertBusy(() -> assertFalse(shouldBlock.get()));
} catch (Exception e) {
throw new RuntimeException(e);
}
return 1;
}

private Object termScript(Map<String, Object> params) {
return 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
package org.elasticsearch.search.aggregations;

import org.apache.lucene.search.Collector;
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.aggregations.timeseries.TimeSeriesIndexSearcher;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.profile.query.CollectorResult;
import org.elasticsearch.search.profile.query.InternalProfileCollector;
import org.elasticsearch.search.query.QueryPhase;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -40,7 +43,7 @@ public void preProcess(SearchContext context) {
}
if (context.aggregations().factories().context() != null
&& context.aggregations().factories().context().isInSortOrderExecutionRequired()) {
TimeSeriesIndexSearcher searcher = new TimeSeriesIndexSearcher(context.searcher());
TimeSeriesIndexSearcher searcher = new TimeSeriesIndexSearcher(context.searcher(), getCancellationChecks(context));
try {
searcher.search(context.rewrittenQuery(), bucketCollector);
} catch (IOException e) {
Expand All @@ -55,6 +58,36 @@ public void preProcess(SearchContext context) {
}
}

private List<Runnable> getCancellationChecks(SearchContext context) {
List<Runnable> cancellationChecks = new ArrayList<>();
if (context.lowLevelCancellation()) {
// This searching doesn't live beyond this phase, so we don't need to remove query cancellation
cancellationChecks.add(() -> {
final SearchShardTask task = context.getTask();
if (task != null) {
task.ensureNotCancelled();
}
});
}

boolean timeoutSet = context.scrollContext() == null
&& context.timeout() != null
&& context.timeout().equals(SearchService.NO_TIMEOUT) == false;

if (timeoutSet) {
final long startTime = context.getRelativeTimeInMillis();
final long timeout = context.timeout().millis();
final long maxTime = startTime + timeout;
cancellationChecks.add(() -> {
final long time = context.getRelativeTimeInMillis();
if (time > maxTime) {
throw new QueryPhase.TimeExceededException();
}
});
}
return cancellationChecks;
}

public void execute(SearchContext context) {
if (context.aggregations() == null) {
context.queryResult().aggregations(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,29 @@
* TODO: Convert it to use index sort instead of hard-coded tsid and timestamp values
*/
public class TimeSeriesIndexSearcher {
private static final int CHECK_CANCELLED_SCORER_INTERVAL = 1 << 11;

// We need to delegate to the other searcher here as opposed to extending IndexSearcher and inheriting default implementations as the
// IndexSearcher would most of the time be a ContextIndexSearcher that has important logic related to e.g. document-level security.
private final IndexSearcher searcher;
private final List<Runnable> cancellations;

public TimeSeriesIndexSearcher(IndexSearcher searcher) {
public TimeSeriesIndexSearcher(IndexSearcher searcher, List<Runnable> cancellations) {
this.searcher = searcher;
this.cancellations = cancellations;
}

public void search(Query query, BucketCollector bucketCollector) throws IOException {
int seen = 0;
query = searcher.rewrite(query);
Weight weight = searcher.createWeight(query, bucketCollector.scoreMode(), 1);

// Create LeafWalker for each subreader
List<LeafWalker> leafWalkers = new ArrayList<>();
for (LeafReaderContext leaf : searcher.getIndexReader().leaves()) {
if (++seen % CHECK_CANCELLED_SCORER_INTERVAL == 0) {
checkCancelled();
}
LeafBucketCollector leafCollector = bucketCollector.getLeafCollector(leaf);
Scorer scorer = weight.scorer(leaf);
if (scorer != null) {
Expand All @@ -76,6 +83,9 @@ protected boolean lessThan(LeafWalker a, LeafWalker b) {
// walkers are ordered by timestamp.
while (populateQueue(leafWalkers, queue)) {
do {
if (++seen % CHECK_CANCELLED_SCORER_INTERVAL == 0) {
checkCancelled();
}
LeafWalker walker = queue.top();
walker.collectCurrent();
if (walker.nextDoc() == DocIdSetIterator.NO_MORE_DOCS || walker.shouldPop()) {
Expand Down Expand Up @@ -131,6 +141,12 @@ private boolean queueAllHaveTsid(PriorityQueue<LeafWalker> queue, BytesRef tsid)
return true;
}

private void checkCancelled() {
for (Runnable r : cancellations) {
r.run();
}
}

private static class LeafWalker {
private final LeafCollector collector;
private final Bits liveDocs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,5 +267,5 @@ private static boolean canEarlyTerminate(IndexReader reader, SortAndFormats sort
return true;
}

static class TimeExceededException extends RuntimeException {}
public static class TimeExceededException extends RuntimeException {}
}

0 comments on commit a89d4c3

Please sign in to comment.