Skip to content

Commit

Permalink
CoerceTo list must be done recursively
Browse files Browse the repository at this point in the history
  • Loading branch information
pontusmelke authored and craigtaverner committed Jul 24, 2018
1 parent 44c2e80 commit e96bab0
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 48 deletions.
Expand Up @@ -21,8 +21,14 @@


import org.opencypher.v9_0.util.CypherTypeException; import org.opencypher.v9_0.util.CypherTypeException;


import java.util.HashMap;
import java.util.Map;

import org.neo4j.cypher.internal.runtime.DbAccess; import org.neo4j.cypher.internal.runtime.DbAccess;
import org.neo4j.internal.kernel.api.procs.Neo4jTypes;
import org.neo4j.values.AnyValue; import org.neo4j.values.AnyValue;
import org.neo4j.values.SequenceValue;
import org.neo4j.values.ValueMapper;
import org.neo4j.values.storable.ArrayValue; import org.neo4j.values.storable.ArrayValue;
import org.neo4j.values.storable.BooleanValue; import org.neo4j.values.storable.BooleanValue;
import org.neo4j.values.storable.DateTimeValue; import org.neo4j.values.storable.DateTimeValue;
Expand All @@ -42,10 +48,30 @@
import org.neo4j.values.virtual.NodeValue; import org.neo4j.values.virtual.NodeValue;
import org.neo4j.values.virtual.PathValue; import org.neo4j.values.virtual.PathValue;
import org.neo4j.values.virtual.RelationshipValue; import org.neo4j.values.virtual.RelationshipValue;
import org.neo4j.values.virtual.VirtualNodeValue;
import org.neo4j.values.virtual.VirtualRelationshipValue;
import org.neo4j.values.virtual.VirtualValues; import org.neo4j.values.virtual.VirtualValues;


import static java.lang.String.format; import static java.lang.String.format;
import static org.neo4j.values.storable.Values.NO_VALUE; import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTAny;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTBoolean;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTDate;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTDateTime;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTDuration;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTFloat;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTGeometry;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTInteger;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTLocalDateTime;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTLocalTime;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTMap;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTNode;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTNumber;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTPath;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTPoint;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTRelationship;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTString;
import static org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTTime;
import static org.neo4j.values.SequenceValue.IterationPreference.RANDOM_ACCESS;


@SuppressWarnings( {"unused", "WeakerAccess"} ) @SuppressWarnings( {"unused", "WeakerAccess"} )
public final class CypherCoercions public final class CypherCoercions
Expand Down Expand Up @@ -210,52 +236,219 @@ else if ( value instanceof RelationshipValue )
} }
} }


public static ListValue asListValueFailOnPaths( AnyValue value ) public static ListValue asList( AnyValue value, Neo4jTypes.AnyType innerType, DbAccess access )
{
return new ListCoercer().apply( value, innerType, access );
}

private static CypherTypeException cantCoerce( AnyValue value, String type )
{ {
if ( value instanceof PathValue ) return new CypherTypeException( format( "Can't coerce `%s` to %s", value, type ), null );
}

private static class ListMapper implements ValueMapper<ListValue>
{

@Override
public ListValue mapPath( PathValue value )
{ {
throw cantCoerce( value, "List" ); return null;
} }
else
@Override
public ListValue mapNode( VirtualNodeValue value )
{ {
return asList( value ); return null;
} }
}


public static ListValue asListValueSupportPaths( AnyValue value ) @Override
{ public ListValue mapRelationship( VirtualRelationshipValue value )
if ( value instanceof PathValue )
{ {
return ((PathValue) value).asList(); return null;
} }
else
@Override
public ListValue mapMap( MapValue value )
{
return null;
}

@Override
public ListValue mapNoValue()
{
return null;
}

@Override
public ListValue mapSequence( SequenceValue value )
{
return null;
}

@Override
public ListValue mapText( TextValue value )
{
return null;
}

@Override
public ListValue mapBoolean( BooleanValue value )
{
return null;
}

@Override
public ListValue mapNumber( NumberValue value )
{
return null;
}

@Override
public ListValue mapDateTime( DateTimeValue value )
{
return null;
}

@Override
public ListValue mapLocalDateTime( LocalDateTimeValue value )
{
return null;
}

@Override
public ListValue mapDate( DateValue value )
{ {
return asList( value ); return null;
}

@Override
public ListValue mapTime( TimeValue value )
{
return null;
}

@Override
public ListValue mapLocalTime( LocalTimeValue value )
{
return null;
}

@Override
public ListValue mapDuration( DurationValue value )
{
return null;
}

@Override
public ListValue mapPoint( PointValue value )
{
return null;
} }
} }


private static CypherTypeException cantCoerce( AnyValue value, String type )
@FunctionalInterface
interface Coercer
{ {
return new CypherTypeException( format( "Can't coerce `%s` to %s", value, type ), null ); AnyValue apply( AnyValue value, Neo4jTypes.AnyType coerceTo, DbAccess access );
}

private final static Map<Class<? extends Neo4jTypes.AnyType>,Coercer> CONVERTERS = new HashMap<>();

private AnyValue coerceTo( AnyValue value, DbAccess access, Neo4jTypes.AnyType types )
{
Coercer function = CONVERTERS.get( types.getClass() );

return function.apply( value, types, access );
} }


private static ListValue asList( AnyValue value ) private static class ListCoercer implements Coercer
{
@Override
public ListValue apply( AnyValue value, Neo4jTypes.AnyType innerType, DbAccess access )
{
//Fast route
if ( innerType == NTAny )
{
return fastListConversion( value );
}

//slow route, recursively convert the list
if ( !(value instanceof SequenceValue) )
{
throw cantCoerce( value, "List" );
}
SequenceValue listValue = (SequenceValue) value;
Coercer innerCoercer = CONVERTERS.get( innerType.getClass() );
AnyValue[] coercedValues = new AnyValue[listValue.length()];
Neo4jTypes.AnyType nextInner = nextInner( innerType );
if ( listValue.iterationPreference() == RANDOM_ACCESS )
{
for ( int i = 0; i < coercedValues.length; i++ )
{
coercedValues[i] = innerCoercer.apply( listValue.value( i ), nextInner, access );
}
}
else
{
int i = 0;
for ( AnyValue anyValue : listValue )
{
coercedValues[i++] = innerCoercer.apply( anyValue, nextInner, access );
}
}
return VirtualValues.list( coercedValues );
}
}

private static Neo4jTypes.AnyType nextInner( Neo4jTypes.AnyType type )
{
if (type instanceof Neo4jTypes.ListType )
{
return ((Neo4jTypes.ListType) type).innerType();
}
else
{
return type;
}
}

private static ListValue fastListConversion( AnyValue value )
{ {
if ( value instanceof ListValue ) if ( value instanceof ListValue )
{ {
return (ListValue) value; return (ListValue) value;
} }
else if ( value instanceof ArrayValue ) else if (value instanceof ArrayValue )
{ {
return VirtualValues.fromArray( (ArrayValue) value ); return VirtualValues.fromArray( (ArrayValue) value );
} }
else if ( value == NO_VALUE ) else if ( value instanceof PathValue )
{
return VirtualValues.EMPTY_LIST;
}
else
{ {
return VirtualValues.list( value ); return ((PathValue) value).asList();
} }
throw cantCoerce( value, "List" );
}

static {
CONVERTERS.put( NTAny.getClass(), (a, ignore1, ignore2) -> a );
CONVERTERS.put( NTString.getClass(), (a, ignore1, ignore2) -> asTextValue(a) );
CONVERTERS.put( NTNumber.getClass(), (a, ignore1, ignore2) -> asNumberValue(a) );
CONVERTERS.put( NTInteger.getClass(), (a, ignore1, ignore2) -> asIntegralValue(a) );
CONVERTERS.put( NTFloat.getClass(), (a, ignore1, ignore2) -> asFloatingPointValue(a) );
CONVERTERS.put( NTBoolean.getClass(), (a, ignore1, ignore2) -> asBooleanValue(a) );
CONVERTERS.put( NTMap.getClass(), (a, ignore, c) -> asMapValue(a, c) );
CONVERTERS.put( NTNode.getClass(), (a, ignore1, ignore2) ->asNodeValue(a) );
CONVERTERS.put( NTRelationship.getClass(), (a, ignore1, ignore2) -> asRelationshipValue(a) );
CONVERTERS.put( NTPath.getClass(), (a, ignore1, ignore2) -> asPathValue(a) );
CONVERTERS.put( NTGeometry.getClass(), (a, ignore1, ignore2) ->asPointValue(a) );
CONVERTERS.put( NTPoint.getClass(), (a, ignore1, ignore2) ->asPointValue(a) );
CONVERTERS.put( NTDateTime.getClass(),(a, ignore1, ignore2) -> asDateTimeValue(a) );
CONVERTERS.put( NTLocalDateTime.getClass(), (a, ignore1, ignore2) -> asLocalDateTimeValue(a) );
CONVERTERS.put( NTDate.getClass(), (a, ignore1, ignore2) -> asDateValue(a) );
CONVERTERS.put( NTTime.getClass(), (a, ignore1, ignore2) -> asTimeValue(a) );
CONVERTERS.put( NTLocalTime.getClass(), (a, ignore1, ignore2) -> asLocalTimeValue(a) );
CONVERTERS.put( NTDuration.getClass(), (a, ignore1, ignore2) -> asDurationValue(a) );
CONVERTERS.put( Neo4jTypes.ListType.class, new ListCoercer() );
} }
} }
Expand Up @@ -33,12 +33,14 @@ import org.neo4j.cypher.internal.runtime.interpreted.ExecutionContext
import org.neo4j.cypher.internal.runtime.interpreted.pipes.NestedPipeExpression import org.neo4j.cypher.internal.runtime.interpreted.pipes.NestedPipeExpression
import org.neo4j.cypher.internal.v3_5.logical.plans.{CoerceToPredicate, NestedPlanExpression} import org.neo4j.cypher.internal.v3_5.logical.plans.{CoerceToPredicate, NestedPlanExpression}
import org.neo4j.cypher.operations.{CypherBoolean, CypherCoercions, CypherFunctions, CypherMath} import org.neo4j.cypher.operations.{CypherBoolean, CypherCoercions, CypherFunctions, CypherMath}
import org.neo4j.internal.kernel.api.procs.Neo4jTypes
import org.neo4j.internal.kernel.api.procs.Neo4jTypes.AnyType
import org.neo4j.values.AnyValue import org.neo4j.values.AnyValue
import org.neo4j.values.storable._ import org.neo4j.values.storable._
import org.neo4j.values.virtual._ import org.neo4j.values.virtual._
import org.opencypher.v9_0.expressions import org.opencypher.v9_0.expressions
import org.opencypher.v9_0.expressions._ import org.opencypher.v9_0.expressions._
import org.opencypher.v9_0.util.symbols.{CTAny, CTBoolean, CTDate, CTDateTime, CTDuration, CTFloat, CTGeometry, CTInteger, CTLocalDateTime, CTLocalTime, CTMap, CTNode, CTNumber, CTPath, CTPoint, CTRelationship, CTString, CTTime, ListType} import org.opencypher.v9_0.util.symbols.{CTAny, CTBoolean, CTDate, CTDateTime, CTDuration, CTFloat, CTGeometry, CTInteger, CTLocalDateTime, CTLocalTime, CTMap, CTNode, CTNumber, CTPath, CTPoint, CTRelationship, CTString, CTTime, CypherType, ListType}
import org.opencypher.v9_0.util.{CypherTypeException, InternalException} import org.opencypher.v9_0.util.{CypherTypeException, InternalException}


/** /**
Expand Down Expand Up @@ -321,14 +323,11 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) {
noValueCheck(e)(invokeStatic(method[CypherCoercions, MapValue, AnyValue, DbAccess]("asMapValue"), e.ir, DB_ACCESS)), noValueCheck(e)(invokeStatic(method[CypherCoercions, MapValue, AnyValue, DbAccess]("asMapValue"), e.ir, DB_ACCESS)),
nullable = false, e.fields) nullable = false, e.fields)


case t: ListType if t.innerType == CTNode || t.innerType == CTRelationship => case l: ListType =>
IntermediateExpression( val typ = asNeoType(l.innerType)
noValueCheck(e)(invokeStatic(method[CypherCoercions, ListValue, AnyValue]("asListValueFailOnPaths"), e.ir)),
nullable = false, e.fields)


case _: ListType =>
IntermediateExpression( IntermediateExpression(
noValueCheck(e)(invokeStatic(method[CypherCoercions, ListValue, AnyValue]("asListValueSupportPaths"), e.ir)), noValueCheck(e)(invokeStatic(method[CypherCoercions, ListValue, AnyValue, AnyType, DbAccess]("asList"), e.ir, typ, DB_ACCESS)),
nullable = false, e.fields) nullable = false, e.fields)


case CTBoolean => case CTBoolean =>
Expand Down Expand Up @@ -413,7 +412,7 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) {
Some(IntermediateExpression( Some(IntermediateExpression(
block( block(
condition(equal(loadField(f), constant(-1)))( condition(equal(loadField(f), constant(-1)))(
setField(f, invoke(DB_ACCESS, method[DbAccess, Int, String]("getPropertyKeyId"), constant(key)))), setField(f, invoke(DB_ACCESS, method[DbAccess, Int, String]("propertyKeyId"), constant(key)))),
ternary( ternary(
invoke(DB_ACCESS, method[DbAccess, Boolean, Long, Int]("nodeHasProperty"), invoke(DB_ACCESS, method[DbAccess, Boolean, Long, Int]("nodeHasProperty"),
getLongAt(offset), loadField(f)), truthValue, falseValue)), nullable = false, Seq(f))) getLongAt(offset), loadField(f)), truthValue, falseValue)), nullable = false, Seq(f)))
Expand All @@ -428,7 +427,7 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) {
Some(IntermediateExpression( Some(IntermediateExpression(
block( block(
condition(equal(loadField(f), constant(-1)))( condition(equal(loadField(f), constant(-1)))(
setField(f, invoke(DB_ACCESS, method[DbAccess, Int, String]("getPropertyKeyId"), constant(key)))), setField(f, invoke(DB_ACCESS, method[DbAccess, Int, String]("propertyKeyId"), constant(key)))),
invoke(DB_ACCESS, method[DbAccess, Value, Long, Int]("relationshipProperty"), invoke(DB_ACCESS, method[DbAccess, Value, Long, Int]("relationshipProperty"),
getLongAt(offset), loadField(f))), nullable = true, Seq(f))) getLongAt(offset), loadField(f))), nullable = true, Seq(f)))


Expand All @@ -446,7 +445,7 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) {
Some(IntermediateExpression( Some(IntermediateExpression(
block( block(
condition(equal(loadField(f), constant(-1)))( condition(equal(loadField(f), constant(-1)))(
setField(f, invoke(DB_ACCESS, method[DbAccess, Int, String]("getPropertyKeyId"), constant(key)))), setField(f, invoke(DB_ACCESS, method[DbAccess, Int, String]("propertyKeyId"), constant(key)))),
ternary( ternary(
invoke(DB_ACCESS, method[DbAccess, Boolean, Long, Int]("relationshipHasProperty"), invoke(DB_ACCESS, method[DbAccess, Boolean, Long, Int]("relationshipHasProperty"),
getLongAt(offset), loadField(f)), getLongAt(offset), loadField(f)),
Expand Down Expand Up @@ -1091,6 +1090,28 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) {
if (nullable) ternary(load(seenNull), noValue, load(returnValue)) else load(returnValue)): _*) if (nullable) ternary(load(seenNull), noValue, load(returnValue)) else load(returnValue)): _*)
IntermediateExpression(ir, nullable, expressions.foldLeft(Seq.empty[Field])((a,b) => a ++ b.fields)) IntermediateExpression(ir, nullable, expressions.foldLeft(Seq.empty[Field])((a,b) => a ++ b.fields))
} }

private def asNeoType(ct: CypherType): IntermediateRepresentation = ct match {
case CTString => getStatic[Neo4jTypes, Neo4jTypes.TextType]("NTString")
case CTInteger => getStatic[Neo4jTypes, Neo4jTypes.IntegerType]("NTInteger")
case CTFloat => getStatic[Neo4jTypes, Neo4jTypes.FloatType]("NTFloat")
case CTNumber => getStatic[Neo4jTypes, Neo4jTypes.NumberType]("NTNumber")
case CTBoolean => getStatic[Neo4jTypes, Neo4jTypes.BooleanType]("NTBoolean")
case l: ListType => invokeStatic(method[Neo4jTypes , Neo4jTypes.ListType, AnyType]("NTList"), asNeoType(l.innerType))
case CTDateTime => getStatic[Neo4jTypes, Neo4jTypes.DateTimeType]("NTDateTime")
case CTLocalDateTime => getStatic[Neo4jTypes, Neo4jTypes.LocalDateTimeType]("NTLocalDateTime")
case CTDate => getStatic[Neo4jTypes, Neo4jTypes.DateType]("NTDate")
case CTTime => getStatic[Neo4jTypes, Neo4jTypes.TimeType]("NTTime")
case CTLocalTime => getStatic[Neo4jTypes, Neo4jTypes.LocalTimeType]("NTLocalTime")
case CTDuration =>getStatic[Neo4jTypes, Neo4jTypes.DurationType]("NTDuration")
case CTPoint => getStatic[Neo4jTypes, Neo4jTypes.PointType]("NTPoint")
case CTNode =>getStatic[Neo4jTypes, Neo4jTypes.NodeType]("NTNode")
case CTRelationship => getStatic[Neo4jTypes, Neo4jTypes.RelationshipType]("NTRelationship")
case CTPath => getStatic[Neo4jTypes, Neo4jTypes.PathType]("NTPath")
case CTGeometry => getStatic[Neo4jTypes, Neo4jTypes.GeometryType]("NTGeometry")
case CTMap =>getStatic[Neo4jTypes, Neo4jTypes.MapType]("NTMap")
case CTAny => getStatic[Neo4jTypes, Neo4jTypes.AnyType]("NTAny")
}
} }


object IntermediateCodeGeneration { object IntermediateCodeGeneration {
Expand Down

0 comments on commit e96bab0

Please sign in to comment.