Skip to content

Commit

Permalink
fix group by queries on system tables with limit or the same aggregat…
Browse files Browse the repository at this point in the history
…ion twice
  • Loading branch information
Philipp Bogensberger authored and msbt committed Dec 1, 2014
1 parent b09c253 commit ed5a443
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 31 deletions.
3 changes: 3 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ Changes for Crate
Unreleased
==========

- Fix: group by queries using limit or selecting the same aggregation twice
did not execute correctly.

- Fixed possible race condition that could cause the `LIMIT` to not be applied
correctly.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import io.crate.metadata.shard.unassigned.UnassignedShardCollectorExpression;
import io.crate.operation.Input;
import io.crate.operation.projectors.Projector;
import io.crate.operation.reference.DocLevelReferenceResolver;
import io.crate.operation.reference.sys.shard.unassigned.UnassignedShardsReferenceResolver;
import io.crate.planner.node.dql.CollectNode;
import io.crate.planner.symbol.Literal;
Expand All @@ -57,9 +56,9 @@ public class UnassignedShardsCollectService implements CollectService {
public UnassignedShardsCollectService(Functions functions,
ClusterService clusterService,
UnassignedShardsReferenceResolver unassignedShardsReferenceResolver) {
this.inputSymbolVisitor = new CollectInputSymbolVisitor<Input<?>>(
this.inputSymbolVisitor = new CollectInputSymbolVisitor(
functions,
(DocLevelReferenceResolver)unassignedShardsReferenceResolver
unassignedShardsReferenceResolver
);
this.clusterService = clusterService;
}
Expand Down
4 changes: 4 additions & 0 deletions sql/src/main/java/io/crate/planner/PlanNodeBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@

package io.crate.planner;

import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.crate.PartitionName;
import io.crate.analyze.AbstractDataAnalyzedStatement;
import io.crate.metadata.Routing;
import io.crate.planner.node.dql.CollectNode;
import io.crate.planner.node.dql.DQLPlanNode;
import io.crate.planner.node.dql.MergeNode;
import io.crate.planner.projection.Projection;
import io.crate.planner.symbol.InputColumn;
import io.crate.planner.symbol.Symbol;
import io.crate.planner.symbol.Symbols;

Expand Down Expand Up @@ -103,6 +106,7 @@ static CollectNode collect(AbstractDataAnalyzedStatement analysis,
List<Symbol> toCollect,
ImmutableList<Projection> projections,
@Nullable String partitionIdent) {
assert !Iterables.any(toCollect, Predicates.instanceOf(InputColumn.class)) : "cannot collect inputcolumns";
Routing routing = analysis.table().getRouting(analysis.whereClause());
if (partitionIdent != null && routing.hasLocations()) {
routing = filterRouting(routing, PartitionName.fromPartitionIdent(
Expand Down
33 changes: 20 additions & 13 deletions sql/src/main/java/io/crate/planner/Planner.java
Original file line number Diff line number Diff line change
Expand Up @@ -726,8 +726,24 @@ private void nonDistributedGroupBy(SelectAnalyzedStatement analysis, Plan plan,
contextBuilder.nextStep();

projectionBuilder.add(groupProjection);
boolean topNDone = addTopNIfApplicableOnReducer(analysis, contextBuilder, projectionBuilder);
if (requireLimitOnReducer(analysis, contextBuilder.aggregationsWrappedInScalar)) {

TopNProjection topN = new TopNProjection(
firstNonNull(analysis.limit(), Constants.DEFAULT_SELECT_LIMIT) + analysis.offset(),
0,
contextBuilder.orderBy(),
analysis.orderBy().reverseFlags(),
analysis.orderBy().nullsFirst()
);
// pass through on collectnode
List<Symbol> topNOutputs = new ArrayList<>(groupProjection.outputs().size());
int i = 0;
for (Symbol groupOutput : groupProjection.outputs()) {
topNOutputs.add(new InputColumn(i++, groupOutput.valueType()));
}
topN.outputs(topNOutputs);
projectionBuilder.add(topN);
}
CollectNode collectNode = PlanNodeBuilder.collect(
analysis,
toCollect,
Expand All @@ -738,7 +754,7 @@ private void nonDistributedGroupBy(SelectAnalyzedStatement analysis, Plan plan,
contextBuilder.nextStep();

// handler
ImmutableList.Builder<Projection> builder = ImmutableList.<Projection>builder();
ImmutableList.Builder<Projection> builder = ImmutableList.builder();

if (havingClause != null) {
FilterProjection fp = new FilterProjection((Function)havingClause);
Expand All @@ -752,23 +768,14 @@ private void nonDistributedGroupBy(SelectAnalyzedStatement analysis, Plan plan,
builder.add(new GroupProjection(contextBuilder.groupBy(), contextBuilder.aggregations()));
}
if (!ignoreSorting) {
List<Symbol> outputs;
List<Symbol> orderBy;
if (topNDone) {
orderBy = contextBuilder.passThroughOrderBy();
outputs = contextBuilder.passThroughOutputs();
} else {
orderBy = contextBuilder.orderBy();
outputs = contextBuilder.outputs();
}
TopNProjection topN = new TopNProjection(
firstNonNull(analysis.limit(), Constants.DEFAULT_SELECT_LIMIT),
analysis.offset(),
orderBy,
contextBuilder.orderBy(),
analysis.orderBy().reverseFlags(),
analysis.orderBy().nullsFirst()
);
topN.outputs(outputs);
topN.outputs(contextBuilder.outputs());
builder.add(topN);
}
if (context.indexWriterProjection.isPresent()) {
Expand Down
12 changes: 12 additions & 0 deletions sql/src/main/java/io/crate/planner/PlannerContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.crate.planner.symbol.Aggregation;
import io.crate.planner.symbol.InputColumn;
import io.crate.planner.symbol.Symbol;
import io.crate.planner.symbol.SymbolFormatter;

import java.util.*;

Expand Down Expand Up @@ -64,6 +65,17 @@ public Aggregation.Step step() {
}

Symbol allocateToCollect(Symbol symbol) {

// handle the case that we got 1 function twice
// symbol is already an InputColumn
if (symbol instanceof InputColumn) {
if (toCollectAllocation.containsValue(symbol)) {
return symbol;
} else {
throw new IllegalArgumentException(
SymbolFormatter.format("Symbol %s cannot be collected.", symbol));
}
}
InputColumn inputColumn = toCollectAllocation.get(symbol);
if (inputColumn == null) {
inputColumn = new InputColumn(toCollectAllocation.size(), symbol.valueType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

package io.crate.planner.node;

import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
import io.crate.Streamer;
import io.crate.breaker.RamAccountingContext;
import com.google.common.collect.Lists;
import io.crate.exceptions.ResourceUnknownException;
import io.crate.metadata.Functions;
import io.crate.operation.aggregation.AggregationFunction;
Expand Down Expand Up @@ -115,16 +115,19 @@ private Streamer<?> resolveStreamer(Aggregation aggregation, Aggregation.Step st
@Override
public Void visitCollectNode(CollectNode node, Context context) {
// get aggregations, if any
Optional<Projection> finalProjection = node.finalProjection();
List<Aggregation> aggregations = ImmutableList.of();
if (finalProjection.isPresent()) {
if (finalProjection.get().projectionType() == ProjectionType.AGGREGATION) {
aggregations = ((AggregationProjection)finalProjection.get()).aggregations();
} else if (finalProjection.get().projectionType() == ProjectionType.GROUP) {
aggregations = ((GroupProjection)finalProjection.get()).values();
List<Projection> projections = Lists.reverse(node.projections());
for(Projection projection : projections){
if (projection.projectionType() == ProjectionType.AGGREGATION) {
aggregations = ((AggregationProjection)projection).aggregations();
break;
} else if (projection.projectionType() == ProjectionType.GROUP) {
aggregations = ((GroupProjection)projection).values();
break;
}
}


int aggIdx = 0;
Aggregation aggregation;
for (DataType outputType : node.outputTypes()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ public <C, R> R accept(ProjectionVisitor<C, R> visitor, C context) {
return visitor.visitGroupProjection(this, context);
}

/**
* returns a list of outputs, with the group by keys going first,
* and the aggregations coming last
*/
@Override
public List<? extends Symbol> outputs() {
if (outputs == null) {
Expand Down
22 changes: 20 additions & 2 deletions sql/src/test/java/io/crate/integrationtests/SysShardsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.crate.blob.v2.BlobIndices;
import io.crate.test.integration.ClassLifecycleIntegrationTest;
import io.crate.testing.SQLTransportExecutor;
import io.crate.testing.TestingHelpers;
import org.elasticsearch.cluster.metadata.IndexMetaData;
import org.elasticsearch.common.settings.ImmutableSettings;
import org.elasticsearch.common.settings.Settings;
Expand All @@ -37,6 +38,7 @@
import org.junit.rules.ExpectedException;

import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;

public class SysShardsTest extends ClassLifecycleIntegrationTest {

Expand Down Expand Up @@ -94,6 +96,22 @@ public void testSelectGroupByAllTables() throws Exception {
assertEquals("quotes", response.rows()[2][1]);
}

@Test
public void testGroupByWithLimitUnassignedShards() throws Exception {
transportExecutor.exec("create table t (id int, name string) with (number_of_replicas=2)");
transportExecutor.ensureYellow();

SQLResponse response = transportExecutor.exec("select sum(num_docs), table_name, sum(num_docs) from sys.shards group by table_name order by table_name desc limit 1000");
assertThat(response.rowCount(), is(4L));
assertThat(TestingHelpers.printedTable(response.rows()),
is("0.0| t| 0.0\n" +
"0.0| quotes| 0.0\n" +
"14.0| characters| 14.0\n" +
"0.0| blobs| 0.0\n"));

transportExecutor.exec("drop table t");
}

@Test
public void testSelectGroupByWhereNotLike() throws Exception {
SQLResponse response = transportExecutor.exec("select count(*), table_name from sys.shards " +
Expand Down Expand Up @@ -256,15 +274,15 @@ public void testGroupByUnknownOrderBy() throws Exception {
public void testGroupByUnknownWhere() throws Exception {
expectedException.expect(SQLActionException.class);
expectedException.expectMessage("Column 'lol' unknown");
SQLResponse response = transportExecutor.exec(
transportExecutor.exec(
"select sum(num_docs), table_name from sys.shards where lol='funky' group by table_name");
}

@Test
public void testGlobalAggregateUnknownWhere() throws Exception {
expectedException.expect(SQLActionException.class);
expectedException.expectMessage("Column 'lol' unknown");
SQLResponse response = transportExecutor.exec(
transportExecutor.exec(
"select sum(num_docs) from sys.shards where lol='funky'");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.mock;

public class PlanNodeStreamerVisitorTest {

Expand Down Expand Up @@ -122,6 +123,25 @@ public void testGetOutputStreamersFromCollectNodeWithAggregations() throws Excep
assertThat(streamers[3], instanceOf(DataTypes.DOUBLE.streamer().getClass()));
}

@Test
public void testGetOutputStreamersFromCollectNodeWithGroupAndTopNProjection() throws Exception {
CollectNode collectNode = new CollectNode("mynode", new Routing(new HashMap<String, Map<String, Set<Integer>>>()));
collectNode.outputTypes(Arrays.<DataType>asList(DataTypes.UNDEFINED));
GroupProjection groupProjection = new GroupProjection(
Arrays.<Symbol>asList(Literal.newLiteral("key")),
Arrays.asList(new Aggregation(
countInfo,
ImmutableList.<Symbol>of(),
Aggregation.Step.PARTIAL, Aggregation.Step.FINAL))
);
collectNode.projections(Arrays.<Projection>asList(groupProjection, new TopNProjection(10,0)));
RamAccountingContext mockedRamCtx = mock(RamAccountingContext.class);
PlanNodeStreamerVisitor.Context ctx = visitor.process(collectNode, mockedRamCtx);
Streamer<?>[] streamers = ctx.outputStreamers();
assertThat(streamers.length, is(1));
assertThat(streamers[0], instanceOf(DataTypes.LONG.streamer().getClass()));
}

@Test
public void testGetInputStreamersForMergeNode() throws Exception {
MergeNode mergeNode = new MergeNode("mörtsch", 2);
Expand Down
20 changes: 14 additions & 6 deletions sql/src/test/java/io/crate/testing/SQLTransportExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,25 @@ public ActionFuture<SQLBulkResponse> execute(SQLBulkRequest request) {
}

public ClusterHealthStatus ensureGreen() {
return ensureState(ClusterHealthStatus.GREEN);
}

public ClusterHealthStatus ensureYellow() {
return ensureState(ClusterHealthStatus.YELLOW);
}

private ClusterHealthStatus ensureState(ClusterHealthStatus state) {
ClusterHealthResponse actionGet = client().admin().cluster().health(
Requests.clusterHealthRequest()
.waitForGreenStatus()
.waitForEvents(Priority.LANGUID).waitForRelocatingShards(0)
Requests.clusterHealthRequest()
.waitForStatus(state)
.waitForEvents(Priority.LANGUID).waitForRelocatingShards(0)
).actionGet();

if (actionGet.isTimedOut()) {
logger.info("ensureGreen timed out, cluster state:\n{}\n{}", client().admin().cluster().prepareState().get().getState().prettyPrint(), client().admin().cluster().preparePendingClusterTasks().get().prettyPrint());
assertThat("timed out waiting for green state", actionGet.isTimedOut(), equalTo(false));
logger.info("ensure state timed out, cluster state:\n{}\n{}", client().admin().cluster().prepareState().get().getState().prettyPrint(), client().admin().cluster().preparePendingClusterTasks().get().prettyPrint());
assertThat("timed out waiting for state", actionGet.isTimedOut(), equalTo(false));
}
assertThat(actionGet.getStatus(), equalTo(ClusterHealthStatus.GREEN));
assertThat(actionGet.getStatus(), equalTo(state));
return actionGet.getStatus();
}

Expand Down

0 comments on commit ed5a443

Please sign in to comment.