From 908dc31ffcfc30f11121ed35d56dc0e1c98bc499 Mon Sep 17 00:00:00 2001 From: Mathias Fussenegger Date: Mon, 17 Nov 2014 11:23:46 +0100 Subject: [PATCH] support functions in QueryThenFetchTask removes LocalMerge with TopN projection from qtf-plans --- .../action/sql/query/CrateSearchService.java | 2 +- .../action/sql/query/QueryShardRequest.java | 9 +- .../executor/transport/TransportExecutor.java | 3 +- .../task/elasticsearch/ESFieldExtractor.java | 7 +- .../task/elasticsearch/ESGetTask.java | 74 ++------ .../task/elasticsearch/FieldExtractor.java | 27 +++ .../task/elasticsearch/FunctionExtractor.java | 55 ++++++ .../task/elasticsearch/LiteralExtractor.java | 36 ++++ .../elasticsearch/QueryThenFetchTask.java | 176 ++++++++++++------ .../main/java/io/crate/metadata/Routing.java | 3 +- .../io/crate/planner/DataTypeVisitor.java | 8 +- .../java/io/crate/planner/PlanPrinter.java | 2 +- .../main/java/io/crate/planner/Planner.java | 60 +----- .../io/crate/planner/node/PlanVisitor.java | 2 +- .../planner/node/dql/QueryThenFetchNode.java | 4 +- .../sql/query/QueryShardRequestTest.java | 4 +- .../query/TransportQueryShardActionTest.java | 3 +- .../elasticsearch/QueyThenFetchTaskTest.java | 63 ++++++- .../java/io/crate/planner/PlannerTest.java | 24 +-- 19 files changed, 341 insertions(+), 221 deletions(-) create mode 100644 sql/src/main/java/io/crate/executor/transport/task/elasticsearch/FieldExtractor.java create mode 100644 sql/src/main/java/io/crate/executor/transport/task/elasticsearch/FunctionExtractor.java create mode 100644 sql/src/main/java/io/crate/executor/transport/task/elasticsearch/LiteralExtractor.java diff --git a/sql/src/main/java/io/crate/action/sql/query/CrateSearchService.java b/sql/src/main/java/io/crate/action/sql/query/CrateSearchService.java index f1e39219edbf..a962b1e57242 100644 --- a/sql/src/main/java/io/crate/action/sql/query/CrateSearchService.java +++ b/sql/src/main/java/io/crate/action/sql/query/CrateSearchService.java @@ -240,7 +240,7 @@ private OutputContext(SearchContext searchContext, List partition private static class OutputSymbolVisitor extends SymbolVisitor { - public void process(List outputs, OutputContext context) { + public void process(List outputs, OutputContext context) { for (Symbol output : outputs) { process(output, context); } diff --git a/sql/src/main/java/io/crate/action/sql/query/QueryShardRequest.java b/sql/src/main/java/io/crate/action/sql/query/QueryShardRequest.java index dcea645b37d8..84c28b3be768 100644 --- a/sql/src/main/java/io/crate/action/sql/query/QueryShardRequest.java +++ b/sql/src/main/java/io/crate/action/sql/query/QueryShardRequest.java @@ -38,7 +38,7 @@ public class QueryShardRequest extends ActionRequest { private String index; private Integer shard; - private List outputs; + private List outputs; private List orderBy; private boolean[] reverseFlags; private Boolean[] nullsFirst; @@ -51,7 +51,7 @@ public QueryShardRequest() {} public QueryShardRequest(String index, int shard, - List outputs, + List outputs, List orderBy, boolean[] reverseFlags, Boolean[] nullsFirst, @@ -84,10 +84,11 @@ public void readFrom(StreamInput in) throws IOException { shard = in.readVInt(); int numOutputs = in.readVInt(); - outputs = new ArrayList<>(numOutputs); + List outputs = new ArrayList<>(numOutputs); for (int i = 0; i < numOutputs; i++) { outputs.add(Symbol.fromStream(in)); } + this.outputs = outputs; int numOrderBy = in.readVInt(); orderBy = new ArrayList<>(numOrderBy); @@ -166,7 +167,7 @@ public int shardId() { return shard; } - public List outputs() { + public List outputs() { return outputs; } diff --git a/sql/src/main/java/io/crate/executor/transport/TransportExecutor.java b/sql/src/main/java/io/crate/executor/transport/TransportExecutor.java index 1e8564489454..036d5d040e6e 100644 --- a/sql/src/main/java/io/crate/executor/transport/TransportExecutor.java +++ b/sql/src/main/java/io/crate/executor/transport/TransportExecutor.java @@ -165,8 +165,9 @@ public Void visitMergeNode(MergeNode node, Job context) { } @Override - public Void visitESSearchNode(QueryThenFetchNode node, Job context) { + public Void visitQueryThenFetchNode(QueryThenFetchNode node, Job context) { context.addTask(new QueryThenFetchTask( + functions, node, clusterService, transportActionProvider.transportQueryShardAction(), diff --git a/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESFieldExtractor.java b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESFieldExtractor.java index f8c396b0dc23..c1ed2d72cf4b 100644 --- a/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESFieldExtractor.java +++ b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESFieldExtractor.java @@ -34,7 +34,7 @@ import java.util.List; import java.util.Map; -public abstract class ESFieldExtractor { +public abstract class ESFieldExtractor implements FieldExtractor { private static final Object NOT_FOUND = new Object(); @@ -105,13 +105,11 @@ Object toValue(@Nullable Map source) { public static class PartitionedByColumnExtractor extends ESFieldExtractor { private final Reference reference; - private final List partitionedByInfos; private final int valueIdx; private final Map> cache; public PartitionedByColumnExtractor(Reference reference, List partitionedByInfos) { this.reference = reference; - this.partitionedByInfos = partitionedByInfos; this.valueIdx = partitionedByInfos.indexOf(reference.info()); this.cache = new HashMap<>(); } @@ -121,8 +119,7 @@ public Object extract(SearchHit hit) { try { List values = cache.get(hit.index()); if (values == null) { - values = PartitionName - .fromStringSafe(hit.index()).values(); + values = PartitionName.fromStringSafe(hit.index()).values(); } BytesRef value = values.get(valueIdx); if (value == null) { diff --git a/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESGetTask.java b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESGetTask.java index 51c2544bbbb6..85ea976d7e49 100644 --- a/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESGetTask.java +++ b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESGetTask.java @@ -34,7 +34,6 @@ import io.crate.metadata.Functions; import io.crate.metadata.ReferenceInfo; import io.crate.metadata.Scalar; -import io.crate.operation.Input; import io.crate.operation.ProjectorUpstream; import io.crate.operation.projectors.FlatProjectorChain; import io.crate.operation.projectors.ProjectionToProjectorVisitor; @@ -323,30 +322,30 @@ public void upstreamResult(List> result) { getClass().getSimpleName())); } - static FieldExtractor buildExtractor(final String field, final Context context) { + static FieldExtractor buildExtractor(final String field, final Context context) { if (field.equals("_version")) { - return new FieldExtractor() { + return new FieldExtractor() { @Override public Object extract(GetResponse response) { return response.getVersion(); } }; } else if (field.equals("_id")) { - return new FieldExtractor() { + return new FieldExtractor() { @Override public Object extract(GetResponse response) { return response.getId(); } }; } else if (context.partitionValues.containsKey(field)) { - return new FieldExtractor() { + return new FieldExtractor() { @Override public Object extract(GetResponse response) { return context.partitionValues.get(field); } }; } else { - return new FieldExtractor() { + return new FieldExtractor() { @Override public Object extract(GetResponse response) { assert response.getSourceAsMap() != null; @@ -379,83 +378,38 @@ public String[] fields() { } } - static class Visitor extends SymbolVisitor { + static class Visitor extends SymbolVisitor> { @Override - public FieldExtractor visitReference(Reference symbol, Context context) { + public FieldExtractor visitReference(Reference symbol, Context context) { String fieldName = symbol.info().ident().columnIdent().fqn(); context.addField(fieldName); return buildExtractor(fieldName, context); } @Override - public FieldExtractor visitDynamicReference(DynamicReference symbol, Context context) { + public FieldExtractor visitDynamicReference(DynamicReference symbol, Context context) { return visitReference(symbol, context); } @Override - public FieldExtractor visitFunction(Function symbol, Context context) { - List subExtractors = new ArrayList<>(symbol.arguments().size()); + public FieldExtractor visitFunction(Function symbol, Context context) { + List> subExtractors = new ArrayList<>(symbol.arguments().size()); for (Symbol argument : symbol.arguments()) { subExtractors.add(process(argument, context)); } - return new FunctionExtractor((Scalar) context.functions.getSafe(symbol.info().ident()), subExtractors); + return new FunctionExtractor<>((Scalar) context.functions.getSafe(symbol.info().ident()), subExtractors); } @Override - public FieldExtractor visitLiteral(Literal symbol, Context context) { - return new LiteralExtractor(symbol.value()); + public FieldExtractor visitLiteral(Literal symbol, Context context) { + return new LiteralExtractor<>(symbol.value()); } @Override - protected FieldExtractor visitSymbol(Symbol symbol, Context context) { + protected FieldExtractor visitSymbol(Symbol symbol, Context context) { throw new UnsupportedOperationException( SymbolFormatter.format("Get operation not supported with symbol %s in the result column list", symbol)); } } - - private interface FieldExtractor { - Object extract(GetResponse response); - } - - private static class LiteralExtractor implements FieldExtractor { - private final Object literal; - - private LiteralExtractor(Object literal) { - this.literal = literal; - } - - @Override - public Object extract(GetResponse response) { - return literal; - } - } - - private static class FunctionExtractor implements FieldExtractor { - - private final Scalar scalar; - private final List subExtractors; - - public FunctionExtractor(Scalar scalar, List subExtractors) { - this.scalar = scalar; - this.subExtractors = subExtractors; - } - - @Override - public Object extract(final GetResponse response) { - Input[] inputs = new Input[subExtractors.size()]; - int idx = 0; - for (final FieldExtractor subExtractor : subExtractors) { - inputs[idx] = new Input() { - @Override - public Object value() { - return subExtractor.extract(response); - } - }; - idx++; - } - //noinspection unchecked - return scalar.evaluate(inputs); - } - } } diff --git a/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/FieldExtractor.java b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/FieldExtractor.java new file mode 100644 index 000000000000..d3a9e1d2a54a --- /dev/null +++ b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/FieldExtractor.java @@ -0,0 +1,27 @@ +/* + * Licensed to CRATE Technology GmbH ("Crate") under one or more contributor + * license agreements. See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. Crate licenses + * this file to you under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * However, if you have executed another commercial license agreement + * with Crate these terms will supersede the license and you may use the + * software solely pursuant to the terms of the relevant commercial agreement. + */ + +package io.crate.executor.transport.task.elasticsearch; + +public interface FieldExtractor { + + public Object extract(T value); +} diff --git a/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/FunctionExtractor.java b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/FunctionExtractor.java new file mode 100644 index 000000000000..c699c4825c98 --- /dev/null +++ b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/FunctionExtractor.java @@ -0,0 +1,55 @@ +/* + * Licensed to CRATE Technology GmbH ("Crate") under one or more contributor + * license agreements. See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. Crate licenses + * this file to you under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * However, if you have executed another commercial license agreement + * with Crate these terms will supersede the license and you may use the + * software solely pursuant to the terms of the relevant commercial agreement. + */ + +package io.crate.executor.transport.task.elasticsearch; + +import io.crate.metadata.Scalar; +import io.crate.operation.Input; + +import java.util.List; + +public class FunctionExtractor implements FieldExtractor { + + private final Scalar scalar; + private final List> subExtractors; + + public FunctionExtractor(Scalar scalar, List> subExtractors) { + this.scalar = scalar; + this.subExtractors = subExtractors; + } + + @Override + public Object extract(final T response) { + Input[] inputs = new Input[subExtractors.size()]; + int idx = 0; + for (final FieldExtractor subExtractor : subExtractors) { + inputs[idx] = new Input() { + @Override + public Object value() { + return subExtractor.extract(response); + } + }; + idx++; + } + //noinspection unchecked + return scalar.evaluate(inputs); + } +} \ No newline at end of file diff --git a/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/LiteralExtractor.java b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/LiteralExtractor.java new file mode 100644 index 000000000000..0d337d1177c0 --- /dev/null +++ b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/LiteralExtractor.java @@ -0,0 +1,36 @@ +/* + * Licensed to CRATE Technology GmbH ("Crate") under one or more contributor + * license agreements. See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. Crate licenses + * this file to you under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * However, if you have executed another commercial license agreement + * with Crate these terms will supersede the license and you may use the + * software solely pursuant to the terms of the relevant commercial agreement. + */ + +package io.crate.executor.transport.task.elasticsearch; + +public class LiteralExtractor implements FieldExtractor { + + private final Object literal; + + public LiteralExtractor(Object literal) { + this.literal = literal; + } + + @Override + public Object extract(T response) { + return literal; + } +} \ No newline at end of file diff --git a/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/QueryThenFetchTask.java b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/QueryThenFetchTask.java index cca239a7628a..541ff3f62e32 100644 --- a/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/QueryThenFetchTask.java +++ b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/QueryThenFetchTask.java @@ -31,12 +31,10 @@ import io.crate.executor.QueryResult; import io.crate.executor.Task; import io.crate.executor.TaskResult; -import io.crate.metadata.ColumnIdent; -import io.crate.metadata.Routing; +import io.crate.metadata.*; import io.crate.metadata.doc.DocSysColumns; import io.crate.planner.node.dql.QueryThenFetchNode; -import io.crate.planner.symbol.Reference; -import io.crate.planner.symbol.Symbol; +import io.crate.planner.symbol.*; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.util.BytesRef; import org.elasticsearch.action.ActionListener; @@ -87,9 +85,9 @@ public class QueryThenFetchTask implements Task { private final AtomicArray firstResults; private final AtomicArray fetchResults; private final DiscoveryNodes nodes; - private final ESFieldExtractor[] extractor; private final int numColumns; private final ClusterState state; + private final List> extractors; volatile ScoreDoc[] sortedShardList; private volatile AtomicArray shardFailures; private final Object shardFailuresMutex = new Object(); @@ -101,7 +99,8 @@ public class QueryThenFetchTask implements Task { private final static SearchRequest EMPTY_SEARCH_REQUEST = new SearchRequest(); - public QueryThenFetchTask(QueryThenFetchNode searchNode, + public QueryThenFetchTask(Functions functions, + QueryThenFetchNode searchNode, ClusterService clusterService, TransportQueryShardAction transportQueryShardAction, SearchServiceTransportAction searchServiceTransportAction, @@ -120,67 +119,63 @@ public QueryThenFetchTask(QueryThenFetchNode searchNode, results = Arrays.>asList(result); routing = searchNode.routing(); - requests = prepareRequests(routing.locations()); + + Context context = new Context(functions); + Visitor fieldExtractorVisitor = new Visitor(searchNode.partitionBy()); + extractors = new ArrayList<>(searchNode.outputs().size()); + for (Symbol symbol : searchNode.outputs()) { + extractors.add(fieldExtractorVisitor.process(symbol, context)); + } + requests = prepareRequests(context.references); docIdsToLoad = new AtomicArray<>(requests.size()); firstResults = new AtomicArray<>(requests.size()); fetchResults = new AtomicArray<>(requests.size()); - extractor = buildExtractor(searchNode.outputs()); numColumns = searchNode.outputs().size(); } - private ESFieldExtractor[] buildExtractor(final List outputs) { - ESFieldExtractor[] extractors = new ESFieldExtractor[outputs.size()]; - int i = 0; - for (Symbol output : outputs) { - assert output instanceof Reference; - Reference reference = ((Reference) output); - final ColumnIdent columnIdent = reference.info().ident().columnIdent(); - if (DocSysColumns.VERSION.equals(columnIdent)) { - extractors[i] = new ESFieldExtractor() { - @Override - public Object extract(SearchHit hit) { - return hit.getVersion(); - } - }; - } else if (DocSysColumns.ID.equals(columnIdent)) { - extractors[i] = new ESFieldExtractor() { - @Override - public Object extract(SearchHit hit) { - return new BytesRef(hit.getId()); - } - }; - } else if (DocSysColumns.DOC.equals(columnIdent)) { - extractors[i] = new ESFieldExtractor() { - @Override - public Object extract(SearchHit hit) { - return hit.getSource(); - } - }; - } else if (DocSysColumns.RAW.equals(columnIdent)) { - extractors[i] = new ESFieldExtractor() { - @Override - public Object extract(SearchHit hit) { - return hit.getSourceRef().toBytesRef(); - } - }; - } else if (DocSysColumns.SCORE.equals(columnIdent)) { - extractors[i] = new ESFieldExtractor() { - @Override - public Object extract(SearchHit hit) { - return hit.getScore(); - } - }; - } else if (searchNode.partitionBy().contains(reference.info())) { - extractors[i] = new ESFieldExtractor.PartitionedByColumnExtractor( - reference, searchNode.partitionBy() - ); - } else { - extractors[i] = new ESFieldExtractor.Source(columnIdent); - } - i++; + private static FieldExtractor buildExtractor(Reference reference, List partitionBy) { + final ColumnIdent columnIdent = reference.info().ident().columnIdent(); + if (DocSysColumns.VERSION.equals(columnIdent)) { + return new ESFieldExtractor() { + @Override + public Object extract(SearchHit hit) { + return hit.getVersion(); + } + }; + } else if (DocSysColumns.ID.equals(columnIdent)) { + return new ESFieldExtractor() { + @Override + public Object extract(SearchHit hit) { + return new BytesRef(hit.getId()); + } + }; + } else if (DocSysColumns.DOC.equals(columnIdent)) { + return new ESFieldExtractor() { + @Override + public Object extract(SearchHit hit) { + return hit.getSource(); + } + }; + } else if (DocSysColumns.RAW.equals(columnIdent)) { + return new ESFieldExtractor() { + @Override + public Object extract(SearchHit hit) { + return hit.getSourceRef().toBytesRef(); + } + }; + } else if (DocSysColumns.SCORE.equals(columnIdent)) { + return new ESFieldExtractor() { + @Override + public Object extract(SearchHit hit) { + return hit.getScore(); + } + }; + } else if (partitionBy.contains(reference.info())) { + return new ESFieldExtractor.PartitionedByColumnExtractor(reference, partitionBy); + } else { + return new ESFieldExtractor.Source(columnIdent); } - return extractors; } @Override @@ -204,8 +199,12 @@ public void start() { } } - private List> prepareRequests(Map>> locations) { + private List> prepareRequests(List outputs) { List> requests = new ArrayList<>(); + Map>> locations = searchNode.routing().locations(); + if (locations == null) { + return requests; + } for (Map.Entry>> entry : locations.entrySet()) { String node = entry.getKey(); for (Map.Entry> indexEntry : entry.getValue().entrySet()) { @@ -218,7 +217,7 @@ private List> prepareRequests(Map references = new ArrayList<>(); + + public Context(Functions functions) { + this.functions = functions; + } + + public void addReference(Reference reference) { + references.add(reference); + } + } + static class Visitor extends SymbolVisitor> { + + private List partitionBy; + + public Visitor(List partitionBy) { + this.partitionBy = partitionBy; + } + + @Override + protected FieldExtractor visitSymbol(Symbol symbol, Context context) { + throw new UnsupportedOperationException( + SymbolFormatter.format("QueryThenFetch doesn't support \"%s\" in outputs", symbol)); + } + + @Override + public FieldExtractor visitReference(final Reference reference, Context context) { + context.addReference(reference); + return buildExtractor(reference, partitionBy); + } + + @Override + public FieldExtractor visitDynamicReference(DynamicReference symbol, Context context) { + return visitReference(symbol, context); + } + + @Override + public FieldExtractor visitFunction(Function symbol, Context context) { + List> subExtractors = new ArrayList<>(symbol.arguments().size()); + for (Symbol argument : symbol.arguments()) { + subExtractors.add(process(argument, context)); + } + Scalar scalar = (Scalar) context.functions.getSafe(symbol.info().ident()); + return new FunctionExtractor<>(scalar, subExtractors); + } + + @Override + public FieldExtractor visitLiteral(Literal symbol, Context context) { + return new LiteralExtractor<>(symbol.value()); + } + } } diff --git a/sql/src/main/java/io/crate/metadata/Routing.java b/sql/src/main/java/io/crate/metadata/Routing.java index 94f016cdab51..891f4ef5475f 100644 --- a/sql/src/main/java/io/crate/metadata/Routing.java +++ b/sql/src/main/java/io/crate/metadata/Routing.java @@ -31,12 +31,13 @@ public Routing(@Nullable Map>> locations) { *     Map< indexName (string), Set
*

*/ + @Nullable public Map>> locations() { return locations; } public boolean hasLocations() { - return locations != null && locations().size() > 0; + return locations != null && locations.size() > 0; } public Set nodes() { diff --git a/sql/src/main/java/io/crate/planner/DataTypeVisitor.java b/sql/src/main/java/io/crate/planner/DataTypeVisitor.java index 4ee38d6a6986..18f59f0066a0 100644 --- a/sql/src/main/java/io/crate/planner/DataTypeVisitor.java +++ b/sql/src/main/java/io/crate/planner/DataTypeVisitor.java @@ -38,10 +38,10 @@ public static DataType fromSymbol(Symbol symbol) { return INSTANCE.process(symbol, null); } - public static List fromSymbols(List keys) { - List types = new ArrayList<>(keys.size()); - for (Symbol key : keys) { - types.add(fromSymbol(key)); + public static List fromSymbols(List symbols) { + List types = new ArrayList<>(symbols.size()); + for (Symbol symbol : symbols) { + types.add(fromSymbol(symbol)); } return types; } diff --git a/sql/src/main/java/io/crate/planner/PlanPrinter.java b/sql/src/main/java/io/crate/planner/PlanPrinter.java index 54d6274cee92..421cc1519bb7 100644 --- a/sql/src/main/java/io/crate/planner/PlanPrinter.java +++ b/sql/src/main/java/io/crate/planner/PlanPrinter.java @@ -149,7 +149,7 @@ public Void visitESGetNode(ESGetNode node, PrintContext context) { } @Override - public Void visitESSearchNode(QueryThenFetchNode node, PrintContext context) { + public Void visitQueryThenFetchNode(QueryThenFetchNode node, PrintContext context) { context.print(node.toString()); context.indent(); context.print("outputs:"); diff --git a/sql/src/main/java/io/crate/planner/Planner.java b/sql/src/main/java/io/crate/planner/Planner.java index 1b04373785ad..7d88e78dabff 100644 --- a/sql/src/main/java/io/crate/planner/Planner.java +++ b/sql/src/main/java/io/crate/planner/Planner.java @@ -116,6 +116,7 @@ protected Plan visitSelectAnalysis(SelectAnalysis analysis, Context context) { } else if (analysis.hasAggregates()) { globalAggregates(analysis, plan, context); } else { + WhereClause whereClause = analysis.whereClause(); if (!context.indexWriterProjection.isPresent() && analysis.rowGranularity().ordinal() >= RowGranularity.DOC.ordinal() && @@ -127,7 +128,7 @@ protected Plan visitSelectAnalysis(SelectAnalysis analysis, Context context) { && !analysis.table().isAlias()) { ESGet(analysis, plan, context); } else { - ESSearch(analysis, plan, context); + queryThenFetch(analysis, plan, context); } } else { normalSelect(analysis, plan, context); @@ -569,41 +570,13 @@ private void normalSelect(SelectAnalysis analysis, Plan plan, Context context) { plan.add(PlanNodeBuilder.localMerge(projectionBuilder.build(), collectNode)); } - private void ESSearch(SelectAnalysis analysis, Plan plan, Context context) { - // this is an es query - // this only supports INFOS as order by - PlannerContextBuilder contextBuilder = new PlannerContextBuilder(); - final Predicate symbolIsReference = new Predicate() { - @Override - public boolean apply(@Nullable Symbol input) { - return input instanceof Reference; - } - }; - - boolean needsProjection = !Iterables.all(analysis.outputSymbols(), symbolIsReference) - || context.indexWriterProjection.isPresent(); - List searchSymbols; - if (needsProjection) { - // we must create a deep copy of references if they are function arguments - // or they will be replaced with InputColumn instances by the context builder - if (analysis.whereClause().hasQuery()) { - analysis.whereClause(new WhereClause(functionArgumentCopier.process(analysis.whereClause().query()))); - } - List sortSymbols = analysis.sortSymbols(); + private void queryThenFetch(SelectAnalysis analysis, Plan plan, Context context) { + Preconditions.checkArgument(!context.indexWriterProjection.isPresent(), + "Must use QueryAndFetch with indexWriterProjection."); - // do the same for sortsymbols if we have a function there - if (sortSymbols != null && !Iterables.all(sortSymbols, symbolIsReference)) { - functionArgumentCopier.process(sortSymbols); - } - - contextBuilder.searchOutput(analysis.outputSymbols()); - searchSymbols = contextBuilder.toCollect(); - } else { - searchSymbols = analysis.outputSymbols(); - } - QueryThenFetchNode node = new QueryThenFetchNode( + plan.add(new QueryThenFetchNode( analysis.table().getRouting(analysis.whereClause()), - searchSymbols, + analysis.outputSymbols(), analysis.sortSymbols(), analysis.reverseFlags(), analysis.nullsFirst(), @@ -611,24 +584,7 @@ public boolean apply(@Nullable Symbol input) { analysis.offset(), analysis.whereClause(), analysis.table().partitionedByColumns() - ); - node.outputTypes(extractDataTypes(searchSymbols)); - plan.add(node); - // only add projection if we have scalar functions - if (needsProjection) { - TopNProjection topN = new TopNProjection( - Objects.firstNonNull(analysis.limit(), Constants.DEFAULT_SELECT_LIMIT), - TopN.NO_OFFSET - ); - topN.outputs(contextBuilder.outputs()); - - ImmutableList.Builder projectionBuilder = ImmutableList.builder() - .add(topN); - if (context.indexWriterProjection.isPresent()) { - projectionBuilder.add(context.indexWriterProjection.get()); - } - plan.add(PlanNodeBuilder.localMerge(projectionBuilder.build(), node)); - } + )); } private void globalAggregates(SelectAnalysis analysis, Plan plan, Context context) { diff --git a/sql/src/main/java/io/crate/planner/node/PlanVisitor.java b/sql/src/main/java/io/crate/planner/node/PlanVisitor.java index 6dc97d88dd22..0c65d0e8be45 100644 --- a/sql/src/main/java/io/crate/planner/node/PlanVisitor.java +++ b/sql/src/main/java/io/crate/planner/node/PlanVisitor.java @@ -44,7 +44,7 @@ public R visitCollectNode(CollectNode node, C context) { return visitPlanNode(node, context); } - public R visitESSearchNode(QueryThenFetchNode node, C context) { + public R visitQueryThenFetchNode(QueryThenFetchNode node, C context) { return visitPlanNode(node, context); } diff --git a/sql/src/main/java/io/crate/planner/node/dql/QueryThenFetchNode.java b/sql/src/main/java/io/crate/planner/node/dql/QueryThenFetchNode.java index 535e9e71512c..085ffb09b85a 100644 --- a/sql/src/main/java/io/crate/planner/node/dql/QueryThenFetchNode.java +++ b/sql/src/main/java/io/crate/planner/node/dql/QueryThenFetchNode.java @@ -29,6 +29,7 @@ import io.crate.analyze.WhereClause; import io.crate.metadata.ReferenceInfo; import io.crate.metadata.Routing; +import io.crate.planner.DataTypeVisitor; import io.crate.planner.node.PlanVisitor; import io.crate.planner.symbol.Symbol; import org.elasticsearch.common.Nullable; @@ -83,6 +84,7 @@ public QueryThenFetchNode(Routing routing, this.offset = Objects.firstNonNull(offset, 0); this.partitionBy = Objects.firstNonNull(partitionBy, ImmutableList.of()); + outputTypes(DataTypeVisitor.fromSymbols(outputs)); } public Routing routing() { @@ -119,7 +121,7 @@ public WhereClause whereClause() { @Override public R accept(PlanVisitor visitor, C context) { - return visitor.visitESSearchNode(this, context); + return visitor.visitQueryThenFetchNode(this, context); } @Override diff --git a/sql/src/test/java/io/crate/action/sql/query/QueryShardRequestTest.java b/sql/src/test/java/io/crate/action/sql/query/QueryShardRequestTest.java index 0817d3dc9189..ad4c61102ee2 100644 --- a/sql/src/test/java/io/crate/action/sql/query/QueryShardRequestTest.java +++ b/sql/src/test/java/io/crate/action/sql/query/QueryShardRequestTest.java @@ -34,6 +34,8 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.junit.Test; +import java.sql.Ref; + import static io.crate.testing.TestingHelpers.createFunction; import static io.crate.testing.TestingHelpers.createReference; import static org.junit.Assert.assertEquals; @@ -49,7 +51,7 @@ public void testQueryShardRequestSerialization() throws Exception { QueryShardRequest request = new QueryShardRequest( "dummyTable", 1, - ImmutableList.of(nameRef), + ImmutableList.of(nameRef), ImmutableList.of(nameRef), new boolean[] { false }, new Boolean[] { null }, diff --git a/sql/src/test/java/io/crate/action/sql/query/TransportQueryShardActionTest.java b/sql/src/test/java/io/crate/action/sql/query/TransportQueryShardActionTest.java index e6c9e11e9c4f..75dd7a8cfeb8 100644 --- a/sql/src/test/java/io/crate/action/sql/query/TransportQueryShardActionTest.java +++ b/sql/src/test/java/io/crate/action/sql/query/TransportQueryShardActionTest.java @@ -26,6 +26,7 @@ import io.crate.analyze.WhereClause; import io.crate.integrationtests.SQLTransportIntegrationTest; import io.crate.metadata.ReferenceInfo; +import io.crate.planner.symbol.Reference; import io.crate.planner.symbol.Symbol; import io.crate.test.integration.CrateIntegrationTest; import org.elasticsearch.action.ActionListener; @@ -54,7 +55,7 @@ public void testQueryShardRequestHandling() throws Exception { discoveryNodes[1].id(), new QueryShardRequest("foo", 1, - ImmutableList.of(), + ImmutableList.of(), ImmutableList.of(), new boolean[0], new Boolean[0], diff --git a/sql/src/test/java/io/crate/executor/transport/task/elasticsearch/QueyThenFetchTaskTest.java b/sql/src/test/java/io/crate/executor/transport/task/elasticsearch/QueyThenFetchTaskTest.java index 0af57fe3e932..15ff6f8c9a3b 100644 --- a/sql/src/test/java/io/crate/executor/transport/task/elasticsearch/QueyThenFetchTaskTest.java +++ b/sql/src/test/java/io/crate/executor/transport/task/elasticsearch/QueyThenFetchTaskTest.java @@ -27,9 +27,15 @@ import com.google.common.util.concurrent.ListenableFuture; import io.crate.action.sql.query.QueryShardRequest; import io.crate.action.sql.query.TransportQueryShardAction; +import io.crate.analyze.WhereClause; import io.crate.executor.QueryResult; +import io.crate.metadata.Functions; import io.crate.metadata.Routing; +import io.crate.operation.aggregation.impl.CountAggregation; import io.crate.planner.node.dql.QueryThenFetchNode; +import io.crate.planner.symbol.Aggregation; +import io.crate.planner.symbol.Symbol; +import org.apache.commons.math3.analysis.function.Exp; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.cluster.ClusterService; import org.elasticsearch.cluster.ClusterState; @@ -45,9 +51,13 @@ import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.threadpool.ThreadPool; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.mockito.Answers; import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -64,17 +74,23 @@ public class QueyThenFetchTaskTest { private QueryThenFetchTask queryThenFetchTask; - private QueryThenFetchNode searchNode; - private ClusterService clusterService; private TransportQueryShardAction transportQueryShardAction; private SearchServiceTransportAction searchServiceTransportAction; private SearchPhaseController searchPhaseController; private DiscoveryNodes nodes = mock(DiscoveryNodes.class); + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + @Mock(answer = Answers.RETURNS_MOCKS) + ClusterService clusterService; + @Before public void prepare() { - searchNode = mock(QueryThenFetchNode.class); + MockitoAnnotations.initMocks(this); + QueryThenFetchNode searchNode = mock(QueryThenFetchNode.class); Map>> locations = new HashMap<>(); HashMap> location1 = new HashMap>(); location1.put("loc1", new HashSet(Arrays.asList(1))); @@ -82,7 +98,7 @@ public void prepare() { Routing routing = new Routing(locations); when(searchNode.routing()).thenReturn(routing); - clusterService = mock(ClusterService.class); + ClusterService clusterService = mock(ClusterService.class); ClusterState state = mock(ClusterState.class); when(clusterService.state()).thenReturn(state); when(state.blocks()).thenReturn(mock(ClusterBlocks.class)); @@ -92,14 +108,41 @@ public void prepare() { transportQueryShardAction = mock(TransportQueryShardAction.class); searchServiceTransportAction = mock(SearchServiceTransportAction.class); searchPhaseController = mock(SearchPhaseController.class); - queryThenFetchTask = new QueryThenFetchTask(searchNode, - clusterService, - transportQueryShardAction, - searchServiceTransportAction, - searchPhaseController, - new ThreadPool("testpool")); + queryThenFetchTask = new QueryThenFetchTask( + mock(Functions.class), + searchNode, + clusterService, + transportQueryShardAction, + searchServiceTransportAction, + searchPhaseController, + new ThreadPool("testpool")); } + @Test + public void testAggregationInOutputs() throws Exception { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("QueryThenFetch doesn't support \"count()\" in outputs"); + new QueryThenFetchTask( + mock(Functions.class), + new QueryThenFetchNode( + new Routing(), + Arrays.asList(new Aggregation(CountAggregation.COUNT_STAR_FUNCTION, Collections.emptyList(), + Aggregation.Step.ITER, Aggregation.Step.FINAL)), + null, + null, + null, + null, + null, + WhereClause.MATCH_ALL, + null + ), + clusterService, + mock(TransportQueryShardAction.class), + mock(SearchServiceTransportAction.class), + mock(SearchPhaseController.class), + mock(ThreadPool.class) + ); + } @Test public void testFinishWithErrors() throws Throwable{ diff --git a/sql/src/test/java/io/crate/planner/PlannerTest.java b/sql/src/test/java/io/crate/planner/PlannerTest.java index 932dfd30ac87..cacba66e9650 100644 --- a/sql/src/test/java/io/crate/planner/PlannerTest.java +++ b/sql/src/test/java/io/crate/planner/PlannerTest.java @@ -54,6 +54,7 @@ import static org.hamcrest.Matchers.*; import static org.hamcrest.core.Is.is; import static org.junit.Assert.*; +import static org.mockito.Matchers.isNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -84,6 +85,7 @@ public class PlannerTest { .put("nodeTwo", ImmutableMap.>of()) .build()); + class TestModule extends MetaDataModule { @Override @@ -526,27 +528,17 @@ public void testESSearchPlanFunction() throws Exception { assertThat(planNode, instanceOf(QueryThenFetchNode.class)); QueryThenFetchNode searchNode = (QueryThenFetchNode) planNode; - assertThat(searchNode.outputs().size(), is(1)); - assertThat(((Reference) searchNode.outputs().get(0)).info().ident().columnIdent().fqn(), is("name")); + assertThat(searchNode.outputs().size(), is(2)); + assertThat(searchNode.outputs().get(0), isFunction("format")); + assertThat(searchNode.outputs().get(1), isReference("name")); - assertThat(searchNode.outputTypes().size(), is(1)); + assertThat(searchNode.outputTypes().size(), is(2)); assertEquals(DataTypes.STRING, searchNode.outputTypes().get(0)); + assertEquals(DataTypes.STRING, searchNode.outputTypes().get(1)); assertTrue(searchNode.whereClause().hasQuery()); assertThat(searchNode.partitionBy().size(), is(0)); - planNode = iterator.next(); - assertThat(planNode, instanceOf(MergeNode.class)); - MergeNode mergeNode = (MergeNode)planNode; - assertTrue(mergeNode.hasProjections()); - assertThat(mergeNode.projections().get(0), instanceOf(TopNProjection.class)); - assertThat(mergeNode.outputTypes().size(), is(2)); - assertEquals(DataTypes.STRING, mergeNode.outputTypes().get(0)); - assertThat(mergeNode.projections().get(0).outputs().get(0), instanceOf(Function.class)); - assertEquals(DataTypes.STRING, mergeNode.outputTypes().get(1)); - assertThat(mergeNode.projections().get(0).outputs().get(1), instanceOf(InputColumn.class)); - - assertFalse(iterator.hasNext()); - assertFalse(plan.expectsAffectedRows()); + assertThat(iterator.hasNext(), is(false)); } @Test