Skip to content

Commit

Permalink
Support for container-index access
Browse files Browse the repository at this point in the history
  • Loading branch information
pontusmelke committed Jul 9, 2018
1 parent 7640d16 commit 8c5e8b3
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 88 deletions.
Expand Up @@ -19,61 +19,20 @@
*/
package org.neo4j.cypher.internal.runtime.interpreted.commands.expressions

import org.opencypher.v9_0.util.{CypherTypeException, InvalidArgumentException}
import org.neo4j.cypher.internal.runtime.interpreted.ExecutionContext
import org.neo4j.cypher.internal.runtime.interpreted.{CastSupport, IsList, IsMap, ListSupport}
import org.neo4j.cypher.internal.runtime.interpreted.pipes.QueryState
import org.neo4j.cypher.internal.runtime.interpreted.{ExecutionContext, ListSupport}
import org.neo4j.cypher.operations.CypherFunctions
import org.neo4j.values._
import org.neo4j.values.storable._
import org.neo4j.values.storable.Values

case class ContainerIndex(expression: Expression, index: Expression) extends NullInNullOutExpression(expression)
case class ContainerIndex(expression: Expression, index: Expression) extends Expression
with ListSupport {
def arguments = Seq(expression, index)

override def compute(value: AnyValue, ctx: ExecutionContext, state: QueryState): AnyValue = {
value match {
case IsMap(m) =>
val item = index(ctx, state)
if (item == Values.NO_VALUE) Values.NO_VALUE
else {
val key = CastSupport.castOrFail[TextValue](item)
m(state.query).get(key.stringValue())
}

case IsList(collection) =>
val item = index(ctx, state)
if (item == Values.NO_VALUE) Values.NO_VALUE
else {
var idx = validateTypeAndRange(item)

if (idx < 0)
idx = collection.size + idx

if (idx >= collection.size || idx < 0) Values.NO_VALUE
else collection.value(idx)
}

case _ =>
val indexValue = index(ctx, state)
throw new CypherTypeException(
s"`$value` is not a collection or a map. Element access is only possible by performing a collection lookup using an integer index, or by performing a map lookup using a string key (found: $value[$indexValue])")
}
}

private def validateTypeAndRange(item: AnyValue): Int = {
val number = CastSupport.castOrFail[NumberValue](item)

val longValue = number match {
case _: FloatValue | _: DoubleValue=>
throw new CypherTypeException(s"Cannot index a list using an non-integer number, got $number")
case _ => number.longValue()
}

if (longValue > Int.MaxValue || longValue < Int.MinValue)
throw new InvalidArgumentException(
s"Cannot index a list using a value greater than ${Int.MaxValue} or lesser than ${Int.MinValue}, got $number")

longValue.toInt
override def apply(ctx: ExecutionContext,
state: QueryState): AnyValue = expression(ctx, state) match {
case Values.NO_VALUE => Values.NO_VALUE
case value => CypherFunctions.containerIndex(value, index(ctx, state), state.query)
}

def rewrite(f: (Expression) => Expression): Expression = f(ContainerIndex(expression.rewrite(f), index.rewrite(f)))
Expand Down
Expand Up @@ -20,17 +20,14 @@
package org.neo4j.cypher.internal.runtime.interpreted.commands.expressions

import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.neo4j.cypher.internal.runtime.QueryContext
import org.neo4j.cypher.internal.runtime.interpreted.{ExecutionContext, QueryStateHelper}
import org.neo4j.cypher.internal.runtime.{Operations, QueryContext}
import org.opencypher.v9_0.util.test_helpers.CypherFunSuite
import org.opencypher.v9_0.util.{CypherTypeException, InvalidArgumentException}
import org.neo4j.graphdb.{Node, Relationship}
import org.neo4j.values.AnyValue
import org.neo4j.values.storable.Values
import org.neo4j.values.storable.Values.longValue
import org.neo4j.values.storable.{Value, Values}
import org.neo4j.values.virtual.{RelationshipValue, NodeValue}
import org.opencypher.v9_0.util.test_helpers.CypherFunSuite
import org.opencypher.v9_0.util.{CypherTypeException, InvalidArgumentException}

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -86,17 +83,9 @@ class ContainerIndexTest extends CypherFunSuite {
val node = mock[Node]
when(node.getId).thenReturn(0)
implicit val expression = Literal(node)
when(qtx.getOptPropertyKeyId("v")).thenReturn(Some(0))
when(qtx.getOptPropertyKeyId("c")).thenReturn(Some(1))
val nodeOps = mock[Operations[NodeValue]]
when(nodeOps.getProperty(0, 0)).thenAnswer(new Answer[Value] {
override def answer(invocation: InvocationOnMock): Value = Values.longValue(1)
})
when(nodeOps.getProperty(0, 1)).thenAnswer(new Answer[Value] {
override def answer(invocation: InvocationOnMock): Value = Values.NO_VALUE
})
when(qtx.nodeOps).thenReturn(nodeOps)

when(qtx.nodeProperty(0, "v")).thenReturn(longValue(1))
when(qtx.nodeProperty(0, "c")).thenReturn(Values.NO_VALUE)
idx("v") should equal(longValue(1))
idx("c") should equal(expectedNull)
}
Expand All @@ -105,17 +94,9 @@ class ContainerIndexTest extends CypherFunSuite {
val rel = mock[Relationship]
when(rel.getId).thenReturn(0)
implicit val expression = Literal(rel)
when(qtx.getOptPropertyKeyId("v")).thenReturn(Some(0))
when(qtx.getOptPropertyKeyId("c")).thenReturn(Some(1))
val relOps = mock[Operations[RelationshipValue]]
when(relOps.getProperty(0, 0)).thenAnswer(new Answer[Value] {
override def answer(invocation: InvocationOnMock): Value = Values.longValue(1)
})
when(relOps.getProperty(0, 1)).thenAnswer(new Answer[Value] {
override def answer(invocation: InvocationOnMock): Value = Values.NO_VALUE
})
when(qtx.relationshipOps).thenReturn(relOps)

when(qtx.relationshipProperty(0, "v")).thenReturn(longValue(1))
when(qtx.relationshipProperty(0, "c")).thenReturn(Values.NO_VALUE)
idx("v") should equal(longValue(1))
idx("c") should equal(expectedNull)
}
Expand Down
Expand Up @@ -26,12 +26,14 @@

import org.neo4j.cypher.internal.runtime.DbAccess;
import org.neo4j.values.AnyValue;
import org.neo4j.values.SequenceValue;
import org.neo4j.values.storable.BooleanValue;
import org.neo4j.values.storable.DoubleValue;
import org.neo4j.values.storable.IntegralValue;
import org.neo4j.values.storable.LongValue;
import org.neo4j.values.storable.NumberValue;
import org.neo4j.values.storable.PointValue;
import org.neo4j.values.storable.TextValue;
import org.neo4j.values.storable.Value;
import org.neo4j.values.virtual.ListValue;
import org.neo4j.values.virtual.MapValue;
Expand All @@ -41,6 +43,7 @@
import org.neo4j.values.virtual.VirtualRelationshipValue;
import org.neo4j.values.virtual.VirtualValues;

import static java.lang.String.format;
import static org.neo4j.values.storable.Values.FALSE;
import static org.neo4j.values.storable.Values.NO_VALUE;
import static org.neo4j.values.storable.Values.TRUE;
Expand Down Expand Up @@ -343,7 +346,7 @@ public static NodeValue startNode( AnyValue anyValue, DbAccess access )
}
else
{
throw new CypherTypeException( String.format( "Expected %s to be a RelationshipValue", anyValue), null );
throw new CypherTypeException( format( "Expected %s to be a RelationshipValue", anyValue), null );
}
}

Expand All @@ -355,30 +358,97 @@ public static NodeValue endNode( AnyValue anyValue, DbAccess access )
}
else
{
throw new CypherTypeException( String.format( "Expected %s to be a RelationshipValue", anyValue), null );
throw new CypherTypeException( format( "Expected %s to be a RelationshipValue", anyValue), null );
}
}

public static BooleanValue propertyExists( String key, AnyValue holder, DbAccess dbAccess )
public static BooleanValue propertyExists( String key, AnyValue container, DbAccess dbAccess )
{
if ( holder instanceof VirtualNodeValue )
if ( container instanceof VirtualNodeValue )
{
return dbAccess.nodeHasProperty( ((VirtualNodeValue) holder).id(), key ) ? TRUE : FALSE;
return dbAccess.nodeHasProperty( ((VirtualNodeValue) container).id(), key ) ? TRUE : FALSE;
}
else if ( holder instanceof VirtualRelationshipValue )
else if ( container instanceof VirtualRelationshipValue )
{
return dbAccess.relationshipHasProperty( ((VirtualRelationshipValue) holder).id(), key ) ? TRUE : FALSE;
return dbAccess.relationshipHasProperty( ((VirtualRelationshipValue) container).id(), key ) ? TRUE : FALSE;
}
else if ( holder instanceof MapValue )
else if ( container instanceof MapValue )
{
return ((MapValue) holder).containsKey( key ) ? TRUE : FALSE;
return ((MapValue) container).containsKey( key ) ? TRUE : FALSE;
}
else
{
throw new CypherTypeException( String.format( "Expected %s to be a property container", holder), null );
throw new CypherTypeException( format( "Expected %s to be a property container", container), null );
}
}

public static AnyValue containerIndex( AnyValue container, AnyValue index, DbAccess dbAccess )
{
if ( container instanceof VirtualNodeValue )
{
return dbAccess.nodeProperty( ((VirtualNodeValue) container).id(), asString( index ) );
}
else if ( container instanceof VirtualRelationshipValue )
{
return dbAccess.relationshipProperty( ((VirtualRelationshipValue) container).id(), asString( index ) );
}
if ( container instanceof MapValue )
{
return mapAccess( (MapValue) container, index );
}
else if ( container instanceof SequenceValue )
{
return listAccess( (SequenceValue) container, index );
}
else
{
throw new CypherTypeException( format(
"`%s` is not a collection or a map. Element access is only possible by performing a collection " +
"lookup using an integer index, or by performing a map lookup using a string key (found: %s[%s])",
container, container, index ), null );
}
}

private static AnyValue listAccess( SequenceValue container, AnyValue index )
{
if ( !(index instanceof IntegralValue) )
{
throw new CypherTypeException( format( "Expected %s to be an integer", index), null );
}
long idx = ((IntegralValue) index).longValue();
if ( idx > Integer.MAX_VALUE || idx < Integer.MIN_VALUE )
{
throw new InvalidArgumentException(
format( "Cannot index a list using a value greater than %d or lesser than %d, got %d",
Integer.MAX_VALUE, Integer.MIN_VALUE, idx ), null );
}

if ( idx < 0 )
{
idx = container.length() + idx;
}
if ( idx >= container.length() || idx < 0 )
{
return NO_VALUE;
}
return container.value( (int) idx );
}

private static AnyValue mapAccess( MapValue container, AnyValue index )
{

return container.get( asString( index ) );
}

private static String asString( AnyValue value )
{
if ( !(value instanceof TextValue) )
{
throw new CypherTypeException( format( "Expected %s to be an index key", value), null );
}
return ((TextValue) value).stringValue();
}

private static Value calculateDistance( PointValue p1, PointValue p2 )
{
if ( p1.getCoordinateReferenceSystem().equals( p2.getCoordinateReferenceSystem() ) )
Expand Down Expand Up @@ -406,6 +476,6 @@ private static long asLong( AnyValue value )

private static CypherTypeException needsNumbers( String method )
{
return new CypherTypeException( String.format( "%s requires numbers", method ), null );
return new CypherTypeException( format( "%s requires numbers", method ), null );
}
}
Expand Up @@ -183,6 +183,13 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) {
case CoerceToPredicate(inner) => compile(inner).map(coerceToPredicate)

//data access
case ContainerIndex(container, index) =>
for {c <- compile(container)
idx <- compile(index)
} yield IntermediateExpression(
noValueCheck(c)(invokeStatic(method[CypherFunctions, AnyValue, AnyValue, AnyValue, DbAccess]("containerIndex"),
c.ir, idx.ir, load("dbAccess"))), nullable = true)

case Parameter(name, _) => //TODO parameters that are autogenerated from literals should have nullable = false
Some(IntermediateExpression(invoke(load("params"), method[MapValue, AnyValue, String]("get"),
constant(name)), nullable = true))
Expand Down Expand Up @@ -441,10 +448,10 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) {
invokeStatic(method[CypherFunctions, BooleanValue, String, AnyValue, DbAccess]("propertyExists"),
constant(property.propertyKey.name),
in.ir, load("dbAccess") )), in.nullable))
case e: ContainerIndex => None
case _: PatternExpression => None//TODO
case _: NestedPipeExpression => None//TODO?
case _: NestedPlanExpression => None//TODO
case _ => None
}

case _ => None
Expand Down
Expand Up @@ -42,6 +42,7 @@ import org.neo4j.values.storable.{DoubleValue, Values}
import org.neo4j.values.virtual.VirtualValues._
import org.neo4j.values.virtual.{NodeValue, RelationshipValue}
import org.opencypher.v9_0.ast.AstConstructionTestSupport
import org.opencypher.v9_0.expressions
import org.opencypher.v9_0.expressions._
import org.opencypher.v9_0.util.test_helpers.CypherFunSuite
import org.opencypher.v9_0.util.{CypherTypeException, symbols}
Expand Down Expand Up @@ -716,6 +717,43 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport
compile(IsPrimitiveNull(offset)).evaluate(ctx, db, EMPTY_MAP) should equal(Values.FALSE)
}

test("containerIndex on node") {
val node = nodeValue(1, EMPTY_TEXT_ARRAY, map(Array("prop"), Array(stringValue("hello"))))
when(db.nodeProperty(1, "prop")).thenReturn(stringValue("hello"))
val compiled = compile(containerIndex(parameter("a"), literalString("prop")))

compiled.evaluate(ctx, db, map(Array("a"), Array(node))) should equal(stringValue("hello"))
compiled.evaluate(ctx, db, map(Array("a"), Array(NO_VALUE))) should equal(NO_VALUE)
}

test("containerIndex on relationship") {
val rel = relationshipValue(43,
nodeValue(1, EMPTY_TEXT_ARRAY, EMPTY_MAP),
nodeValue(2, EMPTY_TEXT_ARRAY, EMPTY_MAP),
stringValue("R"), map(Array("prop"), Array(stringValue("hello"))))
when(db.relationshipProperty(43, "prop")).thenReturn(stringValue("hello"))
val compiled = compile(containerIndex(parameter("a"), literalString("prop")))

compiled.evaluate(ctx, db, map(Array("a"), Array(rel))) should equal(stringValue("hello"))
compiled.evaluate(ctx, db, map(Array("a"), Array(NO_VALUE))) should equal(NO_VALUE)
}

test("containerIndex on map") {
val mapValue = map(Array("prop"), Array(stringValue("hello")))
val compiled = compile(containerIndex(parameter("a"), literalString("prop")))

compiled.evaluate(ctx, db, map(Array("a"), Array(mapValue))) should equal(stringValue("hello"))
compiled.evaluate(ctx, db, map(Array("a"), Array(NO_VALUE))) should equal(NO_VALUE)
}

test("containerIndex on list") {
val listValue = list(longValue(42), stringValue("hello"), intValue(42))
val compiled = compile(containerIndex(parameter("a"), literalInt(1)))

compiled.evaluate(ctx, db, map(Array("a"), Array(listValue))) should equal(stringValue("hello"))
compiled.evaluate(ctx, db, map(Array("a"), Array(NO_VALUE))) should equal(NO_VALUE)
}

private def compile(e: Expression) =
CodeGeneration.compile(new IntermediateCodeGeneration(SlotConfiguration.empty).compile(e).map(_.ir).getOrElse(fail()))

Expand Down Expand Up @@ -757,4 +795,9 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport
private def notEquals(lhs: Expression, rhs: Expression) = NotEquals(lhs, rhs)(pos)

private def property(map: Expression, key: String) = Property(map, PropertyKeyName(key)(pos))(pos)

private def containerIndex(container: Expression, index: Expression) = ContainerIndex(container, index)(pos)

private def literalString(s: String) = expressions.StringLiteral(s)(pos)

}

0 comments on commit 8c5e8b3

Please sign in to comment.