diff --git a/sql/src/main/java/io/crate/executor/transport/TransportExecutor.java b/sql/src/main/java/io/crate/executor/transport/TransportExecutor.java index 7a7d196a7242..1e8564489454 100644 --- a/sql/src/main/java/io/crate/executor/transport/TransportExecutor.java +++ b/sql/src/main/java/io/crate/executor/transport/TransportExecutor.java @@ -38,6 +38,7 @@ import io.crate.operation.ImplementationSymbolVisitor; import io.crate.operation.collect.HandlerSideDataCollectOperation; import io.crate.operation.collect.StatsTables; +import io.crate.operation.projectors.ProjectionToProjectorVisitor; import io.crate.planner.Plan; import io.crate.planner.RowGranularity; import io.crate.planner.node.PlanNode; @@ -71,6 +72,7 @@ public class TransportExecutor implements Executor { // operation for handler side collecting private final HandlerSideDataCollectOperation handlerSideDataCollectOperation; + private final ProjectionToProjectorVisitor projectorVisitor; @Inject public TransportExecutor(Settings settings, @@ -92,6 +94,10 @@ public TransportExecutor(Settings settings, this.statsTables = statsTables; this.clusterService = clusterService; this.visitor = new Visitor(); + ImplementationSymbolVisitor clusterImplementationSymbolVisitor = + new ImplementationSymbolVisitor(referenceResolver, functions, RowGranularity.CLUSTER); + projectorVisitor = new ProjectionToProjectorVisitor( + clusterService, settings, transportActionProvider, clusterImplementationSymbolVisitor); } @Override @@ -173,6 +179,8 @@ public Void visitESSearchNode(QueryThenFetchNode node, Job context) { @Override public Void visitESGetNode(ESGetNode node, Job context) { context.addTask(new ESGetTask( + functions, + projectorVisitor, transportActionProvider.transportMultiGetAction(), transportActionProvider.transportGetAction(), node)); diff --git a/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESGetTask.java b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESGetTask.java index c4ceb1455d53..5ea45ed4e0fc 100644 --- a/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESGetTask.java +++ b/sql/src/main/java/io/crate/executor/transport/task/elasticsearch/ESGetTask.java @@ -22,16 +22,26 @@ package io.crate.executor.transport.task.elasticsearch; import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.crate.Constants; -import io.crate.executor.QueryResult; -import io.crate.executor.TaskResult; -import io.crate.types.DataType; import io.crate.PartitionName; +import io.crate.executor.QueryResult; import io.crate.executor.Task; +import io.crate.executor.TaskResult; +import io.crate.metadata.Functions; import io.crate.metadata.ReferenceInfo; +import io.crate.metadata.Scalar; +import io.crate.operation.Input; +import io.crate.operation.ProjectorUpstream; +import io.crate.operation.projectors.FlatProjectorChain; +import io.crate.operation.projectors.ProjectionToProjectorVisitor; +import io.crate.operation.projectors.Projector; import io.crate.planner.node.dql.ESGetNode; +import io.crate.planner.projection.Projection; +import io.crate.planner.projection.TopNProjection; import io.crate.planner.symbol.*; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; @@ -39,149 +49,214 @@ import org.elasticsearch.action.support.TransportAction; import org.elasticsearch.search.fetch.source.FetchSourceContext; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.*; public class ESGetTask implements Task { private final static Visitor visitor = new Visitor(); - private final ESGetNode node; + private final static OrderByVisitor ORDER_BY_VISITOR = new OrderByVisitor(); private final List> results; private final TransportAction transportAction; private final ActionRequest request; private final ActionListener listener; - private final Map partitionValues; - - public ESGetTask(TransportMultiGetAction multiGetAction, TransportGetAction getAction, ESGetNode node) { + public ESGetTask(Functions functions, + ProjectionToProjectorVisitor projectionToProjectorVisitor, + TransportMultiGetAction multiGetAction, + TransportGetAction getAction, + ESGetNode node) { assert multiGetAction != null; assert getAction != null; assert node != null; assert node.ids().size() > 0; + assert node.limit() == null || node.limit() > 0 : "shouldn't execute ESGetTask if limit is 0"; - this.node = node; - - if (this.node.partitionBy().isEmpty()) { - this.partitionValues = ImmutableMap.of(); - } else { - PartitionName partitionName = PartitionName.fromStringSafe(node.index()); - int numPartitionColumns = this.node.partitionBy().size(); - this.partitionValues = new HashMap<>(numPartitionColumns); - for (int i = 0; i partitionValues = preparePartitionValues(node); + final Context ctx = new Context(functions, node.outputs().size(), partitionValues); + FieldExtractor[] extractors = new FieldExtractor[node.outputs().size()]; + int idx = 0; for (Symbol symbol : node.outputs()) { - visitor.process(symbol, ctx); + extractors[idx] = visitor.process(symbol, ctx); + idx++; + } + for (Symbol symbol : node.sortSymbols()) { + ORDER_BY_VISITOR.process(symbol, ctx); } - final FetchSourceContext fsc = new FetchSourceContext(ctx.fields); - final FieldExtractor[] extractors = buildExtractors(ctx.fields, ctx.types); + + final FetchSourceContext fsc = new FetchSourceContext(ctx.fields()); final SettableFuture result = SettableFuture.create(); results = Arrays.>asList(result); if (node.ids().size() > 1) { - MultiGetRequest multiGetRequest = new MultiGetRequest(); - for (int i = 0; i < node.ids().size(); i++) { - String id = node.ids().get(i); - MultiGetRequest.Item item = new MultiGetRequest.Item(node.index(), Constants.DEFAULT_MAPPING_TYPE, id); - item.fetchSourceContext(fsc); - item.routing(node.routingValues().get(i)); - multiGetRequest.add(item); - } - multiGetRequest.realtime(true); - + MultiGetRequest multiGetRequest = prepareMultiGetRequest(node, fsc); transportAction = multiGetAction; request = multiGetRequest; - listener = new MultiGetResponseListener(result, extractors); + FlatProjectorChain projectorChain = getFlatProjectorChain( + projectionToProjectorVisitor, node, ctx.orderBySymbols); + listener = new MultiGetResponseListener(result, extractors, projectorChain); } else { - GetRequest getRequest = new GetRequest(node.index(), Constants.DEFAULT_MAPPING_TYPE, node.ids().get(0)); - getRequest.fetchSourceContext(fsc); - getRequest.realtime(true); - getRequest.routing(node.routingValues().get(0)); - + GetRequest getRequest = prepareGetRequest(node, fsc); transportAction = getAction; request = getRequest; listener = new GetResponseListener(result, extractors); } } - private FieldExtractor[] buildExtractors(String[] fields, DataType[] types) { - FieldExtractor[] extractors = new FieldExtractor[fields.length]; - int i = 0; - for (final String field : fields) { - if (field.equals("_version")) { - extractors[i] = new FieldExtractor() { - @Override - public Object extract(GetResponse response) { - return response.getVersion(); - } - }; - } else if (field.equals("_id")) { - extractors[i] = new FieldExtractor() { - @Override - public Object extract(GetResponse response) { - return response.getId(); - } - }; - } else if (partitionValues.containsKey(field)) { - extractors[i] = new FieldExtractor() { - @Override - public Object extract(GetResponse response) { - return partitionValues.get(field); - } - }; - } else { - extractors[i] = new FieldExtractor() { - @Override - public Object extract(GetResponse response) { - assert response.getSourceAsMap() != null; - return response.getSourceAsMap().get(field); - } - }; + private Map preparePartitionValues(ESGetNode node) { + Map partitionValues; + if (node.partitionBy().isEmpty()) { + partitionValues = ImmutableMap.of(); + } else { + PartitionName partitionName = PartitionName.fromStringSafe(node.index()); + int numPartitionColumns = node.partitionBy().size(); + partitionValues = new HashMap<>(numPartitionColumns); + for (int i = 0; i < node.partitionBy().size(); i++) { + ReferenceInfo info = node.partitionBy().get(i); + partitionValues.put( + info.ident().columnIdent().fqn(), + info.type().value(partitionName.values().get(i)) + ); } - i++; } - return extractors; + return partitionValues; + } + + private GetRequest prepareGetRequest(ESGetNode node, FetchSourceContext fsc) { + GetRequest getRequest = new GetRequest(node.index(), Constants.DEFAULT_MAPPING_TYPE, node.ids().get(0)); + getRequest.fetchSourceContext(fsc); + getRequest.realtime(true); + getRequest.routing(node.routingValues().get(0)); + return getRequest; + } + + private MultiGetRequest prepareMultiGetRequest(ESGetNode node, FetchSourceContext fsc) { + MultiGetRequest multiGetRequest = new MultiGetRequest(); + for (int i = 0; i < node.ids().size(); i++) { + String id = node.ids().get(i); + MultiGetRequest.Item item = new MultiGetRequest.Item(node.index(), Constants.DEFAULT_MAPPING_TYPE, id); + item.fetchSourceContext(fsc); + item.routing(node.routingValues().get(i)); + multiGetRequest.add(item); + } + multiGetRequest.realtime(true); + return multiGetRequest; + } + + private FlatProjectorChain getFlatProjectorChain(ProjectionToProjectorVisitor projectionToProjectorVisitor, + ESGetNode node, + List orderBySymbols) { + FlatProjectorChain projectorChain = null; + if (node.limit() != null || node.offset() > 0 || (node.sortSymbols() != null && !node.sortSymbols().isEmpty())) { + TopNProjection topNProjection = new TopNProjection( + com.google.common.base.Objects.firstNonNull(node.limit(), Constants.DEFAULT_SELECT_LIMIT), + node.offset(), + orderBySymbols, + node.reverseFlags(), + node.nullsFirst() + ); + topNProjection.outputs(genInputColumns(node.outputs().size())); + projectorChain = new FlatProjectorChain( + Arrays.asList(topNProjection), + projectionToProjectorVisitor + ); + } + return projectorChain; } - static class MultiGetResponseListener implements ActionListener { + private static List genInputColumns(int size) { + List inputColumns = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + inputColumns.add(new InputColumn(i)); + } + return inputColumns; + } + + static class MultiGetResponseListener implements ActionListener, ProjectorUpstream { private final SettableFuture result; private final FieldExtractor[] fieldExtractor; + @Nullable + private final FlatProjectorChain projectorChain; + private final Projector downstream; - public MultiGetResponseListener(SettableFuture result, FieldExtractor[] extractors) { + public MultiGetResponseListener(final SettableFuture result, + FieldExtractor[] extractors, + @Nullable FlatProjectorChain projectorChain) { this.result = result; this.fieldExtractor = extractors; + this.projectorChain = projectorChain; + if (projectorChain == null) { + downstream = null; + } else { + downstream = projectorChain.firstProjector(); + downstream.registerUpstream(this); + Futures.addCallback(projectorChain.result(), new FutureCallback() { + @Override + public void onSuccess(@Nullable Object[][] rows) { + result.set(new QueryResult(rows)); + } + + @Override + public void onFailure(@Nonnull Throwable t) { + result.setException(t); + } + }); + } } @Override public void onResponse(MultiGetResponse responses) { - List rows = new ArrayList<>(responses.getResponses().length); - for (MultiGetItemResponse response : responses) { - if (response.isFailed() || !response.getResponse().isExists()) { - continue; + if (projectorChain == null) { + List rows = new ArrayList<>(responses.getResponses().length); + for (MultiGetItemResponse response : responses) { + if (response.isFailed() || !response.getResponse().isExists()) { + continue; + } + final Object[] row = new Object[fieldExtractor.length]; + int c = 0; + for (FieldExtractor extractor : fieldExtractor) { + row[c] = extractor.extract(response.getResponse()); + c++; + } + rows.add(row); } - final Object[] row = new Object[fieldExtractor.length]; - int c = 0; - for (FieldExtractor extractor : fieldExtractor) { - row[c] = extractor.extract(response.getResponse()); - c++; + + result.set(new QueryResult(rows.toArray(new Object[rows.size()][]))); + } else { + projectorChain.startProjections(); + for (MultiGetItemResponse response : responses) { + if (response.isFailed() || !response.getResponse().isExists()) { + continue; + } + final Object[] row = new Object[fieldExtractor.length]; + int c = 0; + for (FieldExtractor extractor : fieldExtractor) { + row[c] = extractor.extract(response.getResponse()); + c++; + } + if (!downstream.setNextRow(row)) { + break; + } } - rows.add(row); + downstream.upstreamFinished(); } - - result.set(new QueryResult(rows.toArray(new Object[rows.size()][]))); } @Override public void onFailure(Throwable e) { result.setException(e); } + + @Override + public void downstream(Projector downstream) { + throw new UnsupportedOperationException("Setting downstream isn't supported on MultiGetResponseListener"); + } + + @Override + public Projector downstream() { + return downstream ; + } } static class GetResponseListener implements ActionListener { @@ -197,7 +272,7 @@ public GetResponseListener(SettableFuture result, FieldExtractor[] @Override public void onResponse(GetResponse response) { if (!response.isExists()) { - result.set((QueryResult) TaskResult.EMPTY_RESULT); + result.set( TaskResult.EMPTY_RESULT); return; } @@ -240,39 +315,125 @@ public void upstreamResult(List> result) { } static class Context { - final DataType[] types; - final String[] fields; - int idx; + private final List fields; + private final Functions functions; + private final Map partitionValues; + private final List orderBySymbols = new ArrayList<>(); + private String[] fieldsArray; + + Context(Functions functions, int size, Map partitionValues) { + this.functions = functions; + this.partitionValues = partitionValues; + fields = new ArrayList<>(size); + } - Context(int size) { - idx = 0; - fields = new String[size]; - types = new DataType[size]; + public void addField(String fieldName) { + fields.add(fieldName); } - void add(Reference reference) { - fields[idx] = reference.info().ident().columnIdent().fqn(); - types[idx] = reference.valueType(); - idx++; + public String[] fields() { + if (fieldsArray == null) { + fieldsArray = fields.toArray(new String[fields.size()]); + } + return fieldsArray; + } + + public void addOrderByField(String fieldName) { + int i = fields.indexOf(fieldName); + if (i < 0) { + fields.add(fieldName); + orderBySymbols.add(new InputColumn(fields.size())); + } else { + orderBySymbols.add(new InputColumn(i)); + } } } - static class Visitor extends SymbolVisitor { + static class OrderByVisitor extends SymbolVisitor { @Override public Void visitReference(Reference symbol, Context context) { - context.add(symbol); + context.addOrderByField(symbol.info().ident().columnIdent().fqn()); return null; } @Override public Void visitDynamicReference(DynamicReference symbol, Context context) { - context.add(symbol); + return visitReference(symbol, context); + } + + @Override + public Void visitFunction(Function symbol, Context context) { + for (Symbol arg : symbol.arguments()) { + process(arg, context); + } return null; } + } + + static class Visitor extends SymbolVisitor { + + private static FieldExtractor buildExtractor(final String field, final Context context) { + if (field.equals("_version")) { + return new FieldExtractor() { + @Override + public Object extract(GetResponse response) { + return response.getVersion(); + } + }; + } else if (field.equals("_id")) { + return new FieldExtractor() { + @Override + public Object extract(GetResponse response) { + return response.getId(); + } + }; + } else if (context.partitionValues.containsKey(field)) { + return new FieldExtractor() { + @Override + public Object extract(GetResponse response) { + return context.partitionValues.get(field); + } + }; + } else { + return new FieldExtractor() { + @Override + public Object extract(GetResponse response) { + assert response.getSourceAsMap() != null; + return response.getSourceAsMap().get(field); + } + }; + } + } + + @Override + public FieldExtractor visitReference(Reference symbol, Context context) { + String fieldName = symbol.info().ident().columnIdent().fqn(); + context.addField(fieldName); + return buildExtractor(fieldName, context); + } + + @Override + public FieldExtractor visitDynamicReference(DynamicReference symbol, Context context) { + return visitReference(symbol, context); + } + + @Override + public FieldExtractor visitFunction(Function symbol, Context context) { + List subExtractors = new ArrayList<>(symbol.arguments().size()); + for (Symbol argument : symbol.arguments()) { + subExtractors.add(process(argument, context)); + } + return new FunctionExtractor((Scalar) context.functions.getSafe(symbol.info().ident()), subExtractors); + } + + @Override + public FieldExtractor visitLiteral(Literal symbol, Context context) { + return new LiteralExtractor(symbol.value()); + } @Override - protected Void visitSymbol(Symbol symbol, Context context) { + protected FieldExtractor visitSymbol(Symbol symbol, Context context) { throw new UnsupportedOperationException( SymbolFormatter.format("Get operation not supported with symbol %s in the result column list", symbol)); } @@ -281,4 +442,45 @@ protected Void visitSymbol(Symbol symbol, Context context) { private interface FieldExtractor { Object extract(GetResponse response); } + + private static class LiteralExtractor implements FieldExtractor { + private final Object literal; + + private LiteralExtractor(Object literal) { + this.literal = literal; + } + + @Override + public Object extract(GetResponse response) { + return literal; + } + } + + private static class FunctionExtractor implements FieldExtractor { + + private final Scalar scalar; + private final List subExtractors; + + public FunctionExtractor(Scalar scalar, List subExtractors) { + this.scalar = scalar; + this.subExtractors = subExtractors; + } + + @Override + public Object extract(final GetResponse response) { + Input[] inputs = new Input[subExtractors.size()]; + int idx = 0; + for (final FieldExtractor subExtractor : subExtractors) { + inputs[idx] = new Input() { + @Override + public Object value() { + return subExtractor.extract(response); + } + }; + idx++; + } + //noinspection unchecked + return scalar.evaluate(inputs); + } + } } diff --git a/sql/src/main/java/io/crate/operation/projectors/FlatProjectorChain.java b/sql/src/main/java/io/crate/operation/projectors/FlatProjectorChain.java index e3deab6f8fed..4057f4bae287 100644 --- a/sql/src/main/java/io/crate/operation/projectors/FlatProjectorChain.java +++ b/sql/src/main/java/io/crate/operation/projectors/FlatProjectorChain.java @@ -44,14 +44,12 @@ */ public class FlatProjectorChain { - private final ProjectionToProjectorVisitor projectorVisitor; private Projector firstProjector; private final List projectors; private ResultProvider lastProjector; public FlatProjectorChain(List projections, ProjectionToProjectorVisitor projectorVisitor) { projectors = new ArrayList<>(); - this.projectorVisitor = projectorVisitor; if (projections.size() == 0) { firstProjector = new CollectingProjector(); lastProjector = (ResultProvider)firstProjector; diff --git a/sql/src/main/java/io/crate/operation/projectors/Projector.java b/sql/src/main/java/io/crate/operation/projectors/Projector.java index d565ec6054a9..5f2554cf7435 100644 --- a/sql/src/main/java/io/crate/operation/projectors/Projector.java +++ b/sql/src/main/java/io/crate/operation/projectors/Projector.java @@ -40,7 +40,6 @@ public interface Projector extends ProjectorUpstream { * * This method must be thread safe. * - * @param row * @return false if this projection does not need any more rows, true otherwise. */ public boolean setNextRow(Object ... row); diff --git a/sql/src/main/java/io/crate/planner/Planner.java b/sql/src/main/java/io/crate/planner/Planner.java index 466113f79d3f..7cd2db466fbe 100644 --- a/sql/src/main/java/io/crate/planner/Planner.java +++ b/sql/src/main/java/io/crate/planner/Planner.java @@ -116,10 +116,10 @@ protected Plan visitSelectAnalysis(SelectAnalysis analysis, Context context) { globalAggregates(analysis, plan, context); } else { WhereClause whereClause = analysis.whereClause(); - if (analysis.rowGranularity().ordinal() >= RowGranularity.DOC.ordinal() && + if (!context.indexWriterProjection.isPresent() + && analysis.rowGranularity().ordinal() >= RowGranularity.DOC.ordinal() && analysis.table().getRouting(whereClause).hasLocations() && - (analysis.table().ident().schema() == null || analysis.table().ident().schema().equals(DocSchemaInfo.NAME)) - && (analysis.isLimited() || analysis.ids().size() > 0 || !context.indexWriterProjection.isPresent())) { + analysis.table().schemaInfo().name().equals(DocSchemaInfo.NAME)) { if (analysis.ids().size() > 0 && analysis.routingValues().size() > 0 @@ -479,10 +479,7 @@ private void ESDeleteByQuery(DeleteAnalysis.NestedDeleteAnalysis analysis, Plan } private void ESGet(SelectAnalysis analysis, Plan plan, Context context) { - PlannerContextBuilder contextBuilder = new PlannerContextBuilder() - .output(analysis.outputSymbols()) - .orderBy(analysis.sortSymbols()); - + assert !context.indexWriterProjection.isPresent() : "shouldn't use ESGet with indexWriterProjection"; String indexName; if (analysis.table().isPartitioned()) { assert analysis.whereClause().partitions().size() == 1 : "ambiguous partitions for ESGet"; @@ -490,34 +487,19 @@ private void ESGet(SelectAnalysis analysis, Plan plan, Context context) { } else { indexName = analysis.table().ident().name(); } - ESGetNode getNode = new ESGetNode( indexName, + analysis.outputSymbols(), + extractDataTypes(analysis.outputSymbols()), analysis.ids(), analysis.routingValues(), + analysis.sortSymbols(), + analysis.reverseFlags(), + analysis.nullsFirst(), + analysis.limit(), + analysis.offset(), analysis.table().partitionedByColumns()); - getNode.outputs(contextBuilder.toCollect()); - getNode.outputTypes(extractDataTypes(analysis.outputSymbols())); plan.add(getNode); - - // handle sorting, limit and offset - if (analysis.isSorted() || analysis.limit() != null - || analysis.offset() > 0 - || context.indexWriterProjection.isPresent()) { - TopNProjection tnp = new TopNProjection( - Objects.firstNonNull(analysis.limit(), Constants.DEFAULT_SELECT_LIMIT), - analysis.offset(), - contextBuilder.orderBy(), - analysis.reverseFlags(), - analysis.nullsFirst() - ); - tnp.outputs(contextBuilder.outputs()); - ImmutableList.Builder projectionBuilder = ImmutableList.builder().add(tnp); - if (context.indexWriterProjection.isPresent()) { - projectionBuilder.add(context.indexWriterProjection.get()); - } - plan.add(PlanNodeBuilder.localMerge(projectionBuilder.build(), getNode)); - } } private void normalSelect(SelectAnalysis analysis, Plan plan, Context context) { @@ -1091,7 +1073,7 @@ private void ESIndex(InsertFromValuesAnalysis analysis, Plan plan) { } - static List extractDataTypes(List symbols) { + public static List extractDataTypes(List symbols) { List types = new ArrayList<>(symbols.size()); for (Symbol symbol : symbols) { types.add(DataTypeVisitor.fromSymbol(symbol)); diff --git a/sql/src/main/java/io/crate/planner/node/dql/ESGetNode.java b/sql/src/main/java/io/crate/planner/node/dql/ESGetNode.java index bb2bfcc1a53e..43a26606f42c 100644 --- a/sql/src/main/java/io/crate/planner/node/dql/ESGetNode.java +++ b/sql/src/main/java/io/crate/planner/node/dql/ESGetNode.java @@ -25,9 +25,10 @@ import com.google.common.collect.ImmutableList; import io.crate.metadata.ReferenceInfo; import io.crate.planner.node.PlanVisitor; +import io.crate.planner.symbol.Symbol; +import io.crate.types.DataType; import org.elasticsearch.common.Nullable; -import java.util.Arrays; import java.util.List; @@ -36,27 +37,38 @@ public class ESGetNode extends ESDQLPlanNode implements DQLPlanNode { private final String index; private final List ids; private final List routingValues; + private final List sortSymbols; + private final boolean[] reverseFlags; + private final Boolean[] nullsFirst; + private final Integer limit; + private final int offset; private final List partitionBy; + private final static boolean[] EMPTY_REVERSE_FLAGS = new boolean[0]; + private final static Boolean[] EMPTY_NULLS_FIRST = new Boolean[0]; + public ESGetNode(String index, + List outputs, + List outputTypes, List ids, List routingValues, + @Nullable List sortSymbols, + @Nullable boolean[] reverseFlags, + @Nullable Boolean[] nullsFirst, + @Nullable Integer limit, + int offset, @Nullable List partitionBy) { this.index = index; + this.outputs = outputs; + outputTypes(outputTypes); this.ids = ids; this.routingValues = routingValues; - this.partitionBy = Objects.firstNonNull(partitionBy, - ImmutableList.of()); - } - - public ESGetNode(String index, - List ids, - List routingValues) { - this(index, ids, routingValues, null); - } - - public ESGetNode(String index, String id, String routingValue) { - this(index, Arrays.asList(id), Arrays.asList(routingValue), null); + this.sortSymbols = Objects.firstNonNull(sortSymbols, ImmutableList.of()); + this.reverseFlags = Objects.firstNonNull(reverseFlags, EMPTY_REVERSE_FLAGS); + this.nullsFirst = Objects.firstNonNull(nullsFirst, EMPTY_NULLS_FIRST); + this.limit = limit; + this.offset = offset; + this.partitionBy = Objects.firstNonNull(partitionBy, ImmutableList.of()); } public String index() { @@ -76,6 +88,28 @@ public List routingValues() { return routingValues; } + @Nullable + public Integer limit() { + return limit; + } + + public int offset() { + return offset; + } + + public List sortSymbols() { + return sortSymbols; + } + + public boolean[] reverseFlags() { + return reverseFlags; + } + + public Boolean[] nullsFirst() { + return nullsFirst; + } + + public List partitionBy() { return partitionBy; } diff --git a/sql/src/test/java/io/crate/executor/transport/TransportExecutorTest.java b/sql/src/test/java/io/crate/executor/transport/TransportExecutorTest.java index 70e9b48b3735..5144ef414aaf 100644 --- a/sql/src/test/java/io/crate/executor/transport/TransportExecutorTest.java +++ b/sql/src/test/java/io/crate/executor/transport/TransportExecutorTest.java @@ -42,6 +42,7 @@ import io.crate.operation.projectors.TopN; import io.crate.operation.scalar.DateTruncFunction; import io.crate.planner.Plan; +import io.crate.planner.Planner; import io.crate.planner.RowGranularity; import io.crate.planner.node.dml.ESDeleteByQueryNode; import io.crate.planner.node.dml.ESDeleteNode; @@ -105,6 +106,26 @@ public class TransportExecutorTest extends SQLTransportIntegrationTest { Reference parted_date_ref = new Reference(new ReferenceInfo( new ReferenceIdent(partedTable, "date"), RowGranularity.DOC, DataTypes.TIMESTAMP)); + private static ESGetNode newGetNode(String index, List outputs, String id) { + return newGetNode(index, outputs, Arrays.asList(id)); + } + + private static ESGetNode newGetNode(String index, List outputs, List ids) { + return new ESGetNode( + index, + outputs, + Planner.extractDataTypes(outputs), + ids, + ids, + ImmutableList.of(), + new boolean[0], + new Boolean[0], + null, + 0, + null + ); + } + @Before public void transportSetUp() { CrateTestCluster cluster = cluster(); @@ -196,8 +217,8 @@ public void testMapSideCollectTask() throws Exception { public void testESGetTask() throws Exception { insertCharacters(); - ESGetNode node = new ESGetNode("characters", "2", "2"); - node.outputs(ImmutableList.of(id_ref, name_ref)); + ImmutableList outputs = ImmutableList.of(id_ref, name_ref); + ESGetNode node = newGetNode("characters", outputs, "2"); Plan plan = new Plan(); plan.add(node); Job job = executor.newJob(plan); @@ -213,9 +234,9 @@ public void testESGetTask() throws Exception { public void testESGetTaskWithDynamicReference() throws Exception { insertCharacters(); - ESGetNode node = new ESGetNode("characters", "2", "2"); - node.outputs(ImmutableList.of(id_ref, new DynamicReference( - new ReferenceIdent(new TableIdent(null, "characters"), "foo"), RowGranularity.DOC))); + ImmutableList outputs = ImmutableList.of(id_ref, new DynamicReference( + new ReferenceIdent(new TableIdent(null, "characters"), "foo"), RowGranularity.DOC)); + ESGetNode node = newGetNode("characters", outputs, "2"); Plan plan = new Plan(); plan.add(node); Job job = executor.newJob(plan); @@ -230,8 +251,8 @@ public void testESGetTaskWithDynamicReference() throws Exception { @Test public void testESMultiGet() throws Exception { insertCharacters(); - ESGetNode node = new ESGetNode("characters", asList("1", "2"), asList("1", "2")); - node.outputs(ImmutableList.of(id_ref, name_ref)); + ImmutableList outputs = ImmutableList.of(id_ref, name_ref); + ESGetNode node = newGetNode("characters", outputs, asList("1", "2")); Plan plan = new Plan(); plan.add(node); Job job = executor.newJob(plan); @@ -472,8 +493,8 @@ public void testESDeleteTask() throws Exception { assertThat(taskResult.rowCount(), is(1L)); // verify deletion - ESGetNode getNode = new ESGetNode("characters", "2", "2"); - getNode.outputs(ImmutableList.of(id_ref, name_ref)); + ImmutableList outputs = ImmutableList.of(id_ref, name_ref); + ESGetNode getNode = newGetNode("characters", outputs, "2"); plan = new Plan(); plan.add(getNode); job = executor.newJob(plan); @@ -514,8 +535,8 @@ public void testESIndexTask() throws Exception { // verify insertion - ESGetNode getNode = new ESGetNode("characters", "99", "99"); - getNode.outputs(ImmutableList.of(id_ref, name_ref)); + ImmutableList outputs = ImmutableList.of(id_ref, name_ref); + ESGetNode getNode = newGetNode("characters", outputs, "99"); plan = new Plan(); plan.add(getNode); job = executor.newJob(plan); @@ -633,10 +654,8 @@ public void testESBulkInsertTask() throws Exception { // verify insertion - ESGetNode getNode = new ESGetNode("characters", - Arrays.asList("99", "42"), - Arrays.asList("99", "42")); - getNode.outputs(ImmutableList.of(id_ref, name_ref)); + ImmutableList outputs = ImmutableList.of(id_ref, name_ref); + ESGetNode getNode = newGetNode("characters", outputs, Arrays.asList("99", "42")); plan = new Plan(); plan.add(getNode); job = executor.newJob(plan); @@ -681,8 +700,8 @@ public void testESUpdateByIdTask() throws Exception { assertThat(taskResult.rowCount(), is(1L)); // verify update - ESGetNode getNode = new ESGetNode("characters", Arrays.asList("1"), Arrays.asList("1")); - getNode.outputs(ImmutableList.of(id_ref, name_ref)); + ImmutableList outputs = ImmutableList.of(id_ref, name_ref); + ESGetNode getNode = newGetNode("characters", outputs, "1"); plan = new Plan(); plan.add(getNode); job = executor.newJob(plan); @@ -734,8 +753,8 @@ public void testUpdateByQueryTaskWithVersion() throws Exception { assertThat(result.get(0).get().errorMessage(), is(nullValue())); assertThat(result.get(0).get().rowCount(), is(1L)); - ESGetNode getNode = new ESGetNode("characters", "1", "1"); - getNode.outputs(Arrays.asList(id_ref, name_ref, version_ref)); + List outputs = Arrays.asList(id_ref, name_ref, version_ref); + ESGetNode getNode = newGetNode("characters", outputs, "1"); plan = new Plan(); plan.add(getNode); plan.expectsAffectedRows(false); diff --git a/sql/src/test/java/io/crate/integrationtests/WherePKIntegrationTest.java b/sql/src/test/java/io/crate/integrationtests/WherePKIntegrationTest.java new file mode 100644 index 000000000000..1481ed835b55 --- /dev/null +++ b/sql/src/test/java/io/crate/integrationtests/WherePKIntegrationTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to CRATE Technology GmbH ("Crate") under one or more contributor + * license agreements. See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. Crate licenses + * this file to you under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * However, if you have executed another commercial license agreement + * with Crate these terms will supersede the license and you may use the + * software solely pursuant to the terms of the relevant commercial agreement. + */ + +package io.crate.integrationtests; + +import io.crate.test.integration.CrateIntegrationTest; +import org.junit.Test; + +import static org.hamcrest.core.Is.is; + +@CrateIntegrationTest.ClusterScope(scope = CrateIntegrationTest.Scope.GLOBAL) +public class WherePKIntegrationTest extends SQLTransportIntegrationTest { + + @Test + public void testWherePkColInWithLimit() throws Exception { + execute("create table users (" + + " id int primary key," + + " name string" + + ") clustered into 2 shards with (number_of_replicas = 0)"); + ensureGreen(); + execute("insert into users (id, name) values (?, ?)", new Object[][] { + new Object[] { 1, "Arthur" }, + new Object[] { 2, "Trillian" }, + new Object[] { 3, "Marvin" }, + new Object[] { 4, "Slartibartfast" }, + }); + execute("refresh table users"); + + execute("select name from users where id in (1, 3, 4) order by name desc limit 2"); + assertThat(response.rowCount(), is(2L)); + assertThat((String) response.rows()[0][0], is("Slartibartfast")); + assertThat((String) response.rows()[1][0], is("Marvin")); + } + + @Test + public void testWherePKWithFunctionInOutputsAndOrderBy() throws Exception { + execute("create table users (" + + " id int primary key," + + " name string" + + ") clustered into 2 shards with (number_of_replicas = 0)"); + ensureGreen(); + execute("insert into users (id, name) values (?, ?)", new Object[][] { + new Object[] { 1, "Arthur" }, + new Object[] { 2, "Trillian" }, + new Object[] { 3, "Marvin" }, + new Object[] { 4, "Slartibartfast" }, + }); + execute("refresh table users"); + execute("select substr(name, 1, 1) from users where id in (1, 2, 3) order by substr(name, 1, 1) desc"); + assertThat(response.rowCount(), is(3L)); + assertThat((String) response.rows()[0][0], is("T")); + assertThat((String) response.rows()[1][0], is("M")); + assertThat((String) response.rows()[2][0], is("A")); + } + + @Test + public void testWherePkColLimit0() throws Exception { + execute("create table users (id int primary key, name string) " + + "clustered into 1 shards with (number_of_replicas = 0)"); + ensureGreen(); + execute("insert into users (id, name) values (1, 'Arthur')"); + execute("refresh table users"); + execute("select name from users where id = 1 limit 0"); + assertThat(response.rowCount(), is(0L)); + } +} diff --git a/sql/src/test/java/io/crate/planner/PlannerTest.java b/sql/src/test/java/io/crate/planner/PlannerTest.java index 5f8583291fc5..1b370078d980 100644 --- a/sql/src/test/java/io/crate/planner/PlannerTest.java +++ b/sql/src/test/java/io/crate/planner/PlannerTest.java @@ -116,6 +116,7 @@ protected void bindSchemas() { .addPrimaryKey("id") .clusteredBy("id") .build(); + when(userTableInfo.schemaInfo().name()).thenReturn(DocSchemaInfo.NAME); TableIdent charactersTableIdent = new TableIdent(null, "characters"); TableInfo charactersTableInfo = TestingTableInfo.builder(charactersTableIdent, RowGranularity.DOC, shardRouting) .add("name", DataTypes.STRING, null) @@ -123,6 +124,7 @@ protected void bindSchemas() { .addPrimaryKey("id") .clusteredBy("id") .build(); + when(charactersTableInfo.schemaInfo().name()).thenReturn(DocSchemaInfo.NAME); TableIdent partedTableIdent = new TableIdent(null, "parted"); TableInfo partedTableInfo = TestingTableInfo.builder(partedTableIdent, RowGranularity.DOC, partedRouting) .add("name", DataTypes.STRING, null) @@ -137,6 +139,7 @@ protected void bindSchemas() { .addPrimaryKey("date") .clusteredBy("id") .build(); + when(partedTableInfo.schemaInfo().name()).thenReturn(DocSchemaInfo.NAME); TableIdent emptyPartedTableIdent = new TableIdent(null, "empty_parted"); TableInfo emptyPartedTableInfo = TestingTableInfo.builder(partedTableIdent, RowGranularity.DOC, shardRouting) .add("name", DataTypes.STRING, null) @@ -146,6 +149,7 @@ protected void bindSchemas() { .addPrimaryKey("date") .clusteredBy("id") .build(); + when(emptyPartedTableInfo.schemaInfo().name()).thenReturn(DocSchemaInfo.NAME); when(schemaInfo.getTableInfo(charactersTableIdent.name())).thenReturn(charactersTableInfo); when(schemaInfo.getTableInfo(userTableIdent.name())).thenReturn(userTableInfo); when(schemaInfo.getTableInfo(partedTableIdent.name())).thenReturn(partedTableInfo); @@ -1063,27 +1067,28 @@ public void testInsertFromSubQueryGlobalAggregate() throws Exception { @Test public void testInsertFromSubQueryESGet() throws Exception { + // doesn't use ESGetNode but CollectNode. + // Round-trip to handler can be skipped by writing from the shards directly Plan plan = plan("insert into users (date, id, name) (select date, id, name from users where id=1)"); Iterator iterator = plan.iterator(); PlanNode planNode = iterator.next(); - assertThat(planNode, instanceOf(ESGetNode.class)); - - planNode = iterator.next(); - assertThat(planNode, instanceOf(MergeNode.class)); - MergeNode localMergeNode = (MergeNode)planNode; + assertThat(planNode, instanceOf(CollectNode.class)); + CollectNode collectNode = (CollectNode) planNode; - assertThat(localMergeNode.projections().size(), is(2)); - assertThat(localMergeNode.projections().get(1), instanceOf(ColumnIndexWriterProjection.class)); - ColumnIndexWriterProjection projection = (ColumnIndexWriterProjection)localMergeNode.projections().get(1); + assertThat(collectNode.projections().size(), is(1)); + assertThat(collectNode.projections().get(0), instanceOf(ColumnIndexWriterProjection.class)); + ColumnIndexWriterProjection projection = (ColumnIndexWriterProjection)collectNode.projections().get(0); assertThat(projection.columnIdents().size(), is(3)); assertThat(projection.columnIdents().get(0).fqn(), is("date")); assertThat(projection.columnIdents().get(1).fqn(), is("id")); assertThat(projection.columnIdents().get(2).fqn(), is("name")); - assertThat(((InputColumn)projection.ids().get(0)).index(), is(1)); assertThat(((InputColumn)projection.clusteredBy()).index(), is(1)); assertThat(projection.partitionedBySymbols().isEmpty(), is(true)); + + planNode = iterator.next(); + assertThat(planNode, instanceOf(MergeNode.class)); } @Test (expected = UnsupportedFeatureException.class)