From 8c5e8b34753c156a1ac9c16da181bf9a03ad22e5 Mon Sep 17 00:00:00 2001 From: Pontus Melke Date: Mon, 25 Jun 2018 16:55:20 +0200 Subject: [PATCH] Support for container-index access --- .../commands/expressions/ContainerIndex.scala | 57 ++---------- .../expressions/ContainerIndexTest.scala | 35 ++----- .../cypher/operations/CypherFunctions.java | 92 ++++++++++++++++--- .../IntermediateCodeGeneration.scala | 9 +- .../expressions/CodeGenerationTest.scala | 43 +++++++++ 5 files changed, 148 insertions(+), 88 deletions(-) diff --git a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/ContainerIndex.scala b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/ContainerIndex.scala index 0be5b51fb11a2..5bba65251530a 100644 --- a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/ContainerIndex.scala +++ b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/ContainerIndex.scala @@ -19,61 +19,20 @@ */ package org.neo4j.cypher.internal.runtime.interpreted.commands.expressions -import org.opencypher.v9_0.util.{CypherTypeException, InvalidArgumentException} -import org.neo4j.cypher.internal.runtime.interpreted.ExecutionContext -import org.neo4j.cypher.internal.runtime.interpreted.{CastSupport, IsList, IsMap, ListSupport} import org.neo4j.cypher.internal.runtime.interpreted.pipes.QueryState +import org.neo4j.cypher.internal.runtime.interpreted.{ExecutionContext, ListSupport} +import org.neo4j.cypher.operations.CypherFunctions import org.neo4j.values._ -import org.neo4j.values.storable._ +import org.neo4j.values.storable.Values -case class ContainerIndex(expression: Expression, index: Expression) extends NullInNullOutExpression(expression) +case class ContainerIndex(expression: Expression, index: Expression) extends Expression with ListSupport { def arguments = Seq(expression, index) - override def compute(value: AnyValue, ctx: ExecutionContext, state: QueryState): AnyValue = { - value match { - case IsMap(m) => - val item = index(ctx, state) - if (item == Values.NO_VALUE) Values.NO_VALUE - else { - val key = CastSupport.castOrFail[TextValue](item) - m(state.query).get(key.stringValue()) - } - - case IsList(collection) => - val item = index(ctx, state) - if (item == Values.NO_VALUE) Values.NO_VALUE - else { - var idx = validateTypeAndRange(item) - - if (idx < 0) - idx = collection.size + idx - - if (idx >= collection.size || idx < 0) Values.NO_VALUE - else collection.value(idx) - } - - case _ => - val indexValue = index(ctx, state) - throw new CypherTypeException( - s"`$value` is not a collection or a map. Element access is only possible by performing a collection lookup using an integer index, or by performing a map lookup using a string key (found: $value[$indexValue])") - } - } - - private def validateTypeAndRange(item: AnyValue): Int = { - val number = CastSupport.castOrFail[NumberValue](item) - - val longValue = number match { - case _: FloatValue | _: DoubleValue=> - throw new CypherTypeException(s"Cannot index a list using an non-integer number, got $number") - case _ => number.longValue() - } - - if (longValue > Int.MaxValue || longValue < Int.MinValue) - throw new InvalidArgumentException( - s"Cannot index a list using a value greater than ${Int.MaxValue} or lesser than ${Int.MinValue}, got $number") - - longValue.toInt + override def apply(ctx: ExecutionContext, + state: QueryState): AnyValue = expression(ctx, state) match { + case Values.NO_VALUE => Values.NO_VALUE + case value => CypherFunctions.containerIndex(value, index(ctx, state), state.query) } def rewrite(f: (Expression) => Expression): Expression = f(ContainerIndex(expression.rewrite(f), index.rewrite(f))) diff --git a/community/cypher/interpreted-runtime/src/test/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/ContainerIndexTest.scala b/community/cypher/interpreted-runtime/src/test/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/ContainerIndexTest.scala index 880601d632160..59cfe3a20a730 100644 --- a/community/cypher/interpreted-runtime/src/test/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/ContainerIndexTest.scala +++ b/community/cypher/interpreted-runtime/src/test/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/ContainerIndexTest.scala @@ -20,17 +20,14 @@ package org.neo4j.cypher.internal.runtime.interpreted.commands.expressions import org.mockito.Mockito._ -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer +import org.neo4j.cypher.internal.runtime.QueryContext import org.neo4j.cypher.internal.runtime.interpreted.{ExecutionContext, QueryStateHelper} -import org.neo4j.cypher.internal.runtime.{Operations, QueryContext} -import org.opencypher.v9_0.util.test_helpers.CypherFunSuite -import org.opencypher.v9_0.util.{CypherTypeException, InvalidArgumentException} import org.neo4j.graphdb.{Node, Relationship} import org.neo4j.values.AnyValue +import org.neo4j.values.storable.Values import org.neo4j.values.storable.Values.longValue -import org.neo4j.values.storable.{Value, Values} -import org.neo4j.values.virtual.{RelationshipValue, NodeValue} +import org.opencypher.v9_0.util.test_helpers.CypherFunSuite +import org.opencypher.v9_0.util.{CypherTypeException, InvalidArgumentException} import scala.collection.JavaConverters._ @@ -86,17 +83,9 @@ class ContainerIndexTest extends CypherFunSuite { val node = mock[Node] when(node.getId).thenReturn(0) implicit val expression = Literal(node) - when(qtx.getOptPropertyKeyId("v")).thenReturn(Some(0)) - when(qtx.getOptPropertyKeyId("c")).thenReturn(Some(1)) - val nodeOps = mock[Operations[NodeValue]] - when(nodeOps.getProperty(0, 0)).thenAnswer(new Answer[Value] { - override def answer(invocation: InvocationOnMock): Value = Values.longValue(1) - }) - when(nodeOps.getProperty(0, 1)).thenAnswer(new Answer[Value] { - override def answer(invocation: InvocationOnMock): Value = Values.NO_VALUE - }) - when(qtx.nodeOps).thenReturn(nodeOps) + when(qtx.nodeProperty(0, "v")).thenReturn(longValue(1)) + when(qtx.nodeProperty(0, "c")).thenReturn(Values.NO_VALUE) idx("v") should equal(longValue(1)) idx("c") should equal(expectedNull) } @@ -105,17 +94,9 @@ class ContainerIndexTest extends CypherFunSuite { val rel = mock[Relationship] when(rel.getId).thenReturn(0) implicit val expression = Literal(rel) - when(qtx.getOptPropertyKeyId("v")).thenReturn(Some(0)) - when(qtx.getOptPropertyKeyId("c")).thenReturn(Some(1)) - val relOps = mock[Operations[RelationshipValue]] - when(relOps.getProperty(0, 0)).thenAnswer(new Answer[Value] { - override def answer(invocation: InvocationOnMock): Value = Values.longValue(1) - }) - when(relOps.getProperty(0, 1)).thenAnswer(new Answer[Value] { - override def answer(invocation: InvocationOnMock): Value = Values.NO_VALUE - }) - when(qtx.relationshipOps).thenReturn(relOps) + when(qtx.relationshipProperty(0, "v")).thenReturn(longValue(1)) + when(qtx.relationshipProperty(0, "c")).thenReturn(Values.NO_VALUE) idx("v") should equal(longValue(1)) idx("c") should equal(expectedNull) } diff --git a/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/operations/CypherFunctions.java b/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/operations/CypherFunctions.java index c9ff1983738fa..606d5e36c148e 100644 --- a/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/operations/CypherFunctions.java +++ b/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/operations/CypherFunctions.java @@ -26,12 +26,14 @@ import org.neo4j.cypher.internal.runtime.DbAccess; import org.neo4j.values.AnyValue; +import org.neo4j.values.SequenceValue; import org.neo4j.values.storable.BooleanValue; import org.neo4j.values.storable.DoubleValue; import org.neo4j.values.storable.IntegralValue; import org.neo4j.values.storable.LongValue; import org.neo4j.values.storable.NumberValue; import org.neo4j.values.storable.PointValue; +import org.neo4j.values.storable.TextValue; import org.neo4j.values.storable.Value; import org.neo4j.values.virtual.ListValue; import org.neo4j.values.virtual.MapValue; @@ -41,6 +43,7 @@ import org.neo4j.values.virtual.VirtualRelationshipValue; import org.neo4j.values.virtual.VirtualValues; +import static java.lang.String.format; import static org.neo4j.values.storable.Values.FALSE; import static org.neo4j.values.storable.Values.NO_VALUE; import static org.neo4j.values.storable.Values.TRUE; @@ -343,7 +346,7 @@ public static NodeValue startNode( AnyValue anyValue, DbAccess access ) } else { - throw new CypherTypeException( String.format( "Expected %s to be a RelationshipValue", anyValue), null ); + throw new CypherTypeException( format( "Expected %s to be a RelationshipValue", anyValue), null ); } } @@ -355,30 +358,97 @@ public static NodeValue endNode( AnyValue anyValue, DbAccess access ) } else { - throw new CypherTypeException( String.format( "Expected %s to be a RelationshipValue", anyValue), null ); + throw new CypherTypeException( format( "Expected %s to be a RelationshipValue", anyValue), null ); } } - public static BooleanValue propertyExists( String key, AnyValue holder, DbAccess dbAccess ) + public static BooleanValue propertyExists( String key, AnyValue container, DbAccess dbAccess ) { - if ( holder instanceof VirtualNodeValue ) + if ( container instanceof VirtualNodeValue ) { - return dbAccess.nodeHasProperty( ((VirtualNodeValue) holder).id(), key ) ? TRUE : FALSE; + return dbAccess.nodeHasProperty( ((VirtualNodeValue) container).id(), key ) ? TRUE : FALSE; } - else if ( holder instanceof VirtualRelationshipValue ) + else if ( container instanceof VirtualRelationshipValue ) { - return dbAccess.relationshipHasProperty( ((VirtualRelationshipValue) holder).id(), key ) ? TRUE : FALSE; + return dbAccess.relationshipHasProperty( ((VirtualRelationshipValue) container).id(), key ) ? TRUE : FALSE; } - else if ( holder instanceof MapValue ) + else if ( container instanceof MapValue ) { - return ((MapValue) holder).containsKey( key ) ? TRUE : FALSE; + return ((MapValue) container).containsKey( key ) ? TRUE : FALSE; } else { - throw new CypherTypeException( String.format( "Expected %s to be a property container", holder), null ); + throw new CypherTypeException( format( "Expected %s to be a property container", container), null ); } } + public static AnyValue containerIndex( AnyValue container, AnyValue index, DbAccess dbAccess ) + { + if ( container instanceof VirtualNodeValue ) + { + return dbAccess.nodeProperty( ((VirtualNodeValue) container).id(), asString( index ) ); + } + else if ( container instanceof VirtualRelationshipValue ) + { + return dbAccess.relationshipProperty( ((VirtualRelationshipValue) container).id(), asString( index ) ); + } + if ( container instanceof MapValue ) + { + return mapAccess( (MapValue) container, index ); + } + else if ( container instanceof SequenceValue ) + { + return listAccess( (SequenceValue) container, index ); + } + else + { + throw new CypherTypeException( format( + "`%s` is not a collection or a map. Element access is only possible by performing a collection " + + "lookup using an integer index, or by performing a map lookup using a string key (found: %s[%s])", + container, container, index ), null ); + } + } + + private static AnyValue listAccess( SequenceValue container, AnyValue index ) + { + if ( !(index instanceof IntegralValue) ) + { + throw new CypherTypeException( format( "Expected %s to be an integer", index), null ); + } + long idx = ((IntegralValue) index).longValue(); + if ( idx > Integer.MAX_VALUE || idx < Integer.MIN_VALUE ) + { + throw new InvalidArgumentException( + format( "Cannot index a list using a value greater than %d or lesser than %d, got %d", + Integer.MAX_VALUE, Integer.MIN_VALUE, idx ), null ); + } + + if ( idx < 0 ) + { + idx = container.length() + idx; + } + if ( idx >= container.length() || idx < 0 ) + { + return NO_VALUE; + } + return container.value( (int) idx ); + } + + private static AnyValue mapAccess( MapValue container, AnyValue index ) + { + + return container.get( asString( index ) ); + } + + private static String asString( AnyValue value ) + { + if ( !(value instanceof TextValue) ) + { + throw new CypherTypeException( format( "Expected %s to be an index key", value), null ); + } + return ((TextValue) value).stringValue(); + } + private static Value calculateDistance( PointValue p1, PointValue p2 ) { if ( p1.getCoordinateReferenceSystem().equals( p2.getCoordinateReferenceSystem() ) ) @@ -406,6 +476,6 @@ private static long asLong( AnyValue value ) private static CypherTypeException needsNumbers( String method ) { - return new CypherTypeException( String.format( "%s requires numbers", method ), null ); + return new CypherTypeException( format( "%s requires numbers", method ), null ); } } diff --git a/enterprise/cypher/compiled-expressions/src/main/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/IntermediateCodeGeneration.scala b/enterprise/cypher/compiled-expressions/src/main/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/IntermediateCodeGeneration.scala index aa347195e5b38..d58a7979ccf92 100644 --- a/enterprise/cypher/compiled-expressions/src/main/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/IntermediateCodeGeneration.scala +++ b/enterprise/cypher/compiled-expressions/src/main/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/IntermediateCodeGeneration.scala @@ -183,6 +183,13 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) { case CoerceToPredicate(inner) => compile(inner).map(coerceToPredicate) //data access + case ContainerIndex(container, index) => + for {c <- compile(container) + idx <- compile(index) + } yield IntermediateExpression( + noValueCheck(c)(invokeStatic(method[CypherFunctions, AnyValue, AnyValue, AnyValue, DbAccess]("containerIndex"), + c.ir, idx.ir, load("dbAccess"))), nullable = true) + case Parameter(name, _) => //TODO parameters that are autogenerated from literals should have nullable = false Some(IntermediateExpression(invoke(load("params"), method[MapValue, AnyValue, String]("get"), constant(name)), nullable = true)) @@ -441,10 +448,10 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) { invokeStatic(method[CypherFunctions, BooleanValue, String, AnyValue, DbAccess]("propertyExists"), constant(property.propertyKey.name), in.ir, load("dbAccess") )), in.nullable)) - case e: ContainerIndex => None case _: PatternExpression => None//TODO case _: NestedPipeExpression => None//TODO? case _: NestedPlanExpression => None//TODO + case _ => None } case _ => None diff --git a/enterprise/cypher/compiled-expressions/src/test/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/CodeGenerationTest.scala b/enterprise/cypher/compiled-expressions/src/test/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/CodeGenerationTest.scala index 700c1fa209151..17073c4aeabde 100644 --- a/enterprise/cypher/compiled-expressions/src/test/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/CodeGenerationTest.scala +++ b/enterprise/cypher/compiled-expressions/src/test/scala/org/neo4j/cypher/internal/runtime/compiled/expressions/CodeGenerationTest.scala @@ -42,6 +42,7 @@ import org.neo4j.values.storable.{DoubleValue, Values} import org.neo4j.values.virtual.VirtualValues._ import org.neo4j.values.virtual.{NodeValue, RelationshipValue} import org.opencypher.v9_0.ast.AstConstructionTestSupport +import org.opencypher.v9_0.expressions import org.opencypher.v9_0.expressions._ import org.opencypher.v9_0.util.test_helpers.CypherFunSuite import org.opencypher.v9_0.util.{CypherTypeException, symbols} @@ -716,6 +717,43 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport compile(IsPrimitiveNull(offset)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE) } + test("containerIndex on node") { + val node = nodeValue(1, EMPTY_TEXT_ARRAY, map(Array("prop"), Array(stringValue("hello")))) + when(db.nodeProperty(1, "prop")).thenReturn(stringValue("hello")) + val compiled = compile(containerIndex(parameter("a"), literalString("prop"))) + + compiled.evaluate(ctx, db, map(Array("a"), Array(node))) should equal(stringValue("hello")) + compiled.evaluate(ctx, db, map(Array("a"), Array(NO_VALUE))) should equal(NO_VALUE) + } + + test("containerIndex on relationship") { + val rel = relationshipValue(43, + nodeValue(1, EMPTY_TEXT_ARRAY, EMPTY_MAP), + nodeValue(2, EMPTY_TEXT_ARRAY, EMPTY_MAP), + stringValue("R"), map(Array("prop"), Array(stringValue("hello")))) + when(db.relationshipProperty(43, "prop")).thenReturn(stringValue("hello")) + val compiled = compile(containerIndex(parameter("a"), literalString("prop"))) + + compiled.evaluate(ctx, db, map(Array("a"), Array(rel))) should equal(stringValue("hello")) + compiled.evaluate(ctx, db, map(Array("a"), Array(NO_VALUE))) should equal(NO_VALUE) + } + + test("containerIndex on map") { + val mapValue = map(Array("prop"), Array(stringValue("hello"))) + val compiled = compile(containerIndex(parameter("a"), literalString("prop"))) + + compiled.evaluate(ctx, db, map(Array("a"), Array(mapValue))) should equal(stringValue("hello")) + compiled.evaluate(ctx, db, map(Array("a"), Array(NO_VALUE))) should equal(NO_VALUE) + } + + test("containerIndex on list") { + val listValue = list(longValue(42), stringValue("hello"), intValue(42)) + val compiled = compile(containerIndex(parameter("a"), literalInt(1))) + + compiled.evaluate(ctx, db, map(Array("a"), Array(listValue))) should equal(stringValue("hello")) + compiled.evaluate(ctx, db, map(Array("a"), Array(NO_VALUE))) should equal(NO_VALUE) + } + private def compile(e: Expression) = CodeGeneration.compile(new IntermediateCodeGeneration(SlotConfiguration.empty).compile(e).map(_.ir).getOrElse(fail())) @@ -757,4 +795,9 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport private def notEquals(lhs: Expression, rhs: Expression) = NotEquals(lhs, rhs)(pos) private def property(map: Expression, key: String) = Property(map, PropertyKeyName(key)(pos))(pos) + + private def containerIndex(container: Expression, index: Expression) = ContainerIndex(container, index)(pos) + + private def literalString(s: String) = expressions.StringLiteral(s)(pos) + }