From 33acebe519c912b10fff791635704974807aa8de Mon Sep 17 00:00:00 2001 From: Mathias Fussenegger Date: Mon, 27 Oct 2014 16:04:18 +0100 Subject: [PATCH] support functions in select-list on queries that do pk lookups Also made the ESGetTask standalone so it doesn't require an additional MergeTask. In addition the plan for insert from query has changed so that it will always use QueryAndFetch to avoid handler round-trips. --- .../executor/transport/TransportExecutor.java | 8 + .../task/elasticsearch/ESGetTask.java | 422 +++++++++++++----- .../projectors/FlatProjectorChain.java | 2 - .../crate/operation/projectors/Projector.java | 1 - .../main/java/io/crate/planner/Planner.java | 42 +- .../io/crate/planner/node/dql/ESGetNode.java | 60 ++- .../transport/TransportExecutorTest.java | 57 ++- .../WherePKIntegrationTest.java | 84 ++++ .../java/io/crate/planner/PlannerTest.java | 23 +- 9 files changed, 515 insertions(+), 184 deletions(-) create mode 100644 sql/src/test/java/io/crate/integrationtests/WherePKIntegrationTest.java diff --git a/sql/src/main/java/io/crate/executor/transport/TransportExecutor.java b/sql/src/main/java/io/crate/executor/transport/TransportExecutor.java index 7a7d196a7242..1e8564489454 100644 --- a/sql/src/main/java/io/crate/executor/transport/TransportExecutor.java +++ b/sql/src/main/java/io/crate/executor/transport/TransportExecutor.java @@ -38,6 +38,7 @@ import io.crate.operation.ImplementationSymbolVisitor; import io.crate.operation.collect.HandlerSideDataCollectOperation; import io.crate.operation.collect.StatsTables; +import io.crate.operation.projectors.ProjectionToProjectorVisitor; import io.crate.planner.Plan; import io.crate.planner.RowGranularity; import io.crate.planner.node.PlanNode; @@ -71,6 +72,7 @@ public class TransportExecutor implements Executor { // operation for handler side collecting private final HandlerSideDataCollectOperation handlerSideDataCollectOperation; + private final ProjectionToProjectorVisitor projectorVisitor; @Inject public TransportExecutor(Settings settings, @@ -92,6 +94,10 @@ public TransportExecutor(Settings settings, this.statsTables = statsTables; this.clusterService = clusterService; this.visitor = new Visitor(); + ImplementationSymbolVisitor clusterImplementationSymbolVisitor = + new ImplementationSymbolVisitor(referenceResolver, functions, RowGranularity.CLUSTER); + projectorVisitor = new ProjectionToProjectorVisitor( + clusterService, settings, transportActionProvider, clusterImplementationSymbolVisitor); } @Override @@ -173,6 +179,8 @@ public Void visitESSearchNode(QueryThenFetchNode node, Job context) { @Override public Void visitESGetNode(ESGetNode node, Job context) { context.addTask(new ESGetTask( + functions, + projectorVisitor, transportActionProvider.transportMultiGetAction(), transportActionProvider.transportGetAction(), node)); diff --git a/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESGetTask.java b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESGetTask.java index c4ceb1455d53..5ea45ed4e0fc 100644 --- a/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESGetTask.java +++ b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESGetTask.java @@ -22,16 +22,26 @@ package io.crate.executor.transport.task.elasticsearch; import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.crate.Constants; -import io.crate.executor.QueryResult; -import io.crate.executor.TaskResult; -import io.crate.types.DataType; import io.crate.PartitionName; +import io.crate.executor.QueryResult; import io.crate.executor.Task; +import io.crate.executor.TaskResult; +import io.crate.metadata.Functions; import io.crate.metadata.ReferenceInfo; +import io.crate.metadata.Scalar; +import io.crate.operation.Input; +import io.crate.operation.ProjectorUpstream; +import io.crate.operation.projectors.FlatProjectorChain; +import io.crate.operation.projectors.ProjectionToProjectorVisitor; +import io.crate.operation.projectors.Projector; import io.crate.planner.node.dql.ESGetNode; +import io.crate.planner.projection.Projection; +import io.crate.planner.projection.TopNProjection; import io.crate.planner.symbol.*; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; @@ -39,149 +49,214 @@ import org.elasticsearch.action.support.TransportAction; import org.elasticsearch.search.fetch.source.FetchSourceContext; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.*; public class ESGetTask implements Task { private final static Visitor visitor = new Visitor(); - private final ESGetNode node; + private final static OrderByVisitor ORDER_BY_VISITOR = new OrderByVisitor(); private final List> results; private final TransportAction transportAction; private final ActionRequest request; private final ActionListener listener; - private final Map partitionValues; - - public ESGetTask(TransportMultiGetAction multiGetAction, TransportGetAction getAction, ESGetNode node) { + public ESGetTask(Functions functions, + ProjectionToProjectorVisitor projectionToProjectorVisitor, + TransportMultiGetAction multiGetAction, + TransportGetAction getAction, + ESGetNode node) { assert multiGetAction != null; assert getAction != null; assert node != null; assert node.ids().size() > 0; + assert node.limit() == null || node.limit() > 0 : "shouldn't execute ESGetTask if limit is 0"; - this.node = node; - - if (this.node.partitionBy().isEmpty()) { - this.partitionValues = ImmutableMap.of(); - } else { - PartitionName partitionName = PartitionName.fromStringSafe(node.index()); - int numPartitionColumns = this.node.partitionBy().size(); - this.partitionValues = new HashMap<>(numPartitionColumns); - for (int i = 0; i partitionValues = preparePartitionValues(node); + final Context ctx = new Context(functions, node.outputs().size(), partitionValues); + FieldExtractor[] extractors = new FieldExtractor[node.outputs().size()]; + int idx = 0; for (Symbol symbol : node.outputs()) { - visitor.process(symbol, ctx); + extractors[idx] = visitor.process(symbol, ctx); + idx++; + } + for (Symbol symbol : node.sortSymbols()) { + ORDER_BY_VISITOR.process(symbol, ctx); } - final FetchSourceContext fsc = new FetchSourceContext(ctx.fields); - final FieldExtractor[] extractors = buildExtractors(ctx.fields, ctx.types); + + final FetchSourceContext fsc = new FetchSourceContext(ctx.fields()); final SettableFuture result = SettableFuture.create(); results = Arrays.>asList(result); if (node.ids().size() > 1) { - MultiGetRequest multiGetRequest = new MultiGetRequest(); - for (int i = 0; i < node.ids().size(); i++) { - String id = node.ids().get(i); - MultiGetRequest.Item item = new MultiGetRequest.Item(node.index(), Constants.DEFAULT_MAPPING_TYPE, id); - item.fetchSourceContext(fsc); - item.routing(node.routingValues().get(i)); - multiGetRequest.add(item); - } - multiGetRequest.realtime(true); - + MultiGetRequest multiGetRequest = prepareMultiGetRequest(node, fsc); transportAction = multiGetAction; request = multiGetRequest; - listener = new MultiGetResponseListener(result, extractors); + FlatProjectorChain projectorChain = getFlatProjectorChain( + projectionToProjectorVisitor, node, ctx.orderBySymbols); + listener = new MultiGetResponseListener(result, extractors, projectorChain); } else { - GetRequest getRequest = new GetRequest(node.index(), Constants.DEFAULT_MAPPING_TYPE, node.ids().get(0)); - getRequest.fetchSourceContext(fsc); - getRequest.realtime(true); - getRequest.routing(node.routingValues().get(0)); - + GetRequest getRequest = prepareGetRequest(node, fsc); transportAction = getAction; request = getRequest; listener = new GetResponseListener(result, extractors); } } - private FieldExtractor[] buildExtractors(String[] fields, DataType[] types) { - FieldExtractor[] extractors = new FieldExtractor[fields.length]; - int i = 0; - for (final String field : fields) { - if (field.equals("_version")) { - extractors[i] = new FieldExtractor() { - @Override - public Object extract(GetResponse response) { - return response.getVersion(); - } - }; - } else if (field.equals("_id")) { - extractors[i] = new FieldExtractor() { - @Override - public Object extract(GetResponse response) { - return response.getId(); - } - }; - } else if (partitionValues.containsKey(field)) { - extractors[i] = new FieldExtractor() { - @Override - public Object extract(GetResponse response) { - return partitionValues.get(field); - } - }; - } else { - extractors[i] = new FieldExtractor() { - @Override - public Object extract(GetResponse response) { - assert response.getSourceAsMap() != null; - return response.getSourceAsMap().get(field); - } - }; + private Map preparePartitionValues(ESGetNode node) { + Map partitionValues; + if (node.partitionBy().isEmpty()) { + partitionValues = ImmutableMap.of(); + } else { + PartitionName partitionName = PartitionName.fromStringSafe(node.index()); + int numPartitionColumns = node.partitionBy().size(); + partitionValues = new HashMap<>(numPartitionColumns); + for (int i = 0; i < node.partitionBy().size(); i++) { + ReferenceInfo info = node.partitionBy().get(i); + partitionValues.put( + info.ident().columnIdent().fqn(), + info.type().value(partitionName.values().get(i)) + ); } - i++; } - return extractors; + return partitionValues; + } + + private GetRequest prepareGetRequest(ESGetNode node, FetchSourceContext fsc) { + GetRequest getRequest = new GetRequest(node.index(), Constants.DEFAULT_MAPPING_TYPE, node.ids().get(0)); + getRequest.fetchSourceContext(fsc); + getRequest.realtime(true); + getRequest.routing(node.routingValues().get(0)); + return getRequest; + } + + private MultiGetRequest prepareMultiGetRequest(ESGetNode node, FetchSourceContext fsc) { + MultiGetRequest multiGetRequest = new MultiGetRequest(); + for (int i = 0; i < node.ids().size(); i++) { + String id = node.ids().get(i); + MultiGetRequest.Item item = new MultiGetRequest.Item(node.index(), Constants.DEFAULT_MAPPING_TYPE, id); + item.fetchSourceContext(fsc); + item.routing(node.routingValues().get(i)); + multiGetRequest.add(item); + } + multiGetRequest.realtime(true); + return multiGetRequest; + } + + private FlatProjectorChain getFlatProjectorChain(ProjectionToProjectorVisitor projectionToProjectorVisitor, + ESGetNode node, + List orderBySymbols) { + FlatProjectorChain projectorChain = null; + if (node.limit() != null || node.offset() > 0 || (node.sortSymbols() != null && !node.sortSymbols().isEmpty())) { + TopNProjection topNProjection = new TopNProjection( + com.google.common.base.Objects.firstNonNull(node.limit(), Constants.DEFAULT_SELECT_LIMIT), + node.offset(), + orderBySymbols, + node.reverseFlags(), + node.nullsFirst() + ); + topNProjection.outputs(genInputColumns(node.outputs().size())); + projectorChain = new FlatProjectorChain( + Arrays.asList(topNProjection), + projectionToProjectorVisitor + ); + } + return projectorChain; } - static class MultiGetResponseListener implements ActionListener { + private static List genInputColumns(int size) { + List inputColumns = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + inputColumns.add(new InputColumn(i)); + } + return inputColumns; + } + + static class MultiGetResponseListener implements ActionListener, ProjectorUpstream { private final SettableFuture result; private final FieldExtractor[] fieldExtractor; + @Nullable + private final FlatProjectorChain projectorChain; + private final Projector downstream; - public MultiGetResponseListener(SettableFuture result, FieldExtractor[] extractors) { + public MultiGetResponseListener(final SettableFuture result, + FieldExtractor[] extractors, + @Nullable FlatProjectorChain projectorChain) { this.result = result; this.fieldExtractor = extractors; + this.projectorChain = projectorChain; + if (projectorChain == null) { + downstream = null; + } else { + downstream = projectorChain.firstProjector(); + downstream.registerUpstream(this); + Futures.addCallback(projectorChain.result(), new FutureCallback() { + @Override + public void onSuccess(@Nullable Object[][] rows) { + result.set(new QueryResult(rows)); + } + + @Override + public void onFailure(@Nonnull Throwable t) { + result.setException(t); + } + }); + } } @Override public void onResponse(MultiGetResponse responses) { - List rows = new ArrayList<>(responses.getResponses().length); - for (MultiGetItemResponse response : responses) { - if (response.isFailed() || !response.getResponse().isExists()) { - continue; + if (projectorChain == null) { + List rows = new ArrayList<>(responses.getResponses().length); + for (MultiGetItemResponse response : responses) { + if (response.isFailed() || !response.getResponse().isExists()) { + continue; + } + final Object[] row = new Object[fieldExtractor.length]; + int c = 0; + for (FieldExtractor extractor : fieldExtractor) { + row[c] = extractor.extract(response.getResponse()); + c++; + } + rows.add(row); } - final Object[] row = new Object[fieldExtractor.length]; - int c = 0; - for (FieldExtractor extractor : fieldExtractor) { - row[c] = extractor.extract(response.getResponse()); - c++; + + result.set(new QueryResult(rows.toArray(new Object[rows.size()][]))); + } else { + projectorChain.startProjections(); + for (MultiGetItemResponse response : responses) { + if (response.isFailed() || !response.getResponse().isExists()) { + continue; + } + final Object[] row = new Object[fieldExtractor.length]; + int c = 0; + for (FieldExtractor extractor : fieldExtractor) { + row[c] = extractor.extract(response.getResponse()); + c++; + } + if (!downstream.setNextRow(row)) { + break; + } } - rows.add(row); + downstream.upstreamFinished(); } - - result.set(new QueryResult(rows.toArray(new Object[rows.size()][]))); } @Override public void onFailure(Throwable e) { result.setException(e); } + + @Override + public void downstream(Projector downstream) { + throw new UnsupportedOperationException("Setting downstream isn't supported on MultiGetResponseListener"); + } + + @Override + public Projector downstream() { + return downstream ; + } } static class GetResponseListener implements ActionListener { @@ -197,7 +272,7 @@ public GetResponseListener(SettableFuture result, FieldExtractor[] @Override public void onResponse(GetResponse response) { if (!response.isExists()) { - result.set((QueryResult) TaskResult.EMPTY_RESULT); + result.set( TaskResult.EMPTY_RESULT); return; } @@ -240,39 +315,125 @@ public void upstreamResult(List> result) { } static class Context { - final DataType[] types; - final String[] fields; - int idx; + private final List fields; + private final Functions functions; + private final Map partitionValues; + private final List orderBySymbols = new ArrayList<>(); + private String[] fieldsArray; + + Context(Functions functions, int size, Map partitionValues) { + this.functions = functions; + this.partitionValues = partitionValues; + fields = new ArrayList<>(size); + } - Context(int size) { - idx = 0; - fields = new String[size]; - types = new DataType[size]; + public void addField(String fieldName) { + fields.add(fieldName); } - void add(Reference reference) { - fields[idx] = reference.info().ident().columnIdent().fqn(); - types[idx] = reference.valueType(); - idx++; + public String[] fields() { + if (fieldsArray == null) { + fieldsArray = fields.toArray(new String[fields.size()]); + } + return fieldsArray; + } + + public void addOrderByField(String fieldName) { + int i = fields.indexOf(fieldName); + if (i < 0) { + fields.add(fieldName); + orderBySymbols.add(new InputColumn(fields.size())); + } else { + orderBySymbols.add(new InputColumn(i)); + } } } - static class Visitor extends SymbolVisitor { + static class OrderByVisitor extends SymbolVisitor { @Override public Void visitReference(Reference symbol, Context context) { - context.add(symbol); + context.addOrderByField(symbol.info().ident().columnIdent().fqn()); return null; } @Override public Void visitDynamicReference(DynamicReference symbol, Context context) { - context.add(symbol); + return visitReference(symbol, context); + } + + @Override + public Void visitFunction(Function symbol, Context context) { + for (Symbol arg : symbol.arguments()) { + process(arg, context); + } return null; } + } + + static class Visitor extends SymbolVisitor { + + private static FieldExtractor buildExtractor(final String field, final Context context) { + if (field.equals("_version")) { + return new FieldExtractor() { + @Override + public Object extract(GetResponse response) { + return response.getVersion(); + } + }; + } else if (field.equals("_id")) { + return new FieldExtractor() { + @Override + public Object extract(GetResponse response) { + return response.getId(); + } + }; + } else if (context.partitionValues.containsKey(field)) { + return new FieldExtractor() { + @Override + public Object extract(GetResponse response) { + return context.partitionValues.get(field); + } + }; + } else { + return new FieldExtractor() { + @Override + public Object extract(GetResponse response) { + assert response.getSourceAsMap() != null; + return response.getSourceAsMap().get(field); + } + }; + } + } + + @Override + public FieldExtractor visitReference(Reference symbol, Context context) { + String fieldName = symbol.info().ident().columnIdent().fqn(); + context.addField(fieldName); + return buildExtractor(fieldName, context); + } + + @Override + public FieldExtractor visitDynamicReference(DynamicReference symbol, Context context) { + return visitReference(symbol, context); + } + + @Override + public FieldExtractor visitFunction(Function symbol, Context context) { + List subExtractors = new ArrayList<>(symbol.arguments().size()); + for (Symbol argument : symbol.arguments()) { + subExtractors.add(process(argument, context)); + } + return new FunctionExtractor((Scalar) context.functions.getSafe(symbol.info().ident()), subExtractors); + } + + @Override + public FieldExtractor visitLiteral(Literal symbol, Context context) { + return new LiteralExtractor(symbol.value()); + } @Override - protected Void visitSymbol(Symbol symbol, Context context) { + protected FieldExtractor visitSymbol(Symbol symbol, Context context) { throw new UnsupportedOperationException( SymbolFormatter.format("Get operation not supported with symbol %s in the result column list", symbol)); } @@ -281,4 +442,45 @@ protected Void visitSymbol(Symbol symbol, Context context) { private interface FieldExtractor { Object extract(GetResponse response); } + + private static class LiteralExtractor implements FieldExtractor { + private final Object literal; + + private LiteralExtractor(Object literal) { + this.literal = literal; + } + + @Override + public Object extract(GetResponse response) { + return literal; + } + } + + private static class FunctionExtractor implements FieldExtractor { + + private final Scalar scalar; + private final List subExtractors; + + public FunctionExtractor(Scalar scalar, List subExtractors) { + this.scalar = scalar; + this.subExtractors = subExtractors; + } + + @Override + public Object extract(final GetResponse response) { + Input[] inputs = new Input[subExtractors.size()]; + int idx = 0; + for (final FieldExtractor subExtractor : subExtractors) { + inputs[idx] = new Input() { + @Override + public Object value() { + return subExtractor.extract(response); + } + }; + idx++; + } + //noinspection unchecked + return scalar.evaluate(inputs); + } + } } diff --git a/sql/src/main/java/io/crate/operation/projectors/FlatProjectorChain.java b/sql/src/main/java/io/crate/operation/projectors/FlatProjectorChain.java index e3deab6f8fed..4057f4bae287 100644 --- a/sql/src/main/java/io/crate/operation/projectors/FlatProjectorChain.java +++ b/sql/src/main/java/io/crate/operation/projectors/FlatProjectorChain.java @@ -44,14 +44,12 @@ */ public class FlatProjectorChain { - private final ProjectionToProjectorVisitor projectorVisitor; private Projector firstProjector; private final List projectors; private ResultProvider lastProjector; public FlatProjectorChain(List projections, ProjectionToProjectorVisitor projectorVisitor) { projectors = new ArrayList<>(); - this.projectorVisitor = projectorVisitor; if (projections.size() == 0) { firstProjector = new CollectingProjector(); lastProjector = (ResultProvider)firstProjector; diff --git a/sql/src/main/java/io/crate/operation/projectors/Projector.java b/sql/src/main/java/io/crate/operation/projectors/Projector.java index d565ec6054a9..5f2554cf7435 100644 --- a/sql/src/main/java/io/crate/operation/projectors/Projector.java +++ b/sql/src/main/java/io/crate/operation/projectors/Projector.java @@ -40,7 +40,6 @@ public interface Projector extends ProjectorUpstream { * * This method must be thread safe. * - * @param row * @return false if this projection does not need any more rows, true otherwise. */ public boolean setNextRow(Object ... row); diff --git a/sql/src/main/java/io/crate/planner/Planner.java b/sql/src/main/java/io/crate/planner/Planner.java index 466113f79d3f..7cd2db466fbe 100644 --- a/sql/src/main/java/io/crate/planner/Planner.java +++ b/sql/src/main/java/io/crate/planner/Planner.java @@ -116,10 +116,10 @@ protected Plan visitSelectAnalysis(SelectAnalysis analysis, Context context) { globalAggregates(analysis, plan, context); } else { WhereClause whereClause = analysis.whereClause(); - if (analysis.rowGranularity().ordinal() >= RowGranularity.DOC.ordinal() && + if (!context.indexWriterProjection.isPresent() + && analysis.rowGranularity().ordinal() >= RowGranularity.DOC.ordinal() && analysis.table().getRouting(whereClause).hasLocations() && - (analysis.table().ident().schema() == null || analysis.table().ident().schema().equals(DocSchemaInfo.NAME)) - && (analysis.isLimited() || analysis.ids().size() > 0 || !context.indexWriterProjection.isPresent())) { + analysis.table().schemaInfo().name().equals(DocSchemaInfo.NAME)) { if (analysis.ids().size() > 0 && analysis.routingValues().size() > 0 @@ -479,10 +479,7 @@ private void ESDeleteByQuery(DeleteAnalysis.NestedDeleteAnalysis analysis, Plan } private void ESGet(SelectAnalysis analysis, Plan plan, Context context) { - PlannerContextBuilder contextBuilder = new PlannerContextBuilder() - .output(analysis.outputSymbols()) - .orderBy(analysis.sortSymbols()); - + assert !context.indexWriterProjection.isPresent() : "shouldn't use ESGet with indexWriterProjection"; String indexName; if (analysis.table().isPartitioned()) { assert analysis.whereClause().partitions().size() == 1 : "ambiguous partitions for ESGet"; @@ -490,34 +487,19 @@ private void ESGet(SelectAnalysis analysis, Plan plan, Context context) { } else { indexName = analysis.table().ident().name(); } - ESGetNode getNode = new ESGetNode( indexName, + analysis.outputSymbols(), + extractDataTypes(analysis.outputSymbols()), analysis.ids(), analysis.routingValues(), + analysis.sortSymbols(), + analysis.reverseFlags(), + analysis.nullsFirst(), + analysis.limit(), + analysis.offset(), analysis.table().partitionedByColumns()); - getNode.outputs(contextBuilder.toCollect()); - getNode.outputTypes(extractDataTypes(analysis.outputSymbols())); plan.add(getNode); - - // handle sorting, limit and offset - if (analysis.isSorted() || analysis.limit() != null - || analysis.offset() > 0 - || context.indexWriterProjection.isPresent()) { - TopNProjection tnp = new TopNProjection( - Objects.firstNonNull(analysis.limit(), Constants.DEFAULT_SELECT_LIMIT), - analysis.offset(), - contextBuilder.orderBy(), - analysis.reverseFlags(), - analysis.nullsFirst() - ); - tnp.outputs(contextBuilder.outputs()); - ImmutableList.Builder projectionBuilder = ImmutableList.builder().add(tnp); - if (context.indexWriterProjection.isPresent()) { - projectionBuilder.add(context.indexWriterProjection.get()); - } - plan.add(PlanNodeBuilder.localMerge(projectionBuilder.build(), getNode)); - } } private void normalSelect(SelectAnalysis analysis, Plan plan, Context context) { @@ -1091,7 +1073,7 @@ private void ESIndex(InsertFromValuesAnalysis analysis, Plan plan) { } - static List extractDataTypes(List symbols) { + public static List extractDataTypes(List symbols) { List types = new ArrayList<>(symbols.size()); for (Symbol symbol : symbols) { types.add(DataTypeVisitor.fromSymbol(symbol)); diff --git a/sql/src/main/java/io/crate/planner/node/dql/ESGetNode.java b/sql/src/main/java/io/crate/planner/node/dql/ESGetNode.java index bb2bfcc1a53e..43a26606f42c 100644 --- a/sql/src/main/java/io/crate/planner/node/dql/ESGetNode.java +++ b/sql/src/main/java/io/crate/planner/node/dql/ESGetNode.java @@ -25,9 +25,10 @@ import com.google.common.collect.ImmutableList; import io.crate.metadata.ReferenceInfo; import io.crate.planner.node.PlanVisitor; +import io.crate.planner.symbol.Symbol; +import io.crate.types.DataType; import org.elasticsearch.common.Nullable; -import java.util.Arrays; import java.util.List; @@ -36,27 +37,38 @@ public class ESGetNode extends ESDQLPlanNode implements DQLPlanNode { private final String index; private final List ids; private final List routingValues; + private final List sortSymbols; + private final boolean[] reverseFlags; + private final Boolean[] nullsFirst; + private final Integer limit; + private final int offset; private final List partitionBy; + private final static boolean[] EMPTY_REVERSE_FLAGS = new boolean[0]; + private final static Boolean[] EMPTY_NULLS_FIRST = new Boolean[0]; + public ESGetNode(String index, + List outputs, + List outputTypes, List ids, List routingValues, + @Nullable List sortSymbols, + @Nullable boolean[] reverseFlags, + @Nullable Boolean[] nullsFirst, + @Nullable Integer limit, + int offset, @Nullable List partitionBy) { this.index = index; + this.outputs = outputs; + outputTypes(outputTypes); this.ids = ids; this.routingValues = routingValues; - this.partitionBy = Objects.firstNonNull(partitionBy, - ImmutableList.of()); - } - - public ESGetNode(String index, - List ids, - List routingValues) { - this(index, ids, routingValues, null); - } - - public ESGetNode(String index, String id, String routingValue) { - this(index, Arrays.asList(id), Arrays.asList(routingValue), null); + this.sortSymbols = Objects.firstNonNull(sortSymbols, ImmutableList.of()); + this.reverseFlags = Objects.firstNonNull(reverseFlags, EMPTY_REVERSE_FLAGS); + this.nullsFirst = Objects.firstNonNull(nullsFirst, EMPTY_NULLS_FIRST); + this.limit = limit; + this.offset = offset; + this.partitionBy = Objects.firstNonNull(partitionBy, ImmutableList.of()); } public String index() { @@ -76,6 +88,28 @@ public List routingValues() { return routingValues; } + @Nullable + public Integer limit() { + return limit; + } + + public int offset() { + return offset; + } + + public List sortSymbols() { + return sortSymbols; + } + + public boolean[] reverseFlags() { + return reverseFlags; + } + + public Boolean[] nullsFirst() { + return nullsFirst; + } + + public List partitionBy() { return partitionBy; } diff --git a/sql/src/test/java/io/crate/executor/transport/TransportExecutorTest.java b/sql/src/test/java/io/crate/executor/transport/TransportExecutorTest.java index 70e9b48b3735..5144ef414aaf 100644 --- a/sql/src/test/java/io/crate/executor/transport/TransportExecutorTest.java +++ b/sql/src/test/java/io/crate/executor/transport/TransportExecutorTest.java @@ -42,6 +42,7 @@ import io.crate.operation.projectors.TopN; import io.crate.operation.scalar.DateTruncFunction; import io.crate.planner.Plan; +import io.crate.planner.Planner; import io.crate.planner.RowGranularity; import io.crate.planner.node.dml.ESDeleteByQueryNode; import io.crate.planner.node.dml.ESDeleteNode; @@ -105,6 +106,26 @@ public class TransportExecutorTest extends SQLTransportIntegrationTest { Reference parted_date_ref = new Reference(new ReferenceInfo( new ReferenceIdent(partedTable, "date"), RowGranularity.DOC, DataTypes.TIMESTAMP)); + private static ESGetNode newGetNode(String index, List outputs, String id) { + return newGetNode(index, outputs, Arrays.asList(id)); + } + + private static ESGetNode newGetNode(String index, List outputs, List ids) { + return new ESGetNode( + index, + outputs, + Planner.extractDataTypes(outputs), + ids, + ids, + ImmutableList.of(), + new boolean[0], + new Boolean[0], + null, + 0, + null + ); + } + @Before public void transportSetUp() { CrateTestCluster cluster = cluster(); @@ -196,8 +217,8 @@ public void testMapSideCollectTask() throws Exception { public void testESGetTask() throws Exception { insertCharacters(); - ESGetNode node = new ESGetNode("characters", "2", "2"); - node.outputs(ImmutableList.of(id_ref, name_ref)); + ImmutableList outputs = ImmutableList.of(id_ref, name_ref); + ESGetNode node = newGetNode("characters", outputs, "2"); Plan plan = new Plan(); plan.add(node); Job job = executor.newJob(plan); @@ -213,9 +234,9 @@ public void testESGetTask() throws Exception { public void testESGetTaskWithDynamicReference() throws Exception { insertCharacters(); - ESGetNode node = new ESGetNode("characters", "2", "2"); - node.outputs(ImmutableList.of(id_ref, new DynamicReference( - new ReferenceIdent(new TableIdent(null, "characters"), "foo"), RowGranularity.DOC))); + ImmutableList outputs = ImmutableList.of(id_ref, new DynamicReference( + new ReferenceIdent(new TableIdent(null, "characters"), "foo"), RowGranularity.DOC)); + ESGetNode node = newGetNode("characters", outputs, "2"); Plan plan = new Plan(); plan.add(node); Job job = executor.newJob(plan); @@ -230,8 +251,8 @@ public void testESGetTaskWithDynamicReference() throws Exception { @Test public void testESMultiGet() throws Exception { insertCharacters(); - ESGetNode node = new ESGetNode("characters", asList("1", "2"), asList("1", "2")); - node.outputs(ImmutableList.of(id_ref, name_ref)); + ImmutableList outputs = ImmutableList.of(id_ref, name_ref); + ESGetNode node = newGetNode("characters", outputs, asList("1", "2")); Plan plan = new Plan(); plan.add(node); Job job = executor.newJob(plan); @@ -472,8 +493,8 @@ public void testESDeleteTask() throws Exception { assertThat(taskResult.rowCount(), is(1L)); // verify deletion - ESGetNode getNode = new ESGetNode("characters", "2", "2"); - getNode.outputs(ImmutableList.of(id_ref, name_ref)); + ImmutableList outputs = ImmutableList.of(id_ref, name_ref); + ESGetNode getNode = newGetNode("characters", outputs, "2"); plan = new Plan(); plan.add(getNode); job = executor.newJob(plan); @@ -514,8 +535,8 @@ public void testESIndexTask() throws Exception { // verify insertion - ESGetNode getNode = new ESGetNode("characters", "99", "99"); - getNode.outputs(ImmutableList.of(id_ref, name_ref)); + ImmutableList outputs = ImmutableList.of(id_ref, name_ref); + ESGetNode getNode = newGetNode("characters", outputs, "99"); plan = new Plan(); plan.add(getNode); job = executor.newJob(plan); @@ -633,10 +654,8 @@ public void testESBulkInsertTask() throws Exception { // verify insertion - ESGetNode getNode = new ESGetNode("characters", - Arrays.asList("99", "42"), - Arrays.asList("99", "42")); - getNode.outputs(ImmutableList.of(id_ref, name_ref)); + ImmutableList outputs = ImmutableList.of(id_ref, name_ref); + ESGetNode getNode = newGetNode("characters", outputs, Arrays.asList("99", "42")); plan = new Plan(); plan.add(getNode); job = executor.newJob(plan); @@ -681,8 +700,8 @@ public void testESUpdateByIdTask() throws Exception { assertThat(taskResult.rowCount(), is(1L)); // verify update - ESGetNode getNode = new ESGetNode("characters", Arrays.asList("1"), Arrays.asList("1")); - getNode.outputs(ImmutableList.of(id_ref, name_ref)); + ImmutableList outputs = ImmutableList.of(id_ref, name_ref); + ESGetNode getNode = newGetNode("characters", outputs, "1"); plan = new Plan(); plan.add(getNode); job = executor.newJob(plan); @@ -734,8 +753,8 @@ public void testUpdateByQueryTaskWithVersion() throws Exception { assertThat(result.get(0).get().errorMessage(), is(nullValue())); assertThat(result.get(0).get().rowCount(), is(1L)); - ESGetNode getNode = new ESGetNode("characters", "1", "1"); - getNode.outputs(Arrays.asList(id_ref, name_ref, version_ref)); + List outputs = Arrays.asList(id_ref, name_ref, version_ref); + ESGetNode getNode = newGetNode("characters", outputs, "1"); plan = new Plan(); plan.add(getNode); plan.expectsAffectedRows(false); diff --git a/sql/src/test/java/io/crate/integrationtests/WherePKIntegrationTest.java b/sql/src/test/java/io/crate/integrationtests/WherePKIntegrationTest.java new file mode 100644 index 000000000000..1481ed835b55 --- /dev/null +++ b/sql/src/test/java/io/crate/integrationtests/WherePKIntegrationTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to CRATE Technology GmbH ("Crate") under one or more contributor + * license agreements. See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. Crate licenses + * this file to you under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * However, if you have executed another commercial license agreement + * with Crate these terms will supersede the license and you may use the + * software solely pursuant to the terms of the relevant commercial agreement. + */ + +package io.crate.integrationtests; + +import io.crate.test.integration.CrateIntegrationTest; +import org.junit.Test; + +import static org.hamcrest.core.Is.is; + +@CrateIntegrationTest.ClusterScope(scope = CrateIntegrationTest.Scope.GLOBAL) +public class WherePKIntegrationTest extends SQLTransportIntegrationTest { + + @Test + public void testWherePkColInWithLimit() throws Exception { + execute("create table users (" + + " id int primary key," + + " name string" + + ") clustered into 2 shards with (number_of_replicas = 0)"); + ensureGreen(); + execute("insert into users (id, name) values (?, ?)", new Object[][] { + new Object[] { 1, "Arthur" }, + new Object[] { 2, "Trillian" }, + new Object[] { 3, "Marvin" }, + new Object[] { 4, "Slartibartfast" }, + }); + execute("refresh table users"); + + execute("select name from users where id in (1, 3, 4) order by name desc limit 2"); + assertThat(response.rowCount(), is(2L)); + assertThat((String) response.rows()[0][0], is("Slartibartfast")); + assertThat((String) response.rows()[1][0], is("Marvin")); + } + + @Test + public void testWherePKWithFunctionInOutputsAndOrderBy() throws Exception { + execute("create table users (" + + " id int primary key," + + " name string" + + ") clustered into 2 shards with (number_of_replicas = 0)"); + ensureGreen(); + execute("insert into users (id, name) values (?, ?)", new Object[][] { + new Object[] { 1, "Arthur" }, + new Object[] { 2, "Trillian" }, + new Object[] { 3, "Marvin" }, + new Object[] { 4, "Slartibartfast" }, + }); + execute("refresh table users"); + execute("select substr(name, 1, 1) from users where id in (1, 2, 3) order by substr(name, 1, 1) desc"); + assertThat(response.rowCount(), is(3L)); + assertThat((String) response.rows()[0][0], is("T")); + assertThat((String) response.rows()[1][0], is("M")); + assertThat((String) response.rows()[2][0], is("A")); + } + + @Test + public void testWherePkColLimit0() throws Exception { + execute("create table users (id int primary key, name string) " + + "clustered into 1 shards with (number_of_replicas = 0)"); + ensureGreen(); + execute("insert into users (id, name) values (1, 'Arthur')"); + execute("refresh table users"); + execute("select name from users where id = 1 limit 0"); + assertThat(response.rowCount(), is(0L)); + } +} diff --git a/sql/src/test/java/io/crate/planner/PlannerTest.java b/sql/src/test/java/io/crate/planner/PlannerTest.java index 5f8583291fc5..1b370078d980 100644 --- a/sql/src/test/java/io/crate/planner/PlannerTest.java +++ b/sql/src/test/java/io/crate/planner/PlannerTest.java @@ -116,6 +116,7 @@ protected void bindSchemas() { .addPrimaryKey("id") .clusteredBy("id") .build(); + when(userTableInfo.schemaInfo().name()).thenReturn(DocSchemaInfo.NAME); TableIdent charactersTableIdent = new TableIdent(null, "characters"); TableInfo charactersTableInfo = TestingTableInfo.builder(charactersTableIdent, RowGranularity.DOC, shardRouting) .add("name", DataTypes.STRING, null) @@ -123,6 +124,7 @@ protected void bindSchemas() { .addPrimaryKey("id") .clusteredBy("id") .build(); + when(charactersTableInfo.schemaInfo().name()).thenReturn(DocSchemaInfo.NAME); TableIdent partedTableIdent = new TableIdent(null, "parted"); TableInfo partedTableInfo = TestingTableInfo.builder(partedTableIdent, RowGranularity.DOC, partedRouting) .add("name", DataTypes.STRING, null) @@ -137,6 +139,7 @@ protected void bindSchemas() { .addPrimaryKey("date") .clusteredBy("id") .build(); + when(partedTableInfo.schemaInfo().name()).thenReturn(DocSchemaInfo.NAME); TableIdent emptyPartedTableIdent = new TableIdent(null, "empty_parted"); TableInfo emptyPartedTableInfo = TestingTableInfo.builder(partedTableIdent, RowGranularity.DOC, shardRouting) .add("name", DataTypes.STRING, null) @@ -146,6 +149,7 @@ protected void bindSchemas() { .addPrimaryKey("date") .clusteredBy("id") .build(); + when(emptyPartedTableInfo.schemaInfo().name()).thenReturn(DocSchemaInfo.NAME); when(schemaInfo.getTableInfo(charactersTableIdent.name())).thenReturn(charactersTableInfo); when(schemaInfo.getTableInfo(userTableIdent.name())).thenReturn(userTableInfo); when(schemaInfo.getTableInfo(partedTableIdent.name())).thenReturn(partedTableInfo); @@ -1063,27 +1067,28 @@ public void testInsertFromSubQueryGlobalAggregate() throws Exception { @Test public void testInsertFromSubQueryESGet() throws Exception { + // doesn't use ESGetNode but CollectNode. + // Round-trip to handler can be skipped by writing from the shards directly Plan plan = plan("insert into users (date, id, name) (select date, id, name from users where id=1)"); Iterator iterator = plan.iterator(); PlanNode planNode = iterator.next(); - assertThat(planNode, instanceOf(ESGetNode.class)); - - planNode = iterator.next(); - assertThat(planNode, instanceOf(MergeNode.class)); - MergeNode localMergeNode = (MergeNode)planNode; + assertThat(planNode, instanceOf(CollectNode.class)); + CollectNode collectNode = (CollectNode) planNode; - assertThat(localMergeNode.projections().size(), is(2)); - assertThat(localMergeNode.projections().get(1), instanceOf(ColumnIndexWriterProjection.class)); - ColumnIndexWriterProjection projection = (ColumnIndexWriterProjection)localMergeNode.projections().get(1); + assertThat(collectNode.projections().size(), is(1)); + assertThat(collectNode.projections().get(0), instanceOf(ColumnIndexWriterProjection.class)); + ColumnIndexWriterProjection projection = (ColumnIndexWriterProjection)collectNode.projections().get(0); assertThat(projection.columnIdents().size(), is(3)); assertThat(projection.columnIdents().get(0).fqn(), is("date")); assertThat(projection.columnIdents().get(1).fqn(), is("id")); assertThat(projection.columnIdents().get(2).fqn(), is("name")); - assertThat(((InputColumn)projection.ids().get(0)).index(), is(1)); assertThat(((InputColumn)projection.clusteredBy()).index(), is(1)); assertThat(projection.partitionedBySymbols().isEmpty(), is(true)); + + planNode = iterator.next(); + assertThat(planNode, instanceOf(MergeNode.class)); } @Test (expected = UnsupportedFeatureException.class)