Skip to content

Commit

Permalink
Add support for or
Browse files Browse the repository at this point in the history
  • Loading branch information
pontusmelke committed May 30, 2018
1 parent aae4c82 commit 62db1fb
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,29 @@
import org.neo4j.internal.kernel.api.Transaction;
import org.neo4j.kernel.impl.util.ValueUtils;
import org.neo4j.values.AnyValue;
import org.neo4j.values.SequenceValue;
import org.neo4j.values.ValueMapper;
import org.neo4j.values.storable.ArrayValue;
import org.neo4j.values.storable.BooleanValue;
import org.neo4j.values.storable.DateTimeValue;
import org.neo4j.values.storable.DateValue;
import org.neo4j.values.storable.DoubleValue;
import org.neo4j.values.storable.DurationValue;
import org.neo4j.values.storable.LocalDateTimeValue;
import org.neo4j.values.storable.LocalTimeValue;
import org.neo4j.values.storable.NumberValue;
import org.neo4j.values.storable.PointValue;
import org.neo4j.values.storable.StringValue;
import org.neo4j.values.storable.TemporalValue;
import org.neo4j.values.storable.TextValue;
import org.neo4j.values.storable.TimeValue;
import org.neo4j.values.storable.Value;
import org.neo4j.values.storable.Values;
import org.neo4j.values.virtual.ListValue;
import org.neo4j.values.virtual.MapValue;
import org.neo4j.values.virtual.PathValue;
import org.neo4j.values.virtual.VirtualNodeValue;
import org.neo4j.values.virtual.VirtualRelationshipValue;
import org.neo4j.values.virtual.VirtualValues;

import static org.neo4j.values.storable.Values.NO_VALUE;
Expand All @@ -57,6 +69,8 @@
*/
public final class ExpressionMethods
{
private static final BooleanMapper BOOLEAN_MAPPER = new BooleanMapper();

private ExpressionMethods()
{
throw new UnsupportedOperationException( "Do not instantiate" );
Expand Down Expand Up @@ -126,12 +140,12 @@ else if ( rhs instanceof ListValue )
// exclude them
if ( !(rhs instanceof TemporalValue || rhs instanceof DurationValue || rhs instanceof PointValue) )
{
return stringValue( ((TextValue) lhs).stringValue() + ((Value) rhs).prettyPrint());
return stringValue( ((TextValue) lhs).stringValue() + ((Value) rhs).prettyPrint() );
}
else
{
//TODO this seems wrong but it is what we currently do in compiled runtime
return stringValue(((TextValue) lhs).stringValue() + String.valueOf( rhs ));
return stringValue( ((TextValue) lhs).stringValue() + String.valueOf( rhs ) );
}
}
}
Expand All @@ -143,12 +157,12 @@ else if ( rhs instanceof ListValue )
// exclude them
if ( !(lhs instanceof TemporalValue || lhs instanceof DurationValue || lhs instanceof PointValue) )
{
return stringValue( ((Value) lhs).prettyPrint() + ((TextValue) rhs).stringValue());
return stringValue( ((Value) lhs).prettyPrint() + ((TextValue) rhs).stringValue() );
}
else
{
//TODO this seems wrong but it is what we currently do in compiled runtime
return stringValue(String.valueOf( lhs ) + ((TextValue) rhs).stringValue() );
return stringValue( String.valueOf( lhs ) + ((TextValue) rhs).stringValue() );
}
}
}
Expand All @@ -174,7 +188,7 @@ else if ( rhs instanceof ListValue )
}

throw new CypherTypeException(
String.format( "Don't know how to add `%s` and `%s`", lhs, rhs), null );
String.format( "Don't know how to add `%s` and `%s`", lhs, rhs ), null );
}

public static AnyValue subtract( AnyValue lhs, AnyValue rhs )
Expand Down Expand Up @@ -205,7 +219,7 @@ public static AnyValue subtract( AnyValue lhs, AnyValue rhs )
}

throw new CypherTypeException(
String.format( "Don't know how to subtract `%s` and `%s`", lhs, rhs), null );
String.format( "Don't know how to subtract `%s` and `%s`", lhs, rhs ), null );
}

public static AnyValue multiply( AnyValue lhs, AnyValue rhs )
Expand Down Expand Up @@ -234,9 +248,10 @@ public static AnyValue multiply( AnyValue lhs, AnyValue rhs )
}
}
throw new CypherTypeException(
String.format( "Don't know how to subtract `%s` and `%s`", lhs, rhs), null );
String.format( "Don't know how to subtract `%s` and `%s`", lhs, rhs ), null );
}

//data access
public static Value nodeProperty( Transaction tx, long node, int property )
{
CursorFactory cursors = tx.cursors();
Expand Down Expand Up @@ -269,7 +284,132 @@ public static Value relationshipProperty( Transaction tx, long relationship, int
}
}

private static Value property(PropertyCursor properties, int property)
//boolean operations

public static Value or( AnyValue... args )
{
for ( AnyValue arg : args )
{
if ( arg == NO_VALUE )
{
return NO_VALUE;
}

if ( arg.map( BOOLEAN_MAPPER ) )
{
return Values.TRUE;
}
}
return Values.FALSE;
}

private static final class BooleanMapper implements ValueMapper<Boolean>
{

@Override
public Boolean mapPath( PathValue value )
{
return value.size() > 0;
}

@Override
public Boolean mapNode( VirtualNodeValue value )
{
throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null);
}

@Override
public Boolean mapRelationship( VirtualRelationshipValue value )
{
throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null);
}

@Override
public Boolean mapMap( MapValue value )
{
throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null);
}

@Override
public Boolean mapNoValue()
{
throw new CypherTypeException( "Don't know how to treat that as a boolean: " + NO_VALUE, null);
}

@Override
public Boolean mapSequence( SequenceValue value )
{
return value.length() > 0;
}

@Override
public Boolean mapText( TextValue value )
{
throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null);
}

@Override
public Boolean mapBoolean( BooleanValue value )
{
return value.booleanValue();
}

@Override
public Boolean mapNumber( NumberValue value )
{
throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null);

}

@Override
public Boolean mapDateTime( DateTimeValue value )
{
throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null);

}

@Override
public Boolean mapLocalDateTime( LocalDateTimeValue value )
{
throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null);

}

@Override
public Boolean mapDate( DateValue value )
{
throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null);
}

@Override
public Boolean mapTime( TimeValue value )
{
throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null);

}

@Override
public Boolean mapLocalTime( LocalTimeValue value )
{
throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null);

}

@Override
public Boolean mapDuration( DurationValue value )
{
throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null);

}

@Override
public Boolean mapPoint( PointValue value )
{
throw new CypherTypeException( "Don't know how to treat that as a boolean: " + value, null);
}
}

private static Value property( PropertyCursor properties, int property )
{
while ( properties.next() )
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ object CodeGeneration {
case TRUE => getStatic(staticField(VALUES, classOf[BooleanValue], "TRUE"))
//Values.FALSE
case FALSE => getStatic(staticField(VALUES, classOf[BooleanValue], "FALSE"))
//Loads an AnyValue[]
case ArrayLiteral(values) => Expression.newArray(TypeReference.typeReference(classOf[AnyValue]),
values.map(v => compileExpression(v, block)):_*)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.neo4j.cypher.internal.runtime.interpreted.ExecutionContext
import org.neo4j.internal.kernel.api.Transaction
import org.neo4j.values.AnyValue
import org.neo4j.values.storable.Values.{doubleValue, longValue}
import org.neo4j.values.storable.{DoubleValue, Value, Values}
import org.neo4j.values.storable.{BooleanValue, DoubleValue, Value, Values}
import org.neo4j.values.virtual.MapValue
import org.opencypher.v9_0.expressions
import org.opencypher.v9_0.expressions._
Expand Down Expand Up @@ -95,7 +95,7 @@ object IntermediateCodeGeneration {
case Or(lhs, rhs) =>
(compile(lhs), compile(rhs)) match {
case (Some(l), Some(r)) =>
Some(invokeStatic(method[ExpressionMethods, AnyValue, AnyValue, AnyValue]("or"), l, r))
Some(invokeStatic(method[ExpressionMethods, Value, Array[AnyValue]]("or"), arrayOf(l, r)))
case _ => None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ case object TRUE extends IntermediateRepresentation
*/
case object FALSE extends IntermediateRepresentation

case class ArrayLiteral(values: Array[IntermediateRepresentation]) extends IntermediateRepresentation

/**
* Defines a method
*
Expand Down Expand Up @@ -159,4 +161,6 @@ object IntermediateRepresentation {
def falsy: IntermediateRepresentation = FALSE

def constantJavaValue(value: Any): IntermediateRepresentation = Constant(value)

def arrayOf(values: IntermediateRepresentation*): IntermediateRepresentation = ArrayLiteral(values.toArray)
}
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport

test("True") {
// Given
val expression = t()
val expression = t

// When
val compiled = compile(expression)
Expand All @@ -178,7 +178,7 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport

test("False") {
// Given
val expression = f()
val expression = f

// When
val compiled = compile(expression)
Expand All @@ -189,14 +189,13 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport


test("or") {
// Given
val expression = or(t(), t())

// When
val compiled = compile(expression)

// Then
compiled.compute(ctx, tx, EMPTY_MAP) should equal(Values.TRUE)
compile(or(t, t)).compute(ctx, tx, EMPTY_MAP) should equal(Values.TRUE)
compile(or(f, t)).compute(ctx, tx, EMPTY_MAP) should equal(Values.TRUE)
compile(or(t, f)).compute(ctx, tx, EMPTY_MAP) should equal(Values.TRUE)
compile(or(f, f)).compute(ctx, tx, EMPTY_MAP) should equal(Values.FALSE)
compile(or(noValue, t)).compute(ctx, tx, EMPTY_MAP) should equal(Values.NO_VALUE)
compile(or(f, noValue)).compute(ctx, tx, EMPTY_MAP) should equal(Values.NO_VALUE)
compile(or(t, noValue)).compute(ctx, tx, EMPTY_MAP) should equal(Values.TRUE)
}

private def compile(e: Expression) =
Expand All @@ -218,9 +217,9 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport

private def noValue = Null()(pos)

private def t() = True()(pos)
private def t = True()(pos)

private def f() = False()(pos)
private def f = False()(pos)

private def or(l: Expression, r: Expression) = Or(l, r)(pos)
}

0 comments on commit 62db1fb

Please sign in to comment.