From b9e73dcecf9ac654654be212932f4fb4178664b9 Mon Sep 17 00:00:00 2001 From: Pontus Melke Date: Thu, 14 Jun 2018 23:06:00 +0200 Subject: [PATCH] Fix so that we coerce when necessary in compiled expression In a not so distant future we should be able to rely on planning to have properly coerced for us, but until then we need to handle it explicitly in compiled expressions --- .../v3_5/helpers/PredicateHelper.scala | 2 +- .../cypher/operations/CypherBoolean.java | 75 ++-- .../compiled/expressions/CompiledHelpers.java | 3 +- .../IntermediateCodeGeneration.scala | 75 ++-- .../expressions/CodeGenerationTest.scala | 342 ++++++++++-------- 5 files changed, 268 insertions(+), 229 deletions(-) diff --git a/community/cypher/cypher-planner-3.5/src/main/scala/org/neo4j/cypher/internal/compiler/v3_5/helpers/PredicateHelper.scala b/community/cypher/cypher-planner-3.5/src/main/scala/org/neo4j/cypher/internal/compiler/v3_5/helpers/PredicateHelper.scala index b094d15d7d84e..26feded84a9dc 100644 --- a/community/cypher/cypher-planner-3.5/src/main/scala/org/neo4j/cypher/internal/compiler/v3_5/helpers/PredicateHelper.scala +++ b/community/cypher/cypher-planner-3.5/src/main/scala/org/neo4j/cypher/internal/compiler/v3_5/helpers/PredicateHelper.scala @@ -57,7 +57,7 @@ object PredicateHelper { //i) we do late ast rewrite after semantic analysis, so all semantic table will be missing some expression //ii) For WHERE a.prop semantic analysis will say that a.prop has boolean type since it belongs to a WHERE. // That makes it not usable here since we would need to coerce in that case. - private def isPredicate(expression: Expression) = { + def isPredicate(expression: Expression) = { expression match { case o: OperatorExpression => o.signatures.forall(_.outputType == symbols.CTBoolean) case f: FunctionInvocation => BOOLEAN_FUNCTIONS.contains(f.function) diff --git a/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/operations/CypherBoolean.java b/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/operations/CypherBoolean.java index f8dfa00e408dd..580c96fe1328a 100644 --- a/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/operations/CypherBoolean.java +++ b/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/operations/CypherBoolean.java @@ -58,150 +58,133 @@ private CypherBoolean() public static Value xor( AnyValue lhs, AnyValue rhs ) { - boolean seenNull = false; - if ( lhs == NO_VALUE || rhs == NO_VALUE ) - { - return NO_VALUE; - } - return (lhs == Values.TRUE) ^ (rhs == Values.TRUE) ? Values.TRUE : Values.FALSE; } public static Value not( AnyValue in ) { - if ( in == NO_VALUE ) - { - return NO_VALUE; - } - return in != Values.TRUE ? Values.TRUE : Values.FALSE; } public static Value equals( AnyValue lhs, AnyValue rhs ) { - Boolean equals = lhs.ternaryEquals( rhs ); - if ( equals == null ) + Boolean compare = lhs.ternaryEquals( rhs ); + if ( compare == null ) { return NO_VALUE; } - else - { - return equals ? Values.TRUE : Values.FALSE; - } + return compare ? Values.TRUE : Values.FALSE; } public static Value notEquals( AnyValue lhs, AnyValue rhs ) { - Boolean equals = lhs.ternaryEquals( rhs ); - if ( equals == null ) + Boolean compare = lhs.ternaryEquals( rhs ); + if ( compare == null ) { return NO_VALUE; } - else - { - return equals ? Values.FALSE : Values.TRUE; - } + return compare ? Values.FALSE : Values.TRUE; } - public static BooleanValue coerceToBoolean( AnyValue value ) + public static Value coerceToBoolean( AnyValue value ) { - return value.map( BOOLEAN_MAPPER ) ? Values.TRUE : Values.FALSE; + return value.map( BOOLEAN_MAPPER ); } - private static final class BooleanMapper implements ValueMapper + private static final class BooleanMapper implements ValueMapper { @Override - public Boolean mapPath( PathValue value ) + public Value mapPath( PathValue value ) { - return value.size() > 0; + return value.size() > 0 ? Values.TRUE : Values.FALSE; } @Override - public Boolean mapNode( VirtualNodeValue value ) + public Value mapNode( VirtualNodeValue value ) { throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null ); } @Override - public Boolean mapRelationship( VirtualRelationshipValue value ) + public Value mapRelationship( VirtualRelationshipValue value ) { throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null ); } @Override - public Boolean mapMap( MapValue value ) + public Value mapMap( MapValue value ) { throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null ); } @Override - public Boolean mapNoValue() + public Value mapNoValue() { - return false; + return NO_VALUE; } @Override - public Boolean mapSequence( SequenceValue value ) + public Value mapSequence( SequenceValue value ) { - return value.length() > 0; + return value.length() > 0 ? Values.TRUE : Values.FALSE; } @Override - public Boolean mapText( TextValue value ) + public Value mapText( TextValue value ) { throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null ); } @Override - public Boolean mapBoolean( BooleanValue value ) + public Value mapBoolean( BooleanValue value ) { - return value.booleanValue(); + return value; } @Override - public Boolean mapNumber( NumberValue value ) + public Value mapNumber( NumberValue value ) { throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null ); } @Override - public Boolean mapDateTime( DateTimeValue value ) + public Value mapDateTime( DateTimeValue value ) { throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null ); } @Override - public Boolean mapLocalDateTime( LocalDateTimeValue value ) + public Value mapLocalDateTime( LocalDateTimeValue value ) { throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null ); } @Override - public Boolean mapDate( DateValue value ) + public Value mapDate( DateValue value ) { throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null ); } @Override - public Boolean mapTime( TimeValue value ) + public Value mapTime( TimeValue value ) { throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null ); } @Override - public Boolean mapLocalTime( LocalTimeValue value ) + public Value mapLocalTime( LocalTimeValue value ) { throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null ); } @Override - public Boolean mapDuration( DurationValue value ) + public Value mapDuration( DurationValue value ) { throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null ); } @Override - public Boolean mapPoint( PointValue value ) + public Value mapPoint( PointValue value ) { throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null ); } diff --git a/enterprise/cypher/compiled-expressions/src/main/java/org/neo4j/cypher/internal/runtime/compiled/expressions/CompiledHelpers.java b/enterprise/cypher/compiled-expressions/src/main/java/org/neo4j/cypher/internal/runtime/compiled/expressions/CompiledHelpers.java index 1bfcf0237b2c8..e2af2ad02fc1e 100644 --- a/enterprise/cypher/compiled-expressions/src/main/java/org/neo4j/cypher/internal/runtime/compiled/expressions/CompiledHelpers.java +++ b/enterprise/cypher/compiled-expressions/src/main/java/org/neo4j/cypher/internal/runtime/compiled/expressions/CompiledHelpers.java @@ -22,7 +22,8 @@ */ package org.neo4j.cypher.internal.runtime.compiled.expressions; -import org.neo4j.cypher.CypherTypeException; +import org.opencypher.v9_0.util.CypherTypeException; + import org.neo4j.values.AnyValue; import org.neo4j.values.storable.BooleanValue; import org.neo4j.values.storable.Value; diff --git a/enterprise/cypher/compiled-expressions/src/main/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/IntermediateCodeGeneration.scala b/enterprise/cypher/compiled-expressions/src/main/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/IntermediateCodeGeneration.scala index 9f794a622a20a..83cc60fbc090d 100644 --- a/enterprise/cypher/compiled-expressions/src/main/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/IntermediateCodeGeneration.scala +++ b/enterprise/cypher/compiled-expressions/src/main/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/IntermediateCodeGeneration.scala @@ -24,6 +24,7 @@ package org.neo4j.cypher.internal.runtime.compiled.expressions import org.neo4j.cypher.internal.compatibility.v3_5.runtime.SlotConfiguration import org.neo4j.cypher.internal.compatibility.v3_5.runtime.ast._ +import org.neo4j.cypher.internal.compiler.v3_5.helpers.PredicateHelper.isPredicate import org.neo4j.cypher.internal.runtime.DbAccess import org.neo4j.cypher.internal.runtime.compiled.expressions.IntermediateRepresentation.method import org.neo4j.cypher.internal.runtime.interpreted.ExecutionContext @@ -178,66 +179,88 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) { case Or(lhs, rhs) => for {l <- compile(lhs) r <- compile(rhs) - } yield generateOrs(List(l, r)) + } yield { + val left = if (isPredicate(lhs)) l else coerceToPredicate(l) + val right = if (isPredicate(rhs)) r else coerceToPredicate(r) + generateOrs(List(left, right)) + } case Ors(expressions) => - val compiled = expressions.foldLeft[Option[List[IntermediateExpression]]](Some(List.empty)) { (acc, current) => + val compiled = expressions.foldLeft[Option[List[(IntermediateExpression, Boolean)]]](Some(List.empty)) { (acc, current) => for {l <- acc - e <- compile(current)} yield l :+ e + e <- compile(current)} yield l :+ (e -> isPredicate(current)) } for (e <- compiled) yield e match { case Nil => IntermediateExpression(truthValue, nullable = false) //this will not really happen because of rewriters etc - case a :: Nil => a - case list => generateOrs(list) + case (a, isPredicate) :: Nil => if (isPredicate) a else coerceToPredicate(a) + case list => + val coerced = list.map { + case (p, true) => p + case (p, false) => coerceToPredicate(p) + } + generateOrs(coerced) } case Xor(lhs, rhs) => for {l <- compile(lhs) r <- compile(rhs) - } yield IntermediateExpression(invokeStatic(method[CypherBoolean, Value, AnyValue, AnyValue]("xor"), l.ir, r.ir), - l.nullable | l.nullable) + } yield { + val left = if (isPredicate(lhs)) l else coerceToPredicate(l) + val right = if (isPredicate(rhs)) r else coerceToPredicate(r) + IntermediateExpression( + noValueCheck(left, right)(invokeStatic(method[CypherBoolean, Value, AnyValue, AnyValue]("xor"), left.ir, right.ir)), + left.nullable | right.nullable) + } case And(lhs, rhs) => for {l <- compile(lhs) r <- compile(rhs) - } yield generateAnds(List(l, r)) + } yield { + val left = if (isPredicate(lhs)) l else coerceToPredicate(l) + val right = if (isPredicate(rhs)) r else coerceToPredicate(r) + generateAnds(List(left, right)) + } case Ands(expressions) => - val compiled = expressions.foldLeft[Option[List[IntermediateExpression]]](Some(List.empty)) { (acc, current) => - for {l <- acc - e <- compile(current)} yield l :+ e + val compiled = expressions.foldLeft[Option[List[(IntermediateExpression, Boolean)]]](Some(List.empty)) { (acc, current) => + for {l <- acc + e <- compile(current)} yield l :+ (e -> isPredicate(current)) } for (e <- compiled) yield e match { case Nil => IntermediateExpression(truthValue, nullable = false) //this will not really happen because of rewriters etc - case a :: Nil => a - case list => generateAnds(list) + case (a, isPredicate) :: Nil => if (isPredicate) a else coerceToPredicate(a) + case list => + val coerced = list.map { + case (p, true) => p + case (p, false) => coerceToPredicate(p) + } + generateAnds(coerced) } case Not(arg) => - compile(arg).map(a => + compile(arg).map(a => { + val in = if (isPredicate(arg)) a else coerceToPredicate(a) IntermediateExpression( - invokeStatic(method[CypherBoolean, Value, AnyValue]("not"), a.ir), a.nullable)) + noValueCheck(in)(invokeStatic(method[CypherBoolean, Value, AnyValue]("not"), in.ir)), in.nullable) + }) case Equals(lhs, rhs) => for {l <- compile(lhs) r <- compile(rhs) - } yield IntermediateExpression( - invokeStatic(method[CypherBoolean, Value, AnyValue, AnyValue]("equals"), l.ir, r.ir), l.nullable | r.nullable) + } yield IntermediateExpression(invokeStatic(method[CypherBoolean, Value, AnyValue, AnyValue]("equals"), l.ir, r.ir), + l.nullable | r.nullable) case NotEquals(lhs, rhs) => for {l <- compile(lhs) r <- compile(rhs) - } yield IntermediateExpression( - invokeStatic(method[CypherBoolean, Value, AnyValue, AnyValue]("notEquals"), l.ir, r.ir), l.nullable | r.nullable) + } yield IntermediateExpression(invokeStatic(method[CypherBoolean, Value, AnyValue, AnyValue]("notEquals"), l.ir, r.ir), + l.nullable | r.nullable) + + case CoerceToPredicate(inner) => compile(inner).map(coerceToPredicate) - case CoerceToPredicate(inner) => - compile(inner).map(e => - IntermediateExpression( - invokeStatic(method[CypherBoolean, BooleanValue, AnyValue]("coerceToBoolean"), e.ir), - nullable = false)) //data access case Parameter(name, _) => //TODO parameters that are autogenerated from literals should have nullable = false Some(IntermediateExpression(invoke(load("params"), method[MapValue, AnyValue, String]("get"), @@ -375,6 +398,10 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) { nextName } + private def coerceToPredicate(e: IntermediateExpression) = IntermediateExpression( + invokeStatic(method[CypherBoolean, Value, AnyValue]("coerceToBoolean"), e.ir), + nullable = e.nullable) + /** * Ok AND and ANDS are complicated. At the core we try to find a single `FALSE` if we find one there is no need to look * at more predicates. If it doesn't find a `FALSE` it will either return `NULL` if any of the predicates has evaluated diff --git a/enterprise/cypher/compiled-expressions/src/test/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/CodeGenerationTest.scala b/enterprise/cypher/compiled-expressions/src/test/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/CodeGenerationTest.scala index b802769783af1..7ff5ba9ac555d 100644 --- a/enterprise/cypher/compiled-expressions/src/test/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/CodeGenerationTest.scala +++ b/enterprise/cypher/compiled-expressions/src/test/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/CodeGenerationTest.scala @@ -27,7 +27,6 @@ import java.time.Duration import java.util.concurrent.ThreadLocalRandom import org.mockito.Mockito.when -import org.neo4j.cypher.CypherTypeException import org.neo4j.cypher.internal.compatibility.v3_5.runtime.SlotConfiguration import org.neo4j.cypher.internal.compatibility.v3_5.runtime.ast._ import org.neo4j.cypher.internal.runtime.DbAccess @@ -39,18 +38,18 @@ import org.neo4j.values.storable.{DoubleValue, Values} import org.neo4j.values.virtual.VirtualValues.{EMPTY_LIST, EMPTY_MAP, list, map} import org.opencypher.v9_0.ast.AstConstructionTestSupport import org.opencypher.v9_0.expressions._ -import org.opencypher.v9_0.util.symbols import org.opencypher.v9_0.util.test_helpers.CypherFunSuite +import org.opencypher.v9_0.util.{CypherTypeException, symbols} class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport { private val ctx = mock[ExecutionContext] - private val dbAccess = mock[DbAccess] + private val db = mock[DbAccess] private val random = ThreadLocalRandom.current() test("round function") { - compile(function("round", literalFloat(PI))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(3.0)) - compile(function("round", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("round", literalFloat(PI))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(3.0)) + compile(function("round", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("rand function") { @@ -61,140 +60,140 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(expression) // Then - val value = compiled.evaluate(ctx, dbAccess, EMPTY_MAP).asInstanceOf[DoubleValue].doubleValue() + val value = compiled.evaluate(ctx, db, EMPTY_MAP).asInstanceOf[DoubleValue].doubleValue() value should (be >= 0.0 and be <1.0) } test("sin function") { val arg = random.nextDouble() - compile(function("sin", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.sin(arg))) - compile(function("sin", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("sin", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.sin(arg))) + compile(function("sin", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("asin function") { val arg = random.nextDouble() - compile(function("asin", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.asin(arg))) - compile(function("asin", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("asin", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.asin(arg))) + compile(function("asin", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("haversin function") { val arg = random.nextDouble() - compile(function("haversin", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue((1.0 - Math.cos(arg)) / 2)) - compile(function("haversin", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("haversin", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue((1.0 - Math.cos(arg)) / 2)) + compile(function("haversin", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("acos function") { val arg = random.nextDouble() - compile(function("acos", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.acos(arg))) - compile(function("acos", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("acos", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.acos(arg))) + compile(function("acos", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("cos function") { val arg = random.nextDouble() - compile(function("cos", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.cos(arg))) - compile(function("cos", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("cos", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.cos(arg))) + compile(function("cos", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("cot function") { val arg = random.nextDouble() - compile(function("cot", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(1 / Math.tan(arg))) - compile(function("cot", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("cot", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(1 / Math.tan(arg))) + compile(function("cot", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("atan function") { val arg = random.nextDouble() - compile(function("atan", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.atan(arg))) - compile(function("atan", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("atan", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.atan(arg))) + compile(function("atan", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("atan2 function") { val arg1 = random.nextDouble() val arg2 = random.nextDouble() - compile(function("atan2", literalFloat(arg1), literalFloat(arg2))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.atan2(arg1, arg2))) - compile(function("atan2", noValue,literalFloat(arg1))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) - compile(function("atan2", literalFloat(arg1), noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) - compile(function("atan2", noValue, noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("atan2", literalFloat(arg1), literalFloat(arg2))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.atan2(arg1, arg2))) + compile(function("atan2", noValue,literalFloat(arg1))).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) + compile(function("atan2", literalFloat(arg1), noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) + compile(function("atan2", noValue, noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("tan function") { val arg = random.nextDouble() - compile(function("tan", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.tan(arg))) - compile(function("tan", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("tan", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.tan(arg))) + compile(function("tan", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("ceil function") { val arg = random.nextDouble() - compile(function("ceil", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.ceil(arg))) - compile(function("ceil", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("ceil", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.ceil(arg))) + compile(function("ceil", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("floor function") { val arg = random.nextDouble() - compile(function("floor", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.floor(arg))) - compile(function("floor", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("floor", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.floor(arg))) + compile(function("floor", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("abs function") { - compile(function("abs", literalFloat(3.2))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(3.2)) - compile(function("abs", literalFloat(-3.2))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(3.2)) - compile(function("abs", literalInt(3))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(longValue(3)) - compile(function("abs", literalInt(-3))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(longValue(3)) - compile(function("abs", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(function("abs", literalFloat(3.2))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(3.2)) + compile(function("abs", literalFloat(-3.2))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(3.2)) + compile(function("abs", literalInt(3))).evaluate(ctx, db, EMPTY_MAP) should equal(longValue(3)) + compile(function("abs", literalInt(-3))).evaluate(ctx, db, EMPTY_MAP) should equal(longValue(3)) + compile(function("abs", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) } test("radians function") { val arg = random.nextDouble() - compile(function("radians", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.toRadians(arg))) - compile(function("radians", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("radians", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.toRadians(arg))) + compile(function("radians", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("degrees function") { val arg = random.nextDouble() - compile(function("degrees", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.toDegrees(arg))) - compile(function("degrees", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("degrees", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.toDegrees(arg))) + compile(function("degrees", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("exp function") { val arg = random.nextDouble() - compile(function("exp", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.exp(arg))) - compile(function("exp", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("exp", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.exp(arg))) + compile(function("exp", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("log function") { val arg = random.nextDouble() - compile(function("log", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.log(arg))) - compile(function("log", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("log", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.log(arg))) + compile(function("log", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("log10 function") { val arg = random.nextDouble() - compile(function("log10", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.log10(arg))) - compile(function("log10", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("log10", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.log10(arg))) + compile(function("log10", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("sign function") { val arg = random.nextInt() - compile(function("sign", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.signum(arg))) - compile(function("sign", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("sign", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.signum(arg))) + compile(function("sign", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("sqrt function") { val arg = random.nextDouble() - compile(function("sqrt", literalFloat(arg))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(doubleValue(Math.sqrt(arg))) - compile(function("sqrt", noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compile(function("sqrt", literalFloat(arg))).evaluate(ctx, db, EMPTY_MAP) should equal(doubleValue(Math.sqrt(arg))) + compile(function("sqrt", noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("pi function") { - compile(function("pi")).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.PI) + compile(function("pi")).evaluate(ctx, db, EMPTY_MAP) should equal(Values.PI) } test("e function") { - compile(function("e")).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.E) + compile(function("e")).evaluate(ctx, db, EMPTY_MAP) should equal(Values.E) } test("range function") { val range = function("range", literalInt(5), literalInt(9), literalInt(2)) - compile(range).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(list(longValue(5), longValue(7), longValue(9))) + compile(range).evaluate(ctx, db, EMPTY_MAP) should equal(list(longValue(5), longValue(7), longValue(9))) } test("add numbers") { @@ -205,24 +204,24 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(expression) // Then - compiled.evaluate(ctx, dbAccess, EMPTY_MAP) should equal(longValue(52)) + compiled.evaluate(ctx, db, EMPTY_MAP) should equal(longValue(52)) } test("add temporals") { val compiled = compile(add(parameter("a"), parameter("b"))) // temporal + duration - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), Array(temporalValue(localTime(0)), + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(temporalValue(localTime(0)), durationValue(Duration.ofHours(10))))) should equal(localTime(10, 0, 0, 0)) // duration + temporal - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), Array(durationValue(Duration.ofHours(10)), + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(durationValue(Duration.ofHours(10)), temporalValue(localTime(0))))) should equal(localTime(10, 0, 0, 0)) //duration + duration - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), Array(durationValue(Duration.ofHours(10)), + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(durationValue(Duration.ofHours(10)), durationValue(Duration.ofHours(10))))) should equal(durationValue(Duration.ofHours(20))) } @@ -235,8 +234,8 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(expression) // Then - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), Array(longValue(42), NO_VALUE))) should equal(NO_VALUE) - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), Array(NO_VALUE, longValue(42)))) should equal(NO_VALUE) + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(longValue(42), NO_VALUE))) should equal(NO_VALUE) + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(NO_VALUE, longValue(42)))) should equal(NO_VALUE) } test("add strings") { @@ -244,16 +243,16 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(add(parameter("a"), parameter("b"))) // string1 + string2 - compiled.evaluate(ctx, dbAccess, + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(stringValue("hello "), stringValue("world")))) should equal(stringValue("hello world")) //string + other - compiled.evaluate(ctx, dbAccess, + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(stringValue("hello "), longValue(1337)))) should equal(stringValue("hello 1337")) //other + string - compiled.evaluate(ctx, dbAccess, + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(longValue(1337), stringValue(" hello")))) should equal(stringValue("1337 hello")) @@ -268,7 +267,7 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(expression) // Then - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(longArray(Array(42, 43)), longArray(Array(44, 45))))) should equal(list(longValue(42), longValue(43), longValue(44), longValue(45))) @@ -279,18 +278,18 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(add(parameter("a"), parameter("b"))) // [a1,a2 ..] + [b1,b2 ..] - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(list(longValue(42), longValue(43)), list(longValue(44), longValue(45))))) should equal(list(longValue(42), longValue(43), longValue(44), longValue(45))) // [a1,a2 ..] + b - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(list(longValue(42), longValue(43)), longValue(44)))) should equal(list(longValue(42), longValue(43), longValue(44))) // a + [b1,b2 ..] - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(longValue(43), list(longValue(44), longValue(45))))) should equal(list(longValue(43), longValue(44), longValue(45))) @@ -304,7 +303,7 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(expression) // Then - compiled.evaluate(ctx, dbAccess, EMPTY_MAP) should equal(longValue(32)) + compiled.evaluate(ctx, db, EMPTY_MAP) should equal(longValue(32)) } test("subtract with NO_VALUE") { @@ -315,20 +314,20 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(expression) // Then - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), Array(longValue(42), NO_VALUE))) should equal(NO_VALUE) - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), Array(NO_VALUE, longValue(42)))) should equal(NO_VALUE) + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(longValue(42), NO_VALUE))) should equal(NO_VALUE) + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(NO_VALUE, longValue(42)))) should equal(NO_VALUE) } test("subtract temporals") { val compiled = compile(subtract(parameter("a"), parameter("b"))) // temporal - duration - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), Array(temporalValue(localTime(20, 0, 0, 0)), + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(temporalValue(localTime(20, 0, 0, 0)), durationValue(Duration.ofHours(10))))) should equal(localTime(10, 0, 0, 0)) //duration - duration - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), Array(durationValue(Duration.ofHours(10)), + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(durationValue(Duration.ofHours(10)), durationValue(Duration.ofHours(10))))) should equal(durationValue(Duration.ofHours(0))) } @@ -341,7 +340,7 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(expression) // Then - compiled.evaluate(ctx, dbAccess, EMPTY_MAP) should equal(longValue(420)) + compiled.evaluate(ctx, db, EMPTY_MAP) should equal(longValue(420)) } test("multiply with NO_VALUE") { @@ -352,8 +351,8 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(expression) // Then - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), Array(longValue(42), NO_VALUE))) should equal(NO_VALUE) - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), Array(NO_VALUE, longValue(42)))) should equal(NO_VALUE) + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(longValue(42), NO_VALUE))) should equal(NO_VALUE) + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(NO_VALUE, longValue(42)))) should equal(NO_VALUE) } test("extract parameter") { @@ -364,8 +363,8 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(expression) // Then - compiled.evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) - compiled.evaluate(ctx, dbAccess, map(Array("prop"), Array(stringValue("foo")))) should equal(stringValue("foo")) + compiled.evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) + compiled.evaluate(ctx, db, map(Array("prop"), Array(stringValue("foo")))) should equal(stringValue("foo")) } test("NULL") { @@ -376,7 +375,7 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(expression) // Then - compiled.evaluate(ctx, dbAccess, EMPTY_MAP) should equal(NO_VALUE) + compiled.evaluate(ctx, db, EMPTY_MAP) should equal(NO_VALUE) } test("TRUE") { @@ -387,7 +386,7 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(expression) // Then - compiled.evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) + compiled.evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) } test("FALSE") { @@ -398,138 +397,167 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(expression) // Then - compiled.evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) + compiled.evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) } test("OR") { - compile(or(t, t)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(or(f, t)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(or(t, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(or(f, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) + compile(or(t, t)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(or(f, t)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(or(t, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(or(f, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) - compile(or(noValue, noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(or(noValue, t)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(or(t, noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(or(noValue, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(or(f, noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(or(noValue, noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(or(noValue, t)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(or(t, noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(or(noValue, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(or(f, noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) } test("XOR") { - compile(xor(t, t)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) - compile(xor(f, t)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(xor(t, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(xor(f, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) + compile(xor(t, t)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) + compile(xor(f, t)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(xor(t, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(xor(f, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) - compile(xor(noValue, noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(xor(noValue, t)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(xor(t, noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(xor(noValue, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(xor(f, noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(xor(noValue, noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(xor(noValue, t)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(xor(t, noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(xor(noValue, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(xor(f, noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) } test("OR should throw on non-boolean input") { - a [CypherTypeException] should be thrownBy compile(or(literalInt(42), f)).evaluate(ctx, dbAccess, EMPTY_MAP) - a [CypherTypeException] should be thrownBy compile(or(f, literalInt(42))).evaluate(ctx, dbAccess, EMPTY_MAP) - compile(or(t, literalInt(42))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(or(literalInt(42), t)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) + a [CypherTypeException] should be thrownBy compile(or(literalInt(42), f)).evaluate(ctx, db, EMPTY_MAP) + a [CypherTypeException] should be thrownBy compile(or(f, literalInt(42))).evaluate(ctx, db, EMPTY_MAP) + compile(or(t, literalInt(42))).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(or(literalInt(42), t)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + } + + test("OR should handle coercion") { + val expression = compile(or(parameter("a"), parameter("b"))) + expression.evaluate(ctx, db, map(Array("a", "b"), Array(Values.FALSE, EMPTY_LIST))) should equal(Values.FALSE) + expression.evaluate(ctx, db, map(Array("a", "b"), Array(Values.FALSE, list(stringValue("hello"))))) should equal(Values.TRUE) } test("ORS") { - compile(ors(f, f, f, f, f, f, t, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(ors(f, f, f, f, f, f, f, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) - compile(ors(f, f, f, f, noValue, f, f, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(ors(f, f, f, t, noValue, t, f, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) + compile(ors(f, f, f, f, f, f, t, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(ors(f, f, f, f, f, f, f, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) + compile(ors(f, f, f, f, noValue, f, f, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(ors(f, f, f, t, noValue, t, f, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) } test("ORS should throw on non-boolean input") { val compiled = compile(ors(parameter("a"), parameter("b"), parameter("c"), parameter("d"), parameter("e"))) val keys = Array("a", "b", "c", "d", "e") - compiled.evaluate(ctx, dbAccess, + compiled.evaluate(ctx, db, map(keys, Array(Values.FALSE, Values.FALSE, Values.FALSE, Values.FALSE, Values.FALSE))) should equal(Values.FALSE) - compiled.evaluate(ctx, dbAccess, + compiled.evaluate(ctx, db, map(keys, Array(Values.FALSE, Values.FALSE, Values.TRUE, Values.FALSE, Values.FALSE))) should equal(Values.TRUE) - compiled.evaluate(ctx, dbAccess, + compiled.evaluate(ctx, db, map(keys, Array(intValue(42), Values.FALSE, Values.TRUE, Values.FALSE, Values.FALSE))) should equal(Values.TRUE) - a [CypherTypeException] should be thrownBy compiled.evaluate(ctx, dbAccess, + a [CypherTypeException] should be thrownBy compiled.evaluate(ctx, db, map(keys, Array(intValue(42), Values.FALSE, Values.FALSE, Values.FALSE, Values.FALSE))) } + test("ORS should handle coercion") { + val expression = compile(ors(parameter("a"), parameter("b"))) + expression.evaluate(ctx, db, map(Array("a", "b"), Array(Values.FALSE, EMPTY_LIST))) should equal(Values.FALSE) + expression.evaluate(ctx, db, map(Array("a", "b"), Array(Values.FALSE, list(stringValue("hello"))))) should equal(Values.TRUE) + } + test("AND") { - compile(and(t, t)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(and(f, t)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) - compile(and(t, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) - compile(and(f, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) + compile(and(t, t)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(and(f, t)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) + compile(and(t, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) + compile(and(f, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) - compile(and(noValue, noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(and(noValue, t)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(and(t, noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(and(noValue, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) - compile(and(f, noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) + compile(and(noValue, noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(and(noValue, t)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(and(t, noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(and(noValue, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) + compile(and(f, noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) } test("AND should throw on non-boolean input") { - a [CypherTypeException] should be thrownBy compile(and(literalInt(42), t)).evaluate(ctx, dbAccess, EMPTY_MAP) - a [CypherTypeException] should be thrownBy compile(and(t, literalInt(42))).evaluate(ctx, dbAccess, EMPTY_MAP) - compile(and(f, literalInt(42))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) - compile(and(literalInt(42), f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) + a [CypherTypeException] should be thrownBy compile(and(literalInt(42), t)).evaluate(ctx, db, EMPTY_MAP) + a [CypherTypeException] should be thrownBy compile(and(t, literalInt(42))).evaluate(ctx, db, EMPTY_MAP) + compile(and(f, literalInt(42))).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) + compile(and(literalInt(42), f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) + } + + test("AND should handle coercion") { + val expression = compile(and(parameter("a"), parameter("b"))) + expression.evaluate(ctx, db, map(Array("a", "b"), Array(Values.TRUE, EMPTY_LIST))) should equal(Values.FALSE) + expression.evaluate(ctx, db, map(Array("a", "b"), Array(Values.TRUE, list(stringValue("hello"))))) should equal(Values.TRUE) } test("ANDS") { - compile(ands(t, t, t, t, t)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(ands(t, t, t, t, t, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) - compile(ands(t, t, t, t, noValue, t)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(ands(t, t, t, f, noValue, f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) + compile(ands(t, t, t, t, t)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(ands(t, t, t, t, t, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) + compile(ands(t, t, t, t, noValue, t)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(ands(t, t, t, f, noValue, f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) } test("ANDS should throw on non-boolean input") { val compiled = compile(ands(parameter("a"), parameter("b"), parameter("c"), parameter("d"), parameter("e"))) val keys = Array("a", "b", "c", "d", "e") - compiled.evaluate(ctx, dbAccess, + compiled.evaluate(ctx, db, map(keys, Array(Values.TRUE, Values.TRUE, Values.TRUE, Values.TRUE, Values.TRUE))) should equal(Values.TRUE) - compiled.evaluate(ctx, dbAccess, + compiled.evaluate(ctx, db, map(keys, Array(Values.TRUE, Values.TRUE, Values.FALSE, Values.TRUE, Values.TRUE))) should equal(Values.FALSE) - compiled.evaluate(ctx, dbAccess, + compiled.evaluate(ctx, db, map(keys, Array(intValue(42), Values.TRUE, Values.FALSE, Values.TRUE, Values.TRUE))) should equal(Values.FALSE) - a [CypherTypeException] should be thrownBy compiled.evaluate(ctx, dbAccess, + a [CypherTypeException] should be thrownBy compiled.evaluate(ctx, db, map(keys, Array(intValue(42), Values.TRUE, Values.TRUE, Values.TRUE, Values.TRUE))) } + test("ANDS should handle coercion") { + val expression = compile(ands(parameter("a"), parameter("b"))) + expression.evaluate(ctx, db, map(Array("a", "b"), Array(Values.TRUE, EMPTY_LIST))) should equal(Values.FALSE) + expression.evaluate(ctx, db, map(Array("a", "b"), Array(Values.TRUE, list(stringValue("hello"))))) should equal(Values.TRUE) + } + test("NOT") { - compile(not(f)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(not(t)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) - compile(not(noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(not(f)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(not(t)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) + compile(not(noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + } + + test("NOT should handle coercion") { + val expression = compile(not(parameter("a"))) + expression.evaluate(ctx, db, map(Array("a"), Array(EMPTY_LIST))) should equal(Values.TRUE) + expression.evaluate(ctx, db, map(Array("a"), Array(list(stringValue("hello"))))) should equal(Values.FALSE) } test("EQUALS") { - compile(equals(literalInt(42), literalInt(42))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(equals(literalInt(42), literalInt(43))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) - compile(equals(noValue, literalInt(43))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(equals(literalInt(42), noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(equals(noValue, noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(equals(literalInt(42), literalInt(42))).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(equals(literalInt(42), literalInt(43))).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) + compile(equals(noValue, literalInt(43))).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(equals(literalInt(42), noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(equals(noValue, noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) } test("NOT EQUALS") { - compile(notEquals(literalInt(42), literalInt(42))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) - compile(notEquals(literalInt(42), literalInt(43))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(notEquals(noValue, literalInt(43))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(notEquals(literalInt(42), noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(notEquals(noValue, noValue)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(notEquals(literalInt(42), literalInt(42))).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) + compile(notEquals(literalInt(42), literalInt(43))).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(notEquals(noValue, literalInt(43))).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(notEquals(literalInt(42), noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(notEquals(noValue, noValue)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) } test("CoerceToPredicate") { val coerced = CoerceToPredicate(parameter("a")) - compile(coerced).evaluate(ctx, dbAccess, map(Array("a"), Array(Values.FALSE))) should equal(Values.FALSE) - compile(coerced).evaluate(ctx, dbAccess, map(Array("a"), Array(Values.TRUE))) should equal(Values.TRUE) - compile(coerced).evaluate(ctx, dbAccess, map(Array("a"), Array(Values.NO_VALUE))) should equal(Values.FALSE) - compile(coerced).evaluate(ctx, dbAccess, map(Array("a"), Array(list(stringValue("A"))))) should equal(Values.TRUE) - compile(coerced).evaluate(ctx, dbAccess, map(Array("a"), Array(list(EMPTY_LIST)))) should equal(Values.TRUE) + compile(coerced).evaluate(ctx, db, map(Array("a"), Array(Values.FALSE))) should equal(Values.FALSE) + compile(coerced).evaluate(ctx, db, map(Array("a"), Array(Values.TRUE))) should equal(Values.TRUE) + compile(coerced).evaluate(ctx, db, map(Array("a"), Array(list(stringValue("A"))))) should equal(Values.TRUE) + compile(coerced).evaluate(ctx, db, map(Array("a"), Array(list(EMPTY_LIST)))) should equal(Values.TRUE) } test("ReferenceFromSlot") { @@ -542,7 +570,7 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(expression) // Then - compiled.evaluate(ctx, dbAccess, EMPTY_MAP) should equal(stringValue("hello")) + compiled.evaluate(ctx, db, EMPTY_MAP) should equal(stringValue("hello")) } test("IdFromSlot") { @@ -555,15 +583,15 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport val compiled = compile(expression) // Then - compiled.evaluate(ctx, dbAccess, EMPTY_MAP) should equal(longValue(42)) + compiled.evaluate(ctx, db, EMPTY_MAP) should equal(longValue(42)) } test("PrimitiveEquals") { val compiled = compile(PrimitiveEquals(parameter("a"), parameter("b"))) - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), Array(longValue(42), longValue(42)))) should + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(longValue(42), longValue(42)))) should equal(Values.TRUE) - compiled.evaluate(ctx, dbAccess, map(Array("a", "b"), Array(longValue(42), longValue(1337)))) should + compiled.evaluate(ctx, db, map(Array("a", "b"), Array(longValue(42), longValue(1337)))) should equal(Values.FALSE) } @@ -573,8 +601,8 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport when(ctx.getLongAt(nullOffset)).thenReturn(-1L) when(ctx.getLongAt(offset)).thenReturn(42L) - compile(NullCheck(nullOffset, literalFloat(PI))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(NullCheck(offset, literalFloat(PI))).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.PI) + compile(NullCheck(nullOffset, literalFloat(PI))).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) + compile(NullCheck(offset, literalFloat(PI))).evaluate(ctx, db, EMPTY_MAP) should equal(Values.PI) } test("NullCheckVariable") { @@ -585,9 +613,9 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport when(ctx.getRefAt(nullOffset)).thenReturn(NO_VALUE) when(ctx.getRefAt(offset)).thenReturn(stringValue("hello")) - compile(NullCheckVariable(nullOffset, ReferenceFromSlot(offset, "a"))).evaluate(ctx, dbAccess, EMPTY_MAP) should + compile(NullCheckVariable(nullOffset, ReferenceFromSlot(offset, "a"))).evaluate(ctx, db, EMPTY_MAP) should equal(Values.NO_VALUE) - compile(NullCheckVariable(offset, ReferenceFromSlot(offset, "a"))).evaluate(ctx, dbAccess, EMPTY_MAP) should + compile(NullCheckVariable(offset, ReferenceFromSlot(offset, "a"))).evaluate(ctx, db, EMPTY_MAP) should equal(stringValue("hello")) } @@ -597,8 +625,8 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport when(ctx.getLongAt(nullOffset)).thenReturn(-1L) when(ctx.getLongAt(offset)).thenReturn(77L) - compile(IsPrimitiveNull(nullOffset)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.TRUE) - compile(IsPrimitiveNull(offset)).evaluate(ctx, dbAccess, EMPTY_MAP) should equal(Values.FALSE) + compile(IsPrimitiveNull(nullOffset)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.TRUE) + compile(IsPrimitiveNull(offset)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) } private def compile(e: Expression) =