From e96bab0a76f8f3768f1bc7cc4cabb5678104cc4e Mon Sep 17 00:00:00 2001 From: Pontus Melke Date: Fri, 20 Jul 2018 08:27:40 +0200 Subject: [PATCH] CoerceTo list must be done recursively --- .../cypher/operations/CypherCoercions.java | 239 ++++++++++++++++-- .../IntermediateCodeGeneration.scala | 41 ++- .../expressions/CodeGenerationTest.scala | 46 ++-- 3 files changed, 278 insertions(+), 48 deletions(-) diff --git a/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/operations/CypherCoercions.java b/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/operations/CypherCoercions.java index 9f4598a011d5..0aa2e98a6bf1 100644 --- a/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/operations/CypherCoercions.java +++ b/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/operations/CypherCoercions.java @@ -21,8 +21,14 @@ import org.opencypher.v9_0.util.CypherTypeException; +import java.util.HashMap; +import java.util.Map; + import org.neo4j.cypher.internal.runtime.DbAccess; +import org.neo4j.internal.kernel.api.procs.Neo4jTypes; import org.neo4j.values.AnyValue; +import org.neo4j.values.SequenceValue; +import org.neo4j.values.ValueMapper; import org.neo4j.values.storable.ArrayValue; import org.neo4j.values.storable.BooleanValue; import org.neo4j.values.storable.DateTimeValue; @@ -42,10 +48,30 @@ import org.neo4j.values.virtual.NodeValue; import org.neo4j.values.virtual.PathValue; import org.neo4j.values.virtual.RelationshipValue; +import org.neo4j.values.virtual.VirtualNodeValue; +import org.neo4j.values.virtual.VirtualRelationshipValue; import org.neo4j.values.virtual.VirtualValues; import static java.lang.String.format; -import static org.neo4j.values.storable.Values.NO_VALUE; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTAny; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTBoolean; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTDate; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTDateTime; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTDuration; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTFloat; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTGeometry; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTInteger; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTLocalDateTime; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTLocalTime; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTMap; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTNode; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTNumber; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTPath; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTPoint; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTRelationship; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTString; +import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTTime; +import static org.neo4j.values.SequenceValue.IterationPreference.RANDOM_ACCESS; @SuppressWarnings( {"unused", "WeakerAccess"} ) public final class CypherCoercions @@ -210,52 +236,219 @@ else if ( value instanceof RelationshipValue ) } } - public static ListValue asListValueFailOnPaths( AnyValue value ) + public static ListValue asList( AnyValue value, Neo4jTypes.AnyType innerType, DbAccess access ) + { + return new ListCoercer().apply( value, innerType, access ); + } + + private static CypherTypeException cantCoerce( AnyValue value, String type ) { - if ( value instanceof PathValue ) + return new CypherTypeException( format( "Can't coerce `%s` to %s", value, type ), null ); + } + + private static class ListMapper implements ValueMapper + { + + @Override + public ListValue mapPath( PathValue value ) { - throw cantCoerce( value, "List" ); + return null; } - else + + @Override + public ListValue mapNode( VirtualNodeValue value ) { - return asList( value ); + return null; } - } - public static ListValue asListValueSupportPaths( AnyValue value ) - { - if ( value instanceof PathValue ) + @Override + public ListValue mapRelationship( VirtualRelationshipValue value ) { - return ((PathValue) value).asList(); + return null; } - else + + @Override + public ListValue mapMap( MapValue value ) + { + return null; + } + + @Override + public ListValue mapNoValue() + { + return null; + } + + @Override + public ListValue mapSequence( SequenceValue value ) + { + return null; + } + + @Override + public ListValue mapText( TextValue value ) + { + return null; + } + + @Override + public ListValue mapBoolean( BooleanValue value ) + { + return null; + } + + @Override + public ListValue mapNumber( NumberValue value ) + { + return null; + } + + @Override + public ListValue mapDateTime( DateTimeValue value ) + { + return null; + } + + @Override + public ListValue mapLocalDateTime( LocalDateTimeValue value ) + { + return null; + } + + @Override + public ListValue mapDate( DateValue value ) { - return asList( value ); + return null; + } + + @Override + public ListValue mapTime( TimeValue value ) + { + return null; + } + + @Override + public ListValue mapLocalTime( LocalTimeValue value ) + { + return null; + } + + @Override + public ListValue mapDuration( DurationValue value ) + { + return null; + } + + @Override + public ListValue mapPoint( PointValue value ) + { + return null; } } - private static CypherTypeException cantCoerce( AnyValue value, String type ) + + @FunctionalInterface + interface Coercer { - return new CypherTypeException( format( "Can't coerce `%s` to %s", value, type ), null ); + AnyValue apply( AnyValue value, Neo4jTypes.AnyType coerceTo, DbAccess access ); + } + + private final static Map,Coercer> CONVERTERS = new HashMap<>(); + + private AnyValue coerceTo( AnyValue value, DbAccess access, Neo4jTypes.AnyType types ) + { + Coercer function = CONVERTERS.get( types.getClass() ); + + return function.apply( value, types, access ); } - private static ListValue asList( AnyValue value ) + private static class ListCoercer implements Coercer + { + @Override + public ListValue apply( AnyValue value, Neo4jTypes.AnyType innerType, DbAccess access ) + { + //Fast route + if ( innerType == NTAny ) + { + return fastListConversion( value ); + } + + //slow route, recursively convert the list + if ( !(value instanceof SequenceValue) ) + { + throw cantCoerce( value, "List" ); + } + SequenceValue listValue = (SequenceValue) value; + Coercer innerCoercer = CONVERTERS.get( innerType.getClass() ); + AnyValue[] coercedValues = new AnyValue[listValue.length()]; + Neo4jTypes.AnyType nextInner = nextInner( innerType ); + if ( listValue.iterationPreference() == RANDOM_ACCESS ) + { + for ( int i = 0; i < coercedValues.length; i++ ) + { + coercedValues[i] = innerCoercer.apply( listValue.value( i ), nextInner, access ); + } + } + else + { + int i = 0; + for ( AnyValue anyValue : listValue ) + { + coercedValues[i++] = innerCoercer.apply( anyValue, nextInner, access ); + } + } + return VirtualValues.list( coercedValues ); + } + } + + private static Neo4jTypes.AnyType nextInner( Neo4jTypes.AnyType type ) + { + if (type instanceof Neo4jTypes.ListType ) + { + return ((Neo4jTypes.ListType) type).innerType(); + } + else + { + return type; + } + } + + private static ListValue fastListConversion( AnyValue value ) { if ( value instanceof ListValue ) { return (ListValue) value; } - else if ( value instanceof ArrayValue ) + else if (value instanceof ArrayValue ) { return VirtualValues.fromArray( (ArrayValue) value ); } - else if ( value == NO_VALUE ) - { - return VirtualValues.EMPTY_LIST; - } - else + else if ( value instanceof PathValue ) { - return VirtualValues.list( value ); + return ((PathValue) value).asList(); } + throw cantCoerce( value, "List" ); + } + + static { + CONVERTERS.put( NTAny.getClass(), (a, ignore1, ignore2) -> a ); + CONVERTERS.put( NTString.getClass(), (a, ignore1, ignore2) -> asTextValue(a) ); + CONVERTERS.put( NTNumber.getClass(), (a, ignore1, ignore2) -> asNumberValue(a) ); + CONVERTERS.put( NTInteger.getClass(), (a, ignore1, ignore2) -> asIntegralValue(a) ); + CONVERTERS.put( NTFloat.getClass(), (a, ignore1, ignore2) -> asFloatingPointValue(a) ); + CONVERTERS.put( NTBoolean.getClass(), (a, ignore1, ignore2) -> asBooleanValue(a) ); + CONVERTERS.put( NTMap.getClass(), (a, ignore, c) -> asMapValue(a, c) ); + CONVERTERS.put( NTNode.getClass(), (a, ignore1, ignore2) ->asNodeValue(a) ); + CONVERTERS.put( NTRelationship.getClass(), (a, ignore1, ignore2) -> asRelationshipValue(a) ); + CONVERTERS.put( NTPath.getClass(), (a, ignore1, ignore2) -> asPathValue(a) ); + CONVERTERS.put( NTGeometry.getClass(), (a, ignore1, ignore2) ->asPointValue(a) ); + CONVERTERS.put( NTPoint.getClass(), (a, ignore1, ignore2) ->asPointValue(a) ); + CONVERTERS.put( NTDateTime.getClass(),(a, ignore1, ignore2) -> asDateTimeValue(a) ); + CONVERTERS.put( NTLocalDateTime.getClass(), (a, ignore1, ignore2) -> asLocalDateTimeValue(a) ); + CONVERTERS.put( NTDate.getClass(), (a, ignore1, ignore2) -> asDateValue(a) ); + CONVERTERS.put( NTTime.getClass(), (a, ignore1, ignore2) -> asTimeValue(a) ); + CONVERTERS.put( NTLocalTime.getClass(), (a, ignore1, ignore2) -> asLocalTimeValue(a) ); + CONVERTERS.put( NTDuration.getClass(), (a, ignore1, ignore2) -> asDurationValue(a) ); + CONVERTERS.put( Neo4jTypes.ListType.class, new ListCoercer() ); } } diff --git a/enterprise/cypher/compiled-expressions/src/main/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/IntermediateCodeGeneration.scala b/enterprise/cypher/compiled-expressions/src/main/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/IntermediateCodeGeneration.scala index 1ada53c7780e..80b017e8bc88 100644 --- a/enterprise/cypher/compiled-expressions/src/main/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/IntermediateCodeGeneration.scala +++ b/enterprise/cypher/compiled-expressions/src/main/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/IntermediateCodeGeneration.scala @@ -33,12 +33,14 @@ import org.neo4j.cypher.internal.runtime.interpreted.ExecutionContext import org.neo4j.cypher.internal.runtime.interpreted.pipes.NestedPipeExpression import org.neo4j.cypher.internal.v3_5.logical.plans.{CoerceToPredicate, NestedPlanExpression} import org.neo4j.cypher.operations.{CypherBoolean, CypherCoercions, CypherFunctions, CypherMath} +import org.neo4j.internal.kernel.api.procs.Neo4jTypes +import org.neo4j.internal.kernel.api.procs.Neo4jTypes.AnyType import org.neo4j.values.AnyValue import org.neo4j.values.storable._ import org.neo4j.values.virtual._ import org.opencypher.v9_0.expressions import org.opencypher.v9_0.expressions._ -import org.opencypher.v9_0.util.symbols.{CTAny, CTBoolean, CTDate, CTDateTime, CTDuration, CTFloat, CTGeometry, CTInteger, CTLocalDateTime, CTLocalTime, CTMap, CTNode, CTNumber, CTPath, CTPoint, CTRelationship, CTString, CTTime, ListType} +import org.opencypher.v9_0.util.symbols.{CTAny, CTBoolean, CTDate, CTDateTime, CTDuration, CTFloat, CTGeometry, CTInteger, CTLocalDateTime, CTLocalTime, CTMap, CTNode, CTNumber, CTPath, CTPoint, CTRelationship, CTString, CTTime, CypherType, ListType} import org.opencypher.v9_0.util.{CypherTypeException, InternalException} /** @@ -321,14 +323,11 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) { noValueCheck(e)(invokeStatic(method[CypherCoercions, MapValue, AnyValue, DbAccess]("asMapValue"), e.ir, DB_ACCESS)), nullable = false, e.fields) - case t: ListType if t.innerType == CTNode || t.innerType == CTRelationship => - IntermediateExpression( - noValueCheck(e)(invokeStatic(method[CypherCoercions, ListValue, AnyValue]("asListValueFailOnPaths"), e.ir)), - nullable = false, e.fields) + case l: ListType => + val typ = asNeoType(l.innerType) - case _: ListType => IntermediateExpression( - noValueCheck(e)(invokeStatic(method[CypherCoercions, ListValue, AnyValue]("asListValueSupportPaths"), e.ir)), + noValueCheck(e)(invokeStatic(method[CypherCoercions, ListValue, AnyValue, AnyType, DbAccess]("asList"), e.ir, typ, DB_ACCESS)), nullable = false, e.fields) case CTBoolean => @@ -413,7 +412,7 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) { Some(IntermediateExpression( block( condition(equal(loadField(f), constant(-1)))( - setField(f, invoke(DB_ACCESS, method[DbAccess, Int, String]("getPropertyKeyId"), constant(key)))), + setField(f, invoke(DB_ACCESS, method[DbAccess, Int, String]("propertyKeyId"), constant(key)))), ternary( invoke(DB_ACCESS, method[DbAccess, Boolean, Long, Int]("nodeHasProperty"), getLongAt(offset), loadField(f)), truthValue, falseValue)), nullable = false, Seq(f))) @@ -428,7 +427,7 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) { Some(IntermediateExpression( block( condition(equal(loadField(f), constant(-1)))( - setField(f, invoke(DB_ACCESS, method[DbAccess, Int, String]("getPropertyKeyId"), constant(key)))), + setField(f, invoke(DB_ACCESS, method[DbAccess, Int, String]("propertyKeyId"), constant(key)))), invoke(DB_ACCESS, method[DbAccess, Value, Long, Int]("relationshipProperty"), getLongAt(offset), loadField(f))), nullable = true, Seq(f))) @@ -446,7 +445,7 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) { Some(IntermediateExpression( block( condition(equal(loadField(f), constant(-1)))( - setField(f, invoke(DB_ACCESS, method[DbAccess, Int, String]("getPropertyKeyId"), constant(key)))), + setField(f, invoke(DB_ACCESS, method[DbAccess, Int, String]("propertyKeyId"), constant(key)))), ternary( invoke(DB_ACCESS, method[DbAccess, Boolean, Long, Int]("relationshipHasProperty"), getLongAt(offset), loadField(f)), @@ -1091,6 +1090,28 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) { if (nullable) ternary(load(seenNull), noValue, load(returnValue)) else load(returnValue)): _*) IntermediateExpression(ir, nullable, expressions.foldLeft(Seq.empty[Field])((a,b) => a ++ b.fields)) } + + private def asNeoType(ct: CypherType): IntermediateRepresentation = ct match { + case CTString => getStatic[Neo4jTypes, Neo4jTypes.TextType]("NTString") + case CTInteger => getStatic[Neo4jTypes, Neo4jTypes.IntegerType]("NTInteger") + case CTFloat => getStatic[Neo4jTypes, Neo4jTypes.FloatType]("NTFloat") + case CTNumber => getStatic[Neo4jTypes, Neo4jTypes.NumberType]("NTNumber") + case CTBoolean => getStatic[Neo4jTypes, Neo4jTypes.BooleanType]("NTBoolean") + case l: ListType => invokeStatic(method[Neo4jTypes , Neo4jTypes.ListType, AnyType]("NTList"), asNeoType(l.innerType)) + case CTDateTime => getStatic[Neo4jTypes, Neo4jTypes.DateTimeType]("NTDateTime") + case CTLocalDateTime => getStatic[Neo4jTypes, Neo4jTypes.LocalDateTimeType]("NTLocalDateTime") + case CTDate => getStatic[Neo4jTypes, Neo4jTypes.DateType]("NTDate") + case CTTime => getStatic[Neo4jTypes, Neo4jTypes.TimeType]("NTTime") + case CTLocalTime => getStatic[Neo4jTypes, Neo4jTypes.LocalTimeType]("NTLocalTime") + case CTDuration =>getStatic[Neo4jTypes, Neo4jTypes.DurationType]("NTDuration") + case CTPoint => getStatic[Neo4jTypes, Neo4jTypes.PointType]("NTPoint") + case CTNode =>getStatic[Neo4jTypes, Neo4jTypes.NodeType]("NTNode") + case CTRelationship => getStatic[Neo4jTypes, Neo4jTypes.RelationshipType]("NTRelationship") + case CTPath => getStatic[Neo4jTypes, Neo4jTypes.PathType]("NTPath") + case CTGeometry => getStatic[Neo4jTypes, Neo4jTypes.GeometryType]("NTGeometry") + case CTMap =>getStatic[Neo4jTypes, Neo4jTypes.MapType]("NTMap") + case CTAny => getStatic[Neo4jTypes, Neo4jTypes.AnyType]("NTAny") + } } object IntermediateCodeGeneration { diff --git a/enterprise/cypher/compiled-expressions/src/test/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/CodeGenerationTest.scala b/enterprise/cypher/compiled-expressions/src/test/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/CodeGenerationTest.scala index c6f5f83b395e..86a19bf92c0e 100644 --- a/enterprise/cypher/compiled-expressions/src/test/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/CodeGenerationTest.scala +++ b/enterprise/cypher/compiled-expressions/src/test/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/CodeGenerationTest.scala @@ -1334,26 +1334,31 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport coerce(list(longValue(42), longValue(43)), ListType(symbols.CTAny)) should equal(list(longValue(42), longValue(43))) coerce(path(7), ListType(symbols.CTAny)) should equal(path(7).asList()) coerce(list(node(42), node(43)), ListType(symbols.CTNode)) should equal(list(node(42), node(43))) - coerce(list(relationship(42), relationship(43)), ListType(symbols.CTNode)) should equal(list(relationship(42), relationship(43))) + coerce(list(relationship(42), relationship(43)), ListType(symbols.CTRelationship)) should equal(list(relationship(42), relationship(43))) + coerce(list(doubleValue(1.2), longValue(2), doubleValue(3.1)), + ListType(symbols.CTInteger)) should equal(list(longValue(1), longValue(2), longValue(3))) + coerce(list(doubleValue(1.2), longValue(2), doubleValue(3.1)), + ListType(symbols.CTFloat)) should equal(list(doubleValue(1.2), doubleValue(2), doubleValue(3.1))) + coerce(list(list(doubleValue(1.2), longValue(2)), list(doubleValue(3.1))), + ListType(ListType(symbols.CTInteger))) should equal(list(list(longValue(1), longValue(2)), list(longValue(3)))) + a [CypherTypeException] should be thrownBy coerce(path(11), ListType(symbols.CTNode)) a [CypherTypeException] should be thrownBy coerce(path(11), ListType(symbols.CTRelationship)) } + test("coerceTo list happy path") { + types.foreach { + case (v, typ) => + coerce(list(v), ListType(typ)) should equal(list(v)) + coerce(list(list(v)), ListType(ListType(typ))) should equal(list(list(v))) + coerce(list(list(list(v))), ListType(ListType(ListType(typ)))) should equal(list(list(list(v)))) + } + } + test("coerceTo unhappy path") { - val toTest = Map(longValue(42) -> symbols.CTNumber, stringValue("hello") -> symbols.CTString, - Values.TRUE -> symbols.CTBoolean, node(42) -> symbols.CTNode, - relationship(1337) -> symbols.CTRelationship, path(13) -> symbols.CTPath, - pointValue(Cartesian, 1.0, 3.6) -> symbols.CTPoint, - DateTimeValue.now(Clock.systemUTC()) -> symbols.CTDateTime, - LocalDateTimeValue.now(Clock.systemUTC()) -> symbols.CTLocalDateTime, - TimeValue.now(Clock.systemUTC()) -> symbols.CTTime, - LocalTimeValue.now(Clock.systemUTC()) -> symbols.CTLocalTime, - DateValue.now(Clock.systemUTC()) -> symbols.CTDate, - durationValue(Duration.ofHours(3)) -> symbols.CTDuration) - - for {value <- toTest.keys - typ <- toTest.values} { - if (toTest(value) == typ) coerce(value, typ) should equal(value) + for {value <- types.keys + typ <- types.values} { + if (types(value) == typ) coerce(value, typ) should equal(value) else a [CypherTypeException] should be thrownBy coerce(value, typ) } } @@ -1536,4 +1541,15 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport } } + private val types = Map(longValue(42) -> symbols.CTNumber, stringValue("hello") -> symbols.CTString, + Values.TRUE -> symbols.CTBoolean, node(42) -> symbols.CTNode, + relationship(1337) -> symbols.CTRelationship, path(13) -> symbols.CTPath, + pointValue(Cartesian, 1.0, 3.6) -> symbols.CTPoint, + DateTimeValue.now(Clock.systemUTC()) -> symbols.CTDateTime, + LocalDateTimeValue.now(Clock.systemUTC()) -> symbols.CTLocalDateTime, + TimeValue.now(Clock.systemUTC()) -> symbols.CTTime, + LocalTimeValue.now(Clock.systemUTC()) -> symbols.CTLocalTime, + DateValue.now(Clock.systemUTC()) -> symbols.CTDate, + durationValue(Duration.ofHours(3)) -> symbols.CTDuration) + }