Skip to content

Commit

Permalink
SQL: Avoid empty last pages for GROUP BY queries when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
Luegg committed Feb 28, 2022
1 parent 3823fac commit 6a585a1
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 52 deletions.
Expand Up @@ -479,7 +479,7 @@ public void testCountAndCountDistinct() throws IOException {
);

String cursor = (String) response.remove("cursor");
assertNotNull(cursor);
assertNull(cursor);
assertResponse(expected, response);
}

Expand Down Expand Up @@ -1461,7 +1461,7 @@ public void testFetchAllPagesSearchHitCursor(String format) throws IOException {
List<String> texts = IntStream.range(0, size).mapToObj(i -> String.format(Locale.ROOT, "text%02d", i)).toList();
index(texts.stream().map(t -> "{\"field\": \"" + t + "\"}").toArray(String[]::new));

testFetchAllPages(format, "SELECT field FROM " + indexPattern("test") + " ORDER BY field", texts, pageSize, true);
testFetchAllPages(format, "SELECT field FROM " + indexPattern("test") + " ORDER BY field", texts, pageSize, size % pageSize == 0);
}

/**
Expand Down Expand Up @@ -1496,7 +1496,13 @@ public void testFetchAllPagesCompositeAggCursor(String format) throws IOExceptio
List<String> texts = IntStream.range(0, size).mapToObj(i -> String.format(Locale.ROOT, "text%02d", i)).toList();
index(texts.stream().map(t -> "{\"field\": \"" + t + "\"}").toArray(String[]::new));

testFetchAllPages(format, "SELECT field FROM " + indexPattern("test") + " GROUP BY field ORDER BY field", texts, pageSize, true);
testFetchAllPages(
format,
"SELECT field FROM " + indexPattern("test") + " GROUP BY field ORDER BY field",
texts,
pageSize,
size % pageSize == 0
);
}

public void testFetchAllPagesListCursorTxt() throws IOException {
Expand Down
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.BucketSelectorPipelineAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.ql.execution.search.extractor.BucketExtractor;
import org.elasticsearch.xpack.ql.type.Schema;
Expand All @@ -27,6 +28,7 @@
import org.elasticsearch.xpack.sql.session.Cursor;
import org.elasticsearch.xpack.sql.session.Rows;
import org.elasticsearch.xpack.sql.session.SqlConfiguration;
import org.elasticsearch.xpack.sql.util.Check;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -148,7 +150,8 @@ public void onResponse(SearchResponse response) {
}

protected Supplier<CompositeAggRowSet> makeRowSet(SearchResponse response) {
return () -> new CompositeAggRowSet(extractors, mask, response, limit);
CompositeAggregationBuilder aggregation = getCompositeBuilder(nextQuery);
return () -> new CompositeAggRowSet(extractors, mask, response, aggregation.size(), limit, mightProducePartialPages(aggregation));
}

protected BiFunction<SearchSourceBuilder, CompositeAggRowSet, CompositeAggCursor> makeCursor() {
Expand Down Expand Up @@ -200,33 +203,35 @@ static void handle(
private static boolean shouldRetryDueToEmptyPage(SearchResponse response) {
CompositeAggregation composite = getComposite(response);
// if there are no buckets but a next page, go fetch it instead of sending an empty response to the client
return composite != null
&& composite.getBuckets().isEmpty()
&& composite.afterKey() != null
&& composite.afterKey().isEmpty() == false;
return composite.getBuckets().isEmpty() && composite.afterKey() != null && composite.afterKey().isEmpty() == false;
}

static CompositeAggregation getComposite(SearchResponse response) {
Aggregation agg = response.getAggregations().get(Aggs.ROOT_GROUP_NAME);
if (agg == null) {
return null;
}
static CompositeAggregationBuilder getCompositeBuilder(SearchSourceBuilder source) {
AggregationBuilder aggregation = source.aggregations()
.getAggregatorFactories()
.stream()
.filter(a -> Objects.equals(a.getName(), Aggs.ROOT_GROUP_NAME))
.findFirst()
.orElse(null);

if (agg instanceof CompositeAggregation) {
return (CompositeAggregation) agg;
}
Check.isTrue(aggregation instanceof CompositeAggregationBuilder, "Unexpected aggregation builder " + aggregation);

throw new SqlIllegalArgumentException("Unrecognized root group found; {}", agg.getClass());
return (CompositeAggregationBuilder) aggregation;
}

private static void updateCompositeAfterKey(SearchResponse r, SearchSourceBuilder search) {
CompositeAggregation composite = getComposite(r);
static boolean mightProducePartialPages(CompositeAggregationBuilder aggregation) {
return aggregation.getPipelineAggregations().stream().anyMatch(a -> a instanceof BucketSelectorPipelineAggregationBuilder);
}

if (composite == null) {
throw new SqlIllegalArgumentException("Invalid server response; no group-by detected");
}
static CompositeAggregation getComposite(SearchResponse response) {
Aggregation agg = response.getAggregations().get(Aggs.ROOT_GROUP_NAME);
Check.isTrue(agg instanceof CompositeAggregation, "Unrecognized root group found; " + agg);

updateSourceAfterKey(composite.afterKey(), search);
return (CompositeAggregation) agg;
}

private static void updateCompositeAfterKey(SearchResponse r, SearchSourceBuilder search) {
updateSourceAfterKey(getComposite(r).afterKey(), search);
}

private static void updateSourceAfterKey(Map<String, Object> afterKey, SearchSourceBuilder search) {
Expand Down
Expand Up @@ -15,8 +15,6 @@
import java.util.List;
import java.util.Map;

import static java.util.Collections.emptyList;

/**
* {@link RowSet} specific to (GROUP BY) aggregation.
*/
Expand All @@ -29,21 +27,24 @@ class CompositeAggRowSet extends ResultRowSet<BucketExtractor> {
int size;
int row = 0;

CompositeAggRowSet(List<BucketExtractor> exts, BitSet mask, SearchResponse response, int limit) {
CompositeAggRowSet(
List<BucketExtractor> exts,
BitSet mask,
SearchResponse response,
int sizeRequested,
int remainingLimit,
boolean mightProducePartialPages
) {
super(exts, mask);

CompositeAggregation composite = CompositeAggCursor.getComposite(response);
if (composite != null) {
buckets = composite.getBuckets();
afterKey = composite.afterKey();
} else {
buckets = emptyList();
afterKey = null;
}
buckets = composite.getBuckets();
afterKey = composite.afterKey();

// page size
size = limit == -1 ? buckets.size() : Math.min(buckets.size(), limit);
remainingData = remainingData(afterKey != null, size, limit);
size = remainingLimit == -1 ? buckets.size() : Math.min(buckets.size(), remainingLimit);
boolean hasNextPage = mightProducePartialPages || buckets.size() == sizeRequested;
remainingData = remainingData(hasNextPage, size, remainingLimit);
}

static int remainingData(boolean hasNextPage, int size, int limit) {
Expand Down
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.ql.execution.search.extractor.BucketExtractor;
import org.elasticsearch.xpack.ql.type.Schema;
Expand Down Expand Up @@ -64,7 +65,17 @@ public String getWriteableName() {

@Override
protected Supplier<CompositeAggRowSet> makeRowSet(SearchResponse response) {
return () -> new PivotRowSet(Schema.EMPTY, extractors(), mask(), response, limit(), previousKey);
CompositeAggregationBuilder aggregation = getCompositeBuilder(next());
return () -> new PivotRowSet(
Schema.EMPTY,
extractors(),
mask(),
response,
aggregation.size(),
limit(),
previousKey,
mightProducePartialPages(aggregation)
);
}

@Override
Expand Down
Expand Up @@ -31,10 +31,12 @@ class PivotRowSet extends SchemaCompositeAggRowSet {
List<BucketExtractor> exts,
BitSet mask,
SearchResponse response,
int sizeRequested,
int limit,
Map<String, Object> previousLastKey
Map<String, Object> previousLastKey,
boolean mightProducePartialPages
) {
super(schema, exts, mask, response, limit);
super(schema, exts, mask, response, sizeRequested, limit, mightProducePartialPages);

data = buckets.isEmpty() ? emptyList() : new ArrayList<>();

Expand Down
Expand Up @@ -30,6 +30,7 @@
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.MultiBucketConsumerService;
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation.Bucket;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
Expand Down Expand Up @@ -452,22 +453,27 @@ static class CompositeActionListener extends BaseAggActionListener {

@Override
protected void handleResponse(SearchResponse response, ActionListener<Page> listener) {
CompositeAggregationBuilder aggregation = CompositeAggCursor.getCompositeBuilder(request.source());

Supplier<CompositeAggRowSet> makeRowSet = isPivot
? () -> new PivotRowSet(
schema,
initBucketExtractors(response),
mask,
response,
aggregation.size(),
query.sortingColumns().isEmpty() ? query.limit() : -1,
null
null,
CompositeAggCursor.mightProducePartialPages(aggregation)
)
: () -> new SchemaCompositeAggRowSet(
schema,
initBucketExtractors(response),
mask,
response,
query.sortingColumns().isEmpty() ? query.limit() : -1
aggregation.size(),
query.sortingColumns().isEmpty() ? query.limit() : -1,
CompositeAggCursor.mightProducePartialPages(aggregation)
);

BiFunction<SearchSourceBuilder, CompositeAggRowSet, CompositeAggCursor> makeCursor = isPivot ? (q, r) -> {
Expand Down
Expand Up @@ -23,8 +23,16 @@ class SchemaCompositeAggRowSet extends CompositeAggRowSet implements SchemaRowSe

private final Schema schema;

SchemaCompositeAggRowSet(Schema schema, List<BucketExtractor> exts, BitSet mask, SearchResponse r, int limitAggs) {
super(exts, mask, r, limitAggs);
SchemaCompositeAggRowSet(
Schema schema,
List<BucketExtractor> exts,
BitSet mask,
SearchResponse r,
int sizeRequested,
int limitAggs,
boolean mightProducePartialPages
) {
super(exts, mask, r, sizeRequested, limitAggs, mightProducePartialPages);
this.schema = schema;
}

Expand Down
Expand Up @@ -108,21 +108,14 @@ public void nextPage(SqlConfiguration cfg, Client client, ActionListener<Page> l
client.search(
request,
ActionListener.wrap(
(SearchResponse response) -> handle(
client,
response,
request.source(),
makeRowSet(nextQuery.size(), response),
listener,
includeFrozen
),
(SearchResponse response) -> handle(client, response, request.source(), makeRowSet(response), listener, includeFrozen),
listener::onFailure
)
);
}

private Supplier<SearchHitRowSet> makeRowSet(int sizeRequested, SearchResponse response) {
return () -> new SearchHitRowSet(extractors, mask, sizeRequested, limit, response);
private Supplier<SearchHitRowSet> makeRowSet(SearchResponse response) {
return () -> new SearchHitRowSet(extractors, mask, nextQuery.size(), limit, response);
}

static void handle(
Expand Down

0 comments on commit 6a585a1

Please sign in to comment.