Skip to content

Commit

Permalink
SQL: Fix rearranging columns in PIVOT queries (#81032) (#81287)
Browse files Browse the repository at this point in the history
Resolves #80952 and also fixes arbitrary rearrangements of pivoted columns (dropping, duplicating and reordering).

The bug was caused because fields reordering in `ProjectExec` wrapping `PivotExec` have been ignored. As a result the field mask (from `QueryContainer.columnMask()`) is created wrongly because the attributes in `EsQueryExec.output()` have a different order in `EsQueryExec.output()` as in `QueryContainer.fields()`.
  • Loading branch information
Lukas Wegmann committed Dec 3, 2021
1 parent 66daa10 commit 2ba3632
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 39 deletions.
15 changes: 15 additions & 0 deletions x-pack/plugin/sql/qa/server/src/main/resources/pivot.csv-spec
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ null |48396.28571428572|62140.666666666664
// end::averageWithTwoValuesAndAlias
;

averageWithTwoValuesAndAlias-DuplicateDropAndReorderColumns
schema::c1:d|c2:d|languages:bt|c3:d
SELECT XY as c1, XY as c2, languages, XY as c3 FROM (SELECT languages, gender, salary FROM test_emp) PIVOT (AVG(salary) FOR gender IN ('M' AS "XY", 'F' "XX"));

c1 | c2 | languages | c3
-----------------+-----------------+---------------+-----------------
48396.28571428572|48396.28571428572|null |48396.28571428572
49767.22222222222|49767.22222222222|1 |49767.22222222222
44103.90909090909|44103.90909090909|2 |44103.90909090909
51741.90909090909|51741.90909090909|3 |51741.90909090909
47058.90909090909|47058.90909090909|4 |47058.90909090909
39052.875 |39052.875 |5 |39052.875

;

averageWithThreeValuesIncludingNull
schema::languages:bt|'M':d|'F':d
SELECT * FROM (SELECT languages, gender, salary FROM test_emp) PIVOT (AVG(salary) FOR gender IN ('M', 'F'));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ static class CompositeActionListener extends BaseAggActionListener {
) {
super(listener, client, cfg, output, query, request);

isPivot = query.fields().stream().anyMatch(t -> t.v1() instanceof PivotColumnRef);
isPivot = query.fields().stream().anyMatch(t -> t.extraction() instanceof PivotColumnRef);
}

@Override
Expand Down Expand Up @@ -462,12 +462,12 @@ abstract static class BaseAggActionListener extends BaseActionListener {

protected List<BucketExtractor> initBucketExtractors(SearchResponse response) {
// create response extractors for the first time
List<Tuple<FieldExtraction, String>> refs = query.fields();
List<QueryContainer.FieldInfo> refs = query.fields();

List<BucketExtractor> exts = new ArrayList<>(refs.size());
ConstantExtractor totalCount = new ConstantExtractor(response.getHits().getTotalHits().value);
for (Tuple<FieldExtraction, String> ref : refs) {
exts.add(createExtractor(ref.v1(), totalCount));
for (QueryContainer.FieldInfo ref : refs) {
exts.add(createExtractor(ref.extraction(), totalCount));
}
return exts;
}
Expand Down Expand Up @@ -537,11 +537,11 @@ static class ScrollActionListener extends BaseActionListener {
@Override
protected void handleResponse(SearchResponse response, ActionListener<Page> listener) {
// create response extractors for the first time
List<Tuple<FieldExtraction, String>> refs = query.fields();
List<QueryContainer.FieldInfo> refs = query.fields();

List<HitExtractor> exts = new ArrayList<>(refs.size());
for (Tuple<FieldExtraction, String> ref : refs) {
exts.add(createExtractor(ref.v1()));
for (QueryContainer.FieldInfo ref : refs) {
exts.add(createExtractor(ref.extraction()));
}

ScrollCursor.handle(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public static SearchSourceBuilder sourceBuilder(QueryContainer container, QueryB
// need to be retrieved from the result documents

// NB: the sortBuilder takes care of eliminating duplicates
container.fields().forEach(f -> f.v1().collectFields(sortBuilder));
container.fields().forEach(f -> f.extraction().collectFields(sortBuilder));
sortBuilder.build(source);

// add the aggs (if present)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.ql.execution.search.AggRef;
import org.elasticsearch.xpack.ql.execution.search.FieldExtraction;
import org.elasticsearch.xpack.ql.expression.Alias;
import org.elasticsearch.xpack.ql.expression.Attribute;
import org.elasticsearch.xpack.ql.expression.AttributeMap;
Expand All @@ -17,6 +16,7 @@
import org.elasticsearch.xpack.ql.expression.FieldAttribute;
import org.elasticsearch.xpack.ql.expression.Foldables;
import org.elasticsearch.xpack.ql.expression.Literal;
import org.elasticsearch.xpack.ql.expression.NameId;
import org.elasticsearch.xpack.ql.expression.NamedExpression;
import org.elasticsearch.xpack.ql.expression.Order;
import org.elasticsearch.xpack.ql.expression.ReferenceAttribute;
Expand Down Expand Up @@ -143,9 +143,17 @@ protected PhysicalPlan rule(ProjectExec project) {
AttributeMap.Builder<Expression> aliases = AttributeMap.<Expression>builder().putAll(queryC.aliases());
AttributeMap.Builder<Pipe> processors = AttributeMap.<Pipe>builder().putAll(queryC.scalarFunctions());

// recreate the query container's fields such that they appear in order of the projection and with hidden fields
// last. This is mostly needed for PIVOT queries where we have to fold projections on aggregations because they cannot be
// optimized away. Most (all) other queries usually prune nested projections in earlier steps.
List<QueryContainer.FieldInfo> fields = new ArrayList<>(queryC.fields().size());
List<QueryContainer.FieldInfo> hiddenFields = new ArrayList<>(queryC.fields());

for (NamedExpression pj : project.projections()) {
Attribute attr = pj.toAttribute();
NameId attributeId = attr.id();

if (pj instanceof Alias) {
Attribute attr = pj.toAttribute();
Expression e = ((Alias) pj).child();

// track all aliases (to determine their reference later on)
Expand All @@ -155,13 +163,27 @@ protected PhysicalPlan rule(ProjectExec project) {
if (e instanceof ScalarFunction) {
processors.put(attr, ((ScalarFunction) e).asPipe());
}

if (e instanceof NamedExpression) {
attributeId = ((NamedExpression) e).toAttribute().id();
}
}

for (QueryContainer.FieldInfo field : queryC.fields()) {
if (field.attribute().id().equals(attributeId)) {
fields.add(field);
hiddenFields.remove(field);
break;
}
}
}

fields.addAll(hiddenFields);

QueryContainer clone = new QueryContainer(
queryC.query(),
queryC.aggs(),
queryC.fields(),
fields,
aliases.build(),
queryC.pseudoFunctions(),
processors.build(),
Expand Down Expand Up @@ -574,7 +596,7 @@ else if (target instanceof Function) {
}

// add the computed column
queryC = qC.get().addColumn(new ComputedRef(proc), id);
queryC = qC.get().addColumn(new ComputedRef(proc), id, ne.toAttribute());
}

// apply the same logic above (for function inputs) to non-scalar functions with small variations:
Expand All @@ -588,11 +610,19 @@ else if (target instanceof Function) {
// attributes can only refer to declared groups
if (target instanceof Attribute) {
Check.notNull(matchingGroup, "Cannot find group [{}]", Expressions.name(target));
queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, isDateBased(target.dataType())), id);
queryC = queryC.addColumn(
new GroupByRef(matchingGroup.id(), null, isDateBased(target.dataType())),
id,
ne.toAttribute()
);
}
// handle histogram
else if (target instanceof GroupingFunction) {
queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, isDateBased(target.dataType())), id);
queryC = queryC.addColumn(
new GroupByRef(matchingGroup.id(), null, isDateBased(target.dataType())),
id,
ne.toAttribute()
);
}
// handle literal
else if (target.foldable()) {
Expand All @@ -609,7 +639,7 @@ else if (target.foldable()) {
AggregateFunction af = (AggregateFunction) target;
Tuple<QueryContainer, AggPathInput> withAgg = addAggFunction(matchingGroup, af, compoundAggMap, queryC);
// make sure to add the inner id (to handle compound aggs)
queryC = withAgg.v1().addColumn(withAgg.v2().context(), id);
queryC = withAgg.v1().addColumn(withAgg.v2().context(), id, ne.toAttribute());
}
}

Expand All @@ -623,7 +653,11 @@ else if (target.foldable()) {
matchingGroup = groupingContext.groupFor(target);
Check.notNull(matchingGroup, "Cannot find group [{}]", Expressions.name(ne));

queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, isDateBased(ne.dataType())), id);
queryC = queryC.addColumn(
new GroupByRef(matchingGroup.id(), null, isDateBased(ne.dataType())),
id,
ne.toAttribute()
);
}
// fallback
else {
Expand All @@ -636,7 +670,7 @@ else if (target.foldable()) {
if (a.aggregates().stream().allMatch(e -> e.anyMatch(Expression::foldable))) {
for (Expression grouping : a.groupings()) {
GroupByKey matchingGroup = groupingContext.groupFor(grouping);
queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, false), id);
queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, false), id, null);
}
}
return new EsQueryExec(exec.source(), exec.index(), a.output(), queryC);
Expand Down Expand Up @@ -842,19 +876,20 @@ protected PhysicalPlan rule(PivotExec plan) {
// due to the Pivot structure - the column is the last entry in the grouping set
QueryContainer query = fold.queryContainer();

List<Tuple<FieldExtraction, String>> fields = new ArrayList<>(query.fields());
List<QueryContainer.FieldInfo> fields = new ArrayList<>(query.fields());
int startingIndex = fields.size() - p.aggregates().size() - 1;
// pivot grouping
Tuple<FieldExtraction, String> groupTuple = fields.remove(startingIndex);
QueryContainer.FieldInfo groupField = fields.remove(startingIndex);
AttributeMap<Literal> values = p.valuesToLiterals();

for (int i = startingIndex; i < fields.size(); i++) {
Tuple<FieldExtraction, String> tuple = fields.remove(i);
QueryContainer.FieldInfo field = fields.remove(i);
for (Map.Entry<Attribute, Literal> entry : values.entrySet()) {
fields.add(
new Tuple<>(
new PivotColumnRef(groupTuple.v1(), tuple.v1(), entry.getValue().value()),
Expressions.id(entry.getKey())
new QueryContainer.FieldInfo(
new PivotColumnRef(groupField.extraction(), field.extraction(), entry.getValue().value()),
Expressions.id(entry.getKey()),
entry.getKey()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
public final class GlobalCountRef extends AggRef {
public static final GlobalCountRef INSTANCE = new GlobalCountRef();

private GlobalCountRef() {}

@Override
public String toString() {
return "#_Total_Hits_#";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,34 @@ public class QueryContainer {
private final Aggs aggs;
private final Query query;

public static class FieldInfo {
private final FieldExtraction extraction;
private final String id;
private final Attribute attribute;

public FieldInfo(FieldExtraction extraction, String id, Attribute attribute) {
this.extraction = extraction;
this.id = id;
this.attribute = attribute;
}

public FieldExtraction extraction() {
return extraction;
}

public String id() {
return id;
}

public Attribute attribute() {
return attribute;
}
}

// fields extracted from the response - not necessarily what the client sees
// for example in case of grouping or custom sorting, the response has extra columns
// that is filtered before getting to the client

// the list contains both the field extraction and its id (for custom sorting)
private final List<Tuple<FieldExtraction, String>> fields;
private final List<FieldInfo> fields;

// aliases found in the tree
private final AttributeMap<Expression> aliases;
Expand Down Expand Up @@ -99,7 +121,7 @@ public QueryContainer() {
public QueryContainer(
Query query,
Aggs aggs,
List<Tuple<FieldExtraction, String>> fields,
List<FieldInfo> fields,
AttributeMap<Expression> aliases,
Map<String, GroupByKey> pseudoFunctions,
AttributeMap<Pipe> scalarFunctions,
Expand Down Expand Up @@ -152,8 +174,8 @@ public List<Tuple<Integer, Comparator>> sortingColumns() {

int atIndex = -1;
for (int i = 0; i < fields.size(); i++) {
Tuple<FieldExtraction, String> field = fields.get(i);
if (field.v2().equals(expressionId)) {
FieldInfo field = fields.get(i);
if (field.id().equals(expressionId)) {
atIndex = i;
break;
}
Expand Down Expand Up @@ -195,10 +217,10 @@ public BitSet columnMask(List<Attribute> columns) {
int index = -1;

for (int i = 0; i < fields.size(); i++) {
Tuple<FieldExtraction, String> tuple = fields.get(i);
FieldInfo field = fields.get(i);
// if the index is already set there is a collision,
// so continue searching for the other tuple with the same id
if (mask.get(i) == false && tuple.v2().equals(id)) {
// so continue searching for the other field with the same id
if (mask.get(i) == false && field.id().equals(id)) {
index = i;
break;
}
Expand All @@ -221,7 +243,7 @@ public Aggs aggs() {
return aggs;
}

public List<Tuple<FieldExtraction, String>> fields() {
public List<FieldInfo> fields() {
return fields;
}

Expand All @@ -243,7 +265,7 @@ public int limit() {

public boolean isAggsOnly() {
if (aggsOnly == null) {
aggsOnly = Boolean.valueOf(this.fields.stream().anyMatch(t -> t.v1().supportedByAggsOnlyQuery()));
aggsOnly = Boolean.valueOf(this.fields.stream().anyMatch(t -> t.extraction().supportedByAggsOnlyQuery()));
}

return aggsOnly.booleanValue();
Expand Down Expand Up @@ -523,7 +545,7 @@ public FieldExtraction resolve(Attribute attribute) {
public QueryContainer addColumn(Attribute attr) {
Expression expression = aliases.resolve(attr, attr);
Tuple<QueryContainer, FieldExtraction> tuple = asFieldExtraction(attr);
return tuple.v1().addColumn(tuple.v2(), Expressions.id(expression));
return tuple.v1().addColumn(tuple.v2(), Expressions.id(expression), attr);
}

private Tuple<QueryContainer, FieldExtraction> asFieldExtraction(Attribute attr) {
Expand Down Expand Up @@ -558,11 +580,11 @@ private Tuple<QueryContainer, FieldExtraction> asFieldExtraction(Attribute attr)
throw new SqlIllegalArgumentException("Unknown output attribute {}", attr);
}

public QueryContainer addColumn(FieldExtraction ref, String id) {
public QueryContainer addColumn(FieldExtraction ref, String id, Attribute attribute) {
return new QueryContainer(
query,
aggs,
combine(fields, new Tuple<>(ref, id)),
combine(fields, new FieldInfo(ref, id, attribute)),
aliases,
pseudoFunctions,
scalarFunctions,
Expand Down

0 comments on commit 2ba3632

Please sign in to comment.