Skip to content

Commit

Permalink
Change unwind to use a WhileLoop instruction with iterators
Browse files Browse the repository at this point in the history
Introduce two loop data generators that produces iterators for
collections of primitive vs. non-primitive types

This fixes an acceptance tests that failed because we were relying on
Iterable being the base type of the unwind collection expression
when type information was not known at compile time.

Move increment db hits from WhileLoop into the produceNext of the
loop data generators that are actually hitting the database.
  • Loading branch information
alexaverbuch authored and henriknyman committed Feb 1, 2017
1 parent 103f2f9 commit 6802730
Show file tree
Hide file tree
Showing 15 changed files with 262 additions and 67 deletions.
Expand Up @@ -27,7 +27,7 @@ import org.neo4j.cypher.internal.compiler.v3_2.codegen.ir.expressions._
import org.neo4j.cypher.internal.compiler.v3_2.commands.{ManyQueryExpression, QueryExpression, RangeQueryExpression, SingleQueryExpression}
import org.neo4j.cypher.internal.compiler.v3_2.helpers.{One, ZeroOneOrMany}
import org.neo4j.cypher.internal.compiler.v3_2.planner.logical.plans
import org.neo4j.cypher.internal.compiler.v3_2.planner.logical.plans.{UnwindCollection, _}
import org.neo4j.cypher.internal.compiler.v3_2.planner.logical.plans._
import org.neo4j.cypher.internal.compiler.v3_2.planner.{CantCompileQueryException, logical}
import org.neo4j.cypher.internal.frontend.v3_2.ast.Expression
import org.neo4j.cypher.internal.frontend.v3_2.helpers.Eagerly.immutableMapValues
Expand Down Expand Up @@ -55,7 +55,7 @@ object LogicalPlanConverter {
case p: plans.Aggregation => aggregationAsCodeGenPlan(p)
case p: plans.NodeCountFromCountStore => nodeCountFromCountStore(p)
case p: plans.RelationshipCountFromCountStore => relCountFromCountStore(p)
case p: UnwindCollection => unwindAsCodeGenPlan(p)
case p: plans.UnwindCollection => unwindAsCodeGenPlan(p)
case p: Sort => sortAsCodeGenPlan(p)

case _ =>
Expand Down Expand Up @@ -567,18 +567,21 @@ object LogicalPlanConverter {
override val logicalPlan = unwind

override def consume(context: CodeGenContext, child: CodeGenPlan) = {
val collection = ExpressionConverter.createExpression(unwind.expression)(context)
val collection: CodeGenExpression = ExpressionConverter.createExpression(unwind.expression)(context)

// TODO: Handle range
val collectionCodeGenType = collection.codeGenType(context)

val (elementCodeGenType, castedCollection) = collectionCodeGenType match {
case CodeGenType(symbols.ListType(innerType), ListReferenceType(innerReprType)) =>
(CodeGenType(innerType, innerReprType), collection)
val opName = context.registerOperator(logicalPlan)

val (elementCodeGenType, loopDataGenerator) = collectionCodeGenType match {
case CodeGenType(symbols.ListType(innerType), ListReferenceType(innerReprType))
if RepresentationType.isPrimitive(innerReprType) =>
(CodeGenType(innerType, innerReprType), UnwindPrimitiveCollection(opName, collection))
case CodeGenType(symbols.ListType(innerType), _) =>
(CodeGenType(innerType, ReferenceType), collection)
(CodeGenType(innerType, ReferenceType), ir.UnwindCollection(opName, collection))
case CodeGenType(symbols.CTAny, _) =>
(CodeGenType(symbols.CTAny, ReferenceType), CastToCollection(collection))
(CodeGenType(symbols.CTAny, ReferenceType), ir.UnwindCollection(opName, collection))
case t =>
throw new CantCompileQueryException(s"Unwind collection type $t not supported")
}
Expand All @@ -589,7 +592,7 @@ object LogicalPlanConverter {

val (methodHandle, actions :: tl) = context.popParent().consume(context, this)

(methodHandle, ForEachExpression(variable, castedCollection, actions) :: tl)
(methodHandle, WhileLoop(variable, loopDataGenerator, actions) :: tl)
}
}

Expand Down
Expand Up @@ -43,8 +43,10 @@ case class ExpandAllLoopDataGenerator(opName: String, fromVar: Variable, dir: Se
}

override def produceNext[E](nextVar: Variable, iterVar: String, generator: MethodStructure[E])
(implicit context: CodeGenContext) =
(implicit context: CodeGenContext) = {
generator.incrementDbHits()
generator.nextRelationshipAndNode(toVar.name, iterVar, dir, fromVar.name, relVar.name)
}

override def hasNext[E](generator: MethodStructure[E], iterVar: String): E = generator.hasNextRelationship(iterVar)
}
Expand Up @@ -43,8 +43,10 @@ case class ExpandIntoLoopDataGenerator(opName: String, fromVar: Variable, dir: S
}

override def produceNext[E](nextVar: Variable, iterVar: String, generator: MethodStructure[E])
(implicit context: CodeGenContext) =
(implicit context: CodeGenContext) = {
generator.incrementDbHits()
generator.nextRelationship(iterVar, dir, relVar.name)
}

override def hasNext[E](generator: MethodStructure[E], iterVar: String): E = generator.hasNextRelationship(iterVar)
}
Expand Up @@ -41,8 +41,10 @@ case class IndexSeek(opName: String, labelName: String, propName: String, descri
}

override def produceNext[E](nextVar: Variable, iterVar: String, generator: MethodStructure[E])
(implicit context: CodeGenContext) =
(implicit context: CodeGenContext) = {
generator.incrementDbHits()
generator.nextNode(nextVar.name, iterVar)
}

override def hasNext[E](generator: MethodStructure[E], iterVar: String): E = generator.hasNextNode(iterVar)
}
Expand Up @@ -32,8 +32,10 @@ case class ScanAllNodes(opName: String) extends LoopDataGenerator {
}

override def produceNext[E](nextVar: Variable, iterVar: String, generator: MethodStructure[E])
(implicit context: CodeGenContext) =
(implicit context: CodeGenContext) = {
generator.incrementDbHits()
generator.nextNode(nextVar.name, iterVar)
}

override def hasNext[E](generator: MethodStructure[E], iterVar: String): E = generator.hasNextNode(iterVar)
}
Expand Up @@ -33,8 +33,10 @@ case class ScanForLabel(opName: String, labelName: String, labelVar: String) ext
}

override def produceNext[E](nextVar: Variable, iterVar: String, generator: MethodStructure[E])
(implicit context: CodeGenContext) =
(implicit context: CodeGenContext) = {
generator.incrementDbHits()
generator.nextNode(nextVar.name, iterVar)
}

override def hasNext[E](generator: MethodStructure[E], iterVar: String): E = generator.hasNextNode(iterVar)
}
@@ -0,0 +1,45 @@
/*
* Copyright (c) 2002-2017 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.cypher.internal.compiler.v3_2.codegen.ir

import org.neo4j.cypher.internal.compiler.v3_2.codegen.ir.expressions.{CodeGenExpression, CodeGenType}
import org.neo4j.cypher.internal.compiler.v3_2.codegen.{CodeGenContext, Variable}
import org.neo4j.cypher.internal.compiler.v3_2.codegen.spi.MethodStructure

case class UnwindCollection(opName: String, collection: CodeGenExpression) extends LoopDataGenerator {
override def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext): Unit =
collection.init(generator)

override def produceIterator[E](iterVar: String, generator: MethodStructure[E])
(implicit context: CodeGenContext): Unit = {
generator.declareIterator(iterVar)
val iterator = generator.iteratorFrom(collection.generateExpression(generator))
generator.assign(iterVar, CodeGenType.Any, iterator)
}

override def produceNext[E](nextVar: Variable, iterVar: String, generator: MethodStructure[E])
(implicit context: CodeGenContext): Unit = {
val next = generator.iteratorNext(generator.loadVariable(iterVar))
generator.assign(nextVar.name, CodeGenType.Any, next)
}

override def hasNext[E](generator: MethodStructure[E], iterVar: String): E =
generator.iteratorHasNext(generator.loadVariable(iterVar))
}
@@ -0,0 +1,50 @@
/*
* Copyright (c) 2002-2017 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.cypher.internal.compiler.v3_2.codegen.ir

import org.neo4j.cypher.internal.compiler.v3_2.codegen.ir.expressions._
import org.neo4j.cypher.internal.compiler.v3_2.codegen.spi.MethodStructure
import org.neo4j.cypher.internal.compiler.v3_2.codegen.{CodeGenContext, Variable}
import org.neo4j.cypher.internal.frontend.v3_2.symbols

case class UnwindPrimitiveCollection(opName: String, collection: CodeGenExpression) extends LoopDataGenerator {
override def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext): Unit =
collection.init(generator)

override def produceIterator[E](iterVar: String, generator: MethodStructure[E])
(implicit context: CodeGenContext): Unit = {
generator.declarePrimitiveIterator(iterVar, collection.codeGenType)
val iterator = generator.primitiveIteratorFrom(collection.generateExpression(generator), collection.codeGenType)
generator.assign(iterVar, CodeGenType.Any, iterator)
}

override def produceNext[E](nextVar: Variable, iterVar: String, generator: MethodStructure[E])
(implicit context: CodeGenContext): Unit = {
val elementType = collection.codeGenType match {
case CodeGenType(symbols.ListType(innerCt), ListReferenceType(innerRepr)) => CodeGenType(innerCt, innerRepr)
case _ => throw new IllegalArgumentException(s"CodeGenType $collection.codeGenType not supported as primitive iterator")
}
val next = generator.primitiveIteratorNext(generator.loadVariable(iterVar), collection.codeGenType)
generator.assign(nextVar.name, elementType, next)
}

override def hasNext[E](generator: MethodStructure[E], iterVar: String): E =
generator.iteratorHasNext(generator.loadVariable(iterVar))
}
Expand Up @@ -29,7 +29,6 @@ case class WhileLoop(variable: Variable, producer: LoopDataGenerator, action: In
generator.trace(producer.opName) { body =>
producer.produceIterator(iterator, body)
body.whileLoop(producer.hasNext(body, iterator)) { loopBody =>
loopBody.incrementDbHits()
loopBody.incrementRows()
producer.produceNext(variable, iterator, loopBody)
action.body(loopBody)
Expand Down
Expand Up @@ -63,6 +63,16 @@ trait MethodStructure[E] {
def asList(values: Seq[E]): E
def asPrimitiveStream(values: Seq[E], codeGenType: CodeGenType): E

def declarePrimitiveIterator(name: String, iterableCodeGenType: CodeGenType): Unit
def primitiveIteratorFrom(iterable: E, iterableCodeGenType: CodeGenType): E
def primitiveIteratorNext(iterator: E, iterableCodeGenType: CodeGenType): E
def primitiveIteratorHasNext(iterator: E, iterableCodeGenType: CodeGenType): E

def declareIterator(name: String): Unit
def iteratorFrom(iterable: E): E
def iteratorNext(iterator: E): E
def iteratorHasNext(iterator: E): E

def toSet(value: E): E
def newDistinctSet(name: String, codeGenTypes: Iterable[CodeGenType])
def distinctSetIfNotContains(name: String, structure: Map[String,(CodeGenType,E)])(block: MethodStructure[E] => Unit)
Expand Down Expand Up @@ -210,4 +220,4 @@ sealed trait TupleDescriptor {
case class SimpleTupleDescriptor(structure: Map[String, CodeGenType]) extends TupleDescriptor
case class HashableTupleDescriptor(structure: Map[String, CodeGenType]) extends TupleDescriptor
case class OrderableTupleDescriptor(structure: Map[String, CodeGenType],
sortItems: Iterable[SortItem]) extends TupleDescriptor
sortItems: Iterable[SortItem]) extends TupleDescriptor
Expand Up @@ -23,6 +23,7 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -230,4 +231,36 @@ else if ( anyValue instanceof IntStream )
return anyValue;
}
}

public static final Iterator iteratorFrom( Object iterable )
{
if ( iterable instanceof Iterable )
{
return ((Iterable) iterable).iterator();
}
else if ( iterable instanceof PrimitiveEntityStream )
{
return ((PrimitiveEntityStream) iterable).iterator();
}
else if ( iterable instanceof LongStream )
{
return ((LongStream) iterable).iterator();
}
else if ( iterable instanceof DoubleStream )
{
return ((DoubleStream) iterable).iterator();
}
else if ( iterable instanceof IntStream )
{
return ((IntStream) iterable).iterator();
}
else if ( iterable == null )
{
return Collections.emptyIterator();
}
else
{
throw new CypherTypeException( "Don't know how to create an iterator out of " + iterable.getClass().getSimpleName(), null );
}
}
}
Expand Up @@ -20,39 +20,27 @@
package org.neo4j.cypher.internal.codegen;

import java.util.Iterator;
import java.util.Spliterator;
import java.util.function.Consumer;
import java.util.PrimitiveIterator;
import java.util.stream.LongStream;

public abstract class PrimitiveEntityStream implements Iterable<Long>
public abstract class PrimitiveEntityStream<T>
{
private final LongStream inner;
protected final LongStream inner;

public PrimitiveEntityStream( LongStream inner )
{
this.inner = inner;
}

public LongStream longStream()
{
return inner;
}

@Override
public Iterator<Long> iterator()
public PrimitiveIterator.OfLong primitiveIterator()
{
return inner.iterator();
}

@Override
public void forEach( Consumer<? super Long> action )
public LongStream longStream()
{
throw new UnsupportedOperationException( "PrimitiveEntityStream does not support forEach" );
return inner;
}

@Override
public Spliterator<Long> spliterator()
{
throw new UnsupportedOperationException( "PrimitiveEntityStream does not support spliterator" );
}
public abstract Iterator<T> iterator();
}
Expand Up @@ -19,9 +19,10 @@
*/
package org.neo4j.cypher.internal.codegen;

import java.util.Iterator;
import java.util.stream.LongStream;

public class PrimitiveNodeStream extends PrimitiveEntityStream
public class PrimitiveNodeStream extends PrimitiveEntityStream<NodeIdWrapper>
{
public PrimitiveNodeStream( LongStream inner )
{
Expand All @@ -32,4 +33,11 @@ public static PrimitiveNodeStream of( long[] array )
{
return new PrimitiveNodeStream( LongStream.of( array ) );
}

@Override
// This method is only used when we do not know the element type at compile time, so it has to box the elements
public Iterator<NodeIdWrapper> iterator()
{
return inner.mapToObj( NodeIdWrapper::new ).iterator();
}
}
Expand Up @@ -19,9 +19,10 @@
*/
package org.neo4j.cypher.internal.codegen;

import java.util.Iterator;
import java.util.stream.LongStream;

public class PrimitiveRelationshipStream extends PrimitiveEntityStream
public class PrimitiveRelationshipStream extends PrimitiveEntityStream<RelationshipIdWrapper>
{
public PrimitiveRelationshipStream( LongStream inner )
{
Expand All @@ -32,4 +33,11 @@ public static PrimitiveRelationshipStream of( long[] array )
{
return new PrimitiveRelationshipStream( LongStream.of( array ) );
}

@Override
// This method is only used when we do not know the element type at compile time, so it has to box the elements
public Iterator<RelationshipIdWrapper> iterator()
{
return inner.mapToObj( RelationshipIdWrapper::new ).iterator();
}
}

0 comments on commit 6802730

Please sign in to comment.