Skip to content

Commit

Permalink
CallableUserFunction works with AnyValue.
Browse files Browse the repository at this point in the history
Introduced a ValueMapper type for this. ValueMapper translates (Any)Value to
other type systems, in this case the java classes of the embedded API.
  • Loading branch information
thobe committed Jan 30, 2018
1 parent 195d1ae commit 2cdf76a
Show file tree
Hide file tree
Showing 66 changed files with 1,480 additions and 129 deletions.
Expand Up @@ -154,7 +154,7 @@ class ExceptionTranslatingQueryContext(val inner: QueryContext) extends QueryCon
override def callDbmsProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] =
translateIterator(inner.callDbmsProcedure(name, args, allowed))

override def callFunction(name: QualifiedName, args: Seq[Any], allowed: Array[String]) =
override def callFunction(name: QualifiedName, args: Seq[AnyValue], allowed: Array[String]) =
translateException(inner.callFunction(name, args, allowed))


Expand Down
Expand Up @@ -56,6 +56,8 @@ import org.neo4j.kernel.api.schema.index.IndexDescriptorFactory
import org.neo4j.kernel.api.{exceptions, _}
import org.neo4j.kernel.impl.core.EmbeddedProxySPI
import org.neo4j.kernel.impl.locking.ResourceTypes
import org.neo4j.kernel.impl.util.ValueUtils
import org.neo4j.values.AnyValue
import org.neo4j.values.storable.Values

import scala.collection.Iterator
Expand Down Expand Up @@ -593,7 +595,7 @@ final class TransactionBoundQueryContext(txContext: TransactionalContextWrapper)
}

type KernelProcedureCall = (KernelQualifiedName, Array[AnyRef]) => RawIterator[Array[AnyRef], ProcedureException]
type KernelFunctionCall = (KernelQualifiedName, Array[AnyRef]) => AnyRef
type KernelFunctionCall = (KernelQualifiedName, Array[AnyValue]) => AnyValue

private def shouldElevate(allowed: Array[String]): Boolean = {
// We have to be careful with elevation, since we cannot elevate permissions in a nested procedure call
Expand Down Expand Up @@ -646,17 +648,18 @@ final class TransactionBoundQueryContext(txContext: TransactionalContextWrapper)
override def callFunction(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelFunctionCall =
if (shouldElevate(allowed))
txContext.statement.procedureCallOperations.functionCallOverride(_, _)
(name, args) => txContext.statement.procedureCallOperations.functionCallOverride(name, args)
else
txContext.statement.procedureCallOperations.functionCall(_, _)
(name, args) => txContext.statement.procedureCallOperations.functionCall(name, args)
callFunction(name, args, call)
}

private def callFunction(name: QualifiedName, args: Seq[Any],
call: KernelFunctionCall) = {
val kn = new KernelQualifiedName(name.namespace.asJava, name.name)
val toArray = args.map(_.asInstanceOf[AnyRef]).toArray
call(kn, toArray)
val argArray = args.map(ValueUtils.of).toArray
val result = call(kn, argArray)
result.map(txContext.statement.procedureCallOperations.valueMapper)
}

override def isGraphKernelResultValue(v: Any): Boolean = v.isInstanceOf[PropertyContainer] || v.isInstanceOf[Path]
Expand Down
Expand Up @@ -221,7 +221,7 @@ abstract class DelegatingQueryContext(val inner: QueryContext) extends QueryCont
override def callDbmsProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) =
inner.callDbmsProcedure(name, args, allowed)

override def callFunction(name: QualifiedName, args: Seq[Any], allowed: Array[String]) =
override def callFunction(name: QualifiedName, args: Seq[AnyValue], allowed: Array[String]) =
singleDbHit(inner.callFunction(name, args, allowed))

override def aggregateFunction(name: QualifiedName,
Expand Down
Expand Up @@ -25,7 +25,7 @@ import java.util.function.Predicate
import org.neo4j.collection.RawIterator
import org.neo4j.collection.primitive.{PrimitiveLongIterator, PrimitiveLongResourceIterator}
import org.neo4j.cypher.InternalException
import org.neo4j.cypher.internal.javacompat.{GraphDatabaseCypherService, ValueToObjectSerializer}
import org.neo4j.cypher.internal.javacompat.GraphDatabaseCypherService
import org.neo4j.cypher.internal.planner.v3_4.spi.{IdempotentResult, IndexDescriptor}
import org.neo4j.cypher.internal.runtime._
import org.neo4j.cypher.internal.runtime.interpreted.CypherOrdering.{BY_NUMBER, BY_STRING, BY_VALUE}
Expand Down Expand Up @@ -59,10 +59,10 @@ import org.neo4j.kernel.impl.coreapi.PropertyContainerLocker
import org.neo4j.kernel.impl.locking.ResourceTypes
import org.neo4j.kernel.impl.query.Neo4jTransactionalContext
import org.neo4j.kernel.impl.util.ValueUtils.{fromNodeProxy, fromRelationshipProxy}
import org.neo4j.kernel.impl.util.{NodeProxyWrappingNodeValue, RelationshipProxyWrappingValue}
import org.neo4j.values.AnyValue
import org.neo4j.kernel.impl.util.{DefaultValueMapper, NodeProxyWrappingNodeValue, RelationshipProxyWrappingValue}
import org.neo4j.values.{AnyValue, ValueMapper}
import org.neo4j.values.storable.{TextValue, Value, Values}
import org.neo4j.values.virtual.{RelationshipValue, ListValue, NodeValue, VirtualValues}
import org.neo4j.values.virtual.{ListValue, NodeValue, RelationshipValue, VirtualValues}

import scala.collection.Iterator
import scala.collection.JavaConverters._
Expand All @@ -77,6 +77,7 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional
override val relationshipOps: RelationshipOperations = new RelationshipOperations
override lazy val entityAccessor: EmbeddedProxySPI =
transactionalContext.graph.getDependencyResolver.resolveDependency(classOf[EmbeddedProxySPI])
private lazy val valueMapper: ValueMapper[java.lang.Object] = new DefaultValueMapper(entityAccessor)

override def setLabelsOnNode(node: Long, labelIds: Iterator[Int]): Int = labelIds.foldLeft(0) {
case (count, labelId) => if (writes().nodeAddLabel(node, labelId)) count + 1 else count
Expand Down Expand Up @@ -429,13 +430,7 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional
value match {
case node: NodeProxyWrappingNodeValue => node.nodeProxy
case edge: RelationshipProxyWrappingValue => edge.relationshipProxy
case _ =>

val converter = new ValueToObjectSerializer(entityAccessor)
//TODO this is not very nice, but I need a transaction here and this is what
// I ended up with.
withAnyOpenQueryContext(_ => value.writeTo(converter))
converter.value()
case _ => withAnyOpenQueryContext(_=>value.map(valueMapper))
}
}

Expand Down Expand Up @@ -845,7 +840,7 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional
}

type KernelProcedureCall = (KernelQualifiedName, Array[AnyRef]) => RawIterator[Array[AnyRef], ProcedureException]
type KernelFunctionCall = (KernelQualifiedName, Array[AnyRef]) => AnyRef
type KernelFunctionCall = (KernelQualifiedName, Array[AnyValue]) => AnyValue
type KernelAggregationFunctionCall = (KernelQualifiedName) => Aggregator

private def shouldElevate(allowed: Array[String]): Boolean = {
Expand Down Expand Up @@ -898,7 +893,7 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional
}
}

override def callFunction(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
override def callFunction(name: QualifiedName, args: Seq[AnyValue], allowed: Array[String]) = {
val call: KernelFunctionCall =
if (shouldElevate(allowed))
transactionalContext.statement.procedureCallOperations.functionCallOverride(_, _)
Expand All @@ -916,11 +911,10 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional
callAggregationFunction(name, call)
}

private def callFunction(name: QualifiedName, args: Seq[Any],
private def callFunction(name: QualifiedName, args: Seq[AnyValue],
call: KernelFunctionCall) = {
val kn = new KernelQualifiedName(name.namespace.asJava, name.name)
val toArray = args.map(_.asInstanceOf[AnyRef]).toArray
call(kn, toArray)
call(kn, args.toArray)
}

private def callAggregationFunction(name: QualifiedName,
Expand Down
Expand Up @@ -19,24 +19,20 @@
*/
package org.neo4j.cypher.internal.runtime.interpreted.commands.expressions

import org.neo4j.cypher.internal.runtime.interpreted.ExecutionContext
import org.neo4j.cypher.internal.runtime.interpreted.ValueConversion
import org.neo4j.cypher.internal.runtime.interpreted.GraphElementPropertyFunctions
import org.neo4j.cypher.internal.runtime.interpreted.{ExecutionContext, GraphElementPropertyFunctions}
import org.neo4j.cypher.internal.runtime.interpreted.pipes.QueryState
import org.neo4j.cypher.internal.v3_4.logical.plans.UserFunctionSignature
import org.neo4j.values._

case class FunctionInvocation(signature: UserFunctionSignature, arguments: IndexedSeq[Expression])
extends Expression with GraphElementPropertyFunctions {
private val valueConverter = ValueConversion.getValueConverter(signature.outputType)

override def apply(ctx: ExecutionContext, state: QueryState): AnyValue = {
val query = state.query
val argValues = arguments.map(arg => {
query.asObject(arg(ctx, state))
arg(ctx, state)
})
val result = query.callFunction(signature.name, argValues, signature.allowed)
valueConverter(result)
query.callFunction(signature.name, argValues, signature.allowed)
}

override def rewrite(f: (Expression) => Expression) =
Expand Down
Expand Up @@ -153,7 +153,7 @@ trait QueryContextAdaptation {

override def callDbmsProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] = ???

override def callFunction(name: QualifiedName, args: Seq[Any], allowed: Array[String]): AnyRef = ???
override def callFunction(name: QualifiedName, args: Seq[AnyValue], allowed: Array[String]): AnyValue = ???

override def aggregateFunction(name: QualifiedName,
allowed: Array[String]): UserDefinedAggregator = ???
Expand Down
Expand Up @@ -190,7 +190,7 @@ trait QueryContext extends TokenContext {

def callDbmsProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]]

def callFunction(name: QualifiedName, args: Seq[Any], allowed: Array[String]): AnyRef
def callFunction(name: QualifiedName, args: Seq[AnyValue], allowed: Array[String]): AnyValue

def aggregateFunction(name: QualifiedName, allowed: Array[String]): UserDefinedAggregator

Expand Down
Expand Up @@ -24,6 +24,8 @@
import org.neo4j.kernel.api.proc.CallableUserAggregationFunction;
import org.neo4j.kernel.api.proc.QualifiedName;
import org.neo4j.internal.kernel.api.security.AccessMode;
import org.neo4j.values.AnyValue;
import org.neo4j.values.ValueMapper;

/**
* Specifies procedure call operations for the three types of procedure calls that can be made.
Expand Down Expand Up @@ -99,15 +101,15 @@ RawIterator<Object[], ProcedureException> procedureCallSchemaOverride( Qualified
* @param arguments the function arguments.
* @throws ProcedureException if there was an exception thrown during function execution.
*/
Object functionCall( QualifiedName name, Object[] arguments ) throws ProcedureException;
AnyValue functionCall( QualifiedName name, AnyValue[] arguments ) throws ProcedureException;

/** Invoke a read-only function by name, and set the transaction's access mode to
* {@link AccessMode.Static#READ READ} for the duration of the function execution.
* @param name the name of the function.
* @param arguments the function arguments.
* @throws ProcedureException if there was an exception thrown during function execution.
*/
Object functionCallOverride( QualifiedName name, Object[] arguments ) throws ProcedureException;
AnyValue functionCallOverride( QualifiedName name, AnyValue[] arguments ) throws ProcedureException;

/**
* Create a read-only aggregation function by name
Expand All @@ -123,4 +125,10 @@ RawIterator<Object[], ProcedureException> procedureCallSchemaOverride( Qualified
* @throws ProcedureException if there was an exception thrown during function execution.
*/
CallableUserAggregationFunction.Aggregator aggregationFunctionOverride( QualifiedName name ) throws ProcedureException;

/**
* Retrieve a value mapper for mapping values to regular Java objects.
* @return a value mapper that maps to Java objects.
*/
ValueMapper<Object> valueMapper();
}
Expand Up @@ -24,6 +24,7 @@
import org.neo4j.kernel.api.exceptions.ProcedureException;
import org.neo4j.kernel.api.proc.QualifiedName;
import org.neo4j.internal.kernel.api.security.SecurityContext;
import org.neo4j.values.AnyValue;

/**
* Defines all types of system-oriented operations - i.e. those which do not read from or
Expand All @@ -44,9 +45,9 @@ RawIterator<Object[],ProcedureException> procedureCallDbms(
) throws ProcedureException;

/** Invoke a DBMS function by name */
Object functionCallDbms(
AnyValue functionCallDbms(
QualifiedName name,
Object[] input,
AnyValue[] input,
SecurityContext securityContext
) throws ProcedureException;
}
Expand Up @@ -20,11 +20,12 @@
package org.neo4j.kernel.api.proc;

import org.neo4j.kernel.api.exceptions.ProcedureException;
import org.neo4j.values.AnyValue;

public interface CallableUserFunction
{
UserFunctionSignature signature();
Object apply( Context ctx, Object[] input ) throws ProcedureException;
AnyValue apply( Context ctx, AnyValue[] input ) throws ProcedureException;

abstract class BasicUserFunction implements CallableUserFunction
{
Expand All @@ -42,6 +43,6 @@ public UserFunctionSignature signature()
}

@Override
public abstract Object apply( Context ctx, Object[] input ) throws ProcedureException;
public abstract AnyValue apply( Context ctx, AnyValue[] input ) throws ProcedureException;
}
}
Expand Up @@ -19,9 +19,9 @@
*/
package org.neo4j.kernel.api.proc;

import org.neo4j.collection.RawIterator;
import org.neo4j.kernel.api.exceptions.ProcedureException;
import org.neo4j.kernel.api.exceptions.Status;
import org.neo4j.values.AnyValue;

public class FailedLoadFunction extends CallableUserFunction.BasicUserFunction
{
Expand All @@ -31,7 +31,7 @@ public FailedLoadFunction( UserFunctionSignature signature )
}

@Override
public RawIterator<Object[],ProcedureException> apply( Context ctx, Object[] input ) throws ProcedureException
public AnyValue apply( Context ctx, AnyValue[] input ) throws ProcedureException
{
throw new ProcedureException( Status.Procedure.ProcedureRegistrationFailed,
signature().description().orElse( "Failed to load " + signature().name().toString() ) );
Expand Down
Expand Up @@ -112,6 +112,8 @@
import org.neo4j.storageengine.api.lock.ResourceType;
import org.neo4j.storageengine.api.schema.PopulationProgress;
import org.neo4j.storageengine.api.schema.SchemaRule;
import org.neo4j.values.AnyValue;
import org.neo4j.values.ValueMapper;
import org.neo4j.values.storable.Value;
import org.neo4j.values.storable.Values;
import org.neo4j.values.virtual.MapValue;
Expand Down Expand Up @@ -1463,7 +1465,7 @@ public Object[] next() throws ProcedureException
}

@Override
public Object functionCall( QualifiedName name, Object[] arguments ) throws ProcedureException
public AnyValue functionCall( QualifiedName name, AnyValue[] arguments ) throws ProcedureException
{
if ( !tx.securityContext().mode().allowsReads() )
{
Expand All @@ -1475,7 +1477,7 @@ public Object functionCall( QualifiedName name, Object[] arguments ) throws Proc
}

@Override
public Object functionCallOverride( QualifiedName name, Object[] arguments ) throws ProcedureException
public AnyValue functionCallOverride( QualifiedName name, AnyValue[] arguments ) throws ProcedureException
{
return callFunction( name, arguments,
new OverriddenAccessMode( tx.securityContext().mode(), AccessMode.Static.READ ) );
Expand All @@ -1499,7 +1501,13 @@ public CallableUserAggregationFunction.Aggregator aggregationFunctionOverride( Q
new OverriddenAccessMode( tx.securityContext().mode(), AccessMode.Static.READ ) );
}

private Object callFunction( QualifiedName name, Object[] input, final AccessMode mode ) throws ProcedureException
@Override
public ValueMapper<Object> valueMapper()
{
return procedures.valueMapper();
}

private AnyValue callFunction( QualifiedName name, AnyValue[] input, final AccessMode mode ) throws ProcedureException
{
statement.assertOpen();

Expand Down
Expand Up @@ -27,6 +27,7 @@
import org.neo4j.kernel.api.proc.QualifiedName;
import org.neo4j.internal.kernel.api.security.SecurityContext;
import org.neo4j.kernel.impl.proc.Procedures;
import org.neo4j.values.AnyValue;

public class NonTransactionalDbmsOperations implements DbmsOperations
{
Expand All @@ -51,9 +52,9 @@ public RawIterator<Object[],ProcedureException> procedureCallDbms(
}

@Override
public Object functionCallDbms(
public AnyValue functionCallDbms(
QualifiedName name,
Object[] input,
AnyValue[] input,
SecurityContext securityContext
) throws ProcedureException
{
Expand Down
Expand Up @@ -45,4 +45,6 @@ public interface EmbeddedProxySPI
GraphPropertiesProxy newGraphPropertiesProxy();

RelationshipType getRelationshipTypeById( int type );

int getRelationshipTypeIdByName( String typeName );
}

0 comments on commit 2cdf76a

Please sign in to comment.