Skip to content

Commit

Permalink
Refactor OverriddenAccessMode
Browse files Browse the repository at this point in the history
Separate it into two implementation classes to clarify the difference between
restricting and elevating access mode,

- Replace OverriddenAccessMode with abstract LayeredAccessMode
- Separate implementation classes RestrictedAccessMode and ElevatedAccessMode
- Rename overrideOriginalMode() to isElevated(), with default implementation in
base interface.

Also fixes function call allowed, which had a bug of not reverting overridden
access mode
  • Loading branch information
henriknyman committed Oct 12, 2016
1 parent 68b1680 commit 0662879
Show file tree
Hide file tree
Showing 17 changed files with 212 additions and 277 deletions.
Expand Up @@ -150,8 +150,8 @@ class ExecutionEngine(val queryService: GraphDatabaseQueryService, logProvider:
val tc = externalTransactionalContext.getOrBeginNewIfClosed() val tc = externalTransactionalContext.getOrBeginNewIfClosed()


// Temporarily change access mode during query planning // Temporarily change access mode during query planning
// NOTE: The OVERRIDE_READ mode will force read access even if the current transaction did not have it // NOTE: This will force read access even if the current transaction did not have it
val revertable = tc.restrictCurrentTransaction(AccessMode.Static.OVERRIDE_READ) val revertable = tc.restrictCurrentTransaction(AccessMode.Static.READ)


val ((plan: ExecutionPlan, extractedParameters), touched) = try { val ((plan: ExecutionPlan, extractedParameters), touched) = try {
// fetch plan cache // fetch plan cache
Expand Down
Expand Up @@ -52,7 +52,6 @@ import org.neo4j.kernel.api.exceptions.ProcedureException
import org.neo4j.kernel.api.exceptions.schema.{AlreadyConstrainedException, AlreadyIndexedException} import org.neo4j.kernel.api.exceptions.schema.{AlreadyConstrainedException, AlreadyIndexedException}
import org.neo4j.kernel.api.index.{IndexDescriptor, InternalIndexState} import org.neo4j.kernel.api.index.{IndexDescriptor, InternalIndexState}
import org.neo4j.kernel.api.proc.{QualifiedName => KernelQualifiedName} import org.neo4j.kernel.api.proc.{QualifiedName => KernelQualifiedName}
import org.neo4j.kernel.api.security.{AccessMode, AuthSubject}
import org.neo4j.kernel.impl.core.NodeManager import org.neo4j.kernel.impl.core.NodeManager
import org.neo4j.kernel.impl.locking.ResourceTypes import org.neo4j.kernel.impl.locking.ResourceTypes


Expand Down Expand Up @@ -586,36 +585,37 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional
} }


type KernelProcedureCall = (KernelQualifiedName, Array[AnyRef]) => RawIterator[Array[AnyRef], ProcedureException] type KernelProcedureCall = (KernelQualifiedName, Array[AnyRef]) => RawIterator[Array[AnyRef], ProcedureException]
type KernelFunctionCall = (KernelQualifiedName, Array[AnyRef]) => AnyRef

private def shouldElevate(allowed: Array[String]): Boolean = {
// We have to be careful with elevation, since we cannot elevate permissions in a nested procedure call
// above the original allowed procedure mode. With enforce this by checking if mode is already an overridden mode.
val mode = transactionalContext.accessMode
allowed.nonEmpty && !mode.isElevated && mode.getOriginalAccessMode.allowsProcedureWith(allowed)
}


override def callReadOnlyProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = { override def callReadOnlyProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = val call: KernelProcedureCall =
if (allowsOverride(allowed)) if (shouldElevate(allowed))
transactionalContext.statement.procedureCallOperations.procedureCallRead(_, _, AccessMode.Static.OVERRIDE_READ) transactionalContext.statement.procedureCallOperations.procedureCallReadElevated(_, _)
else else
transactionalContext.statement.procedureCallOperations.procedureCallRead(_, _) transactionalContext.statement.procedureCallOperations.procedureCallRead(_, _)
callProcedure(name, args, call) callProcedure(name, args, call)
} }


private def allowsOverride(allowed: Array[String]): Boolean = {
// We have to be careful with override, since we cannot elevate permissions in a nested procedure call
// above the original allowed procedure mode. With enforce this by checking if mode is already an overridden mode.
val mode = transactionalContext.accessMode
allowed.nonEmpty && !mode.overrideOriginalMode && mode.getOriginalAccessMode.allowsProcedureWith(allowed)
}

override def callReadWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = { override def callReadWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = val call: KernelProcedureCall =
if (allowsOverride(allowed)) if (shouldElevate(allowed))
transactionalContext.statement.procedureCallOperations.procedureCallWrite(_, _, AccessMode.Static.OVERRIDE_WRITE) transactionalContext.statement.procedureCallOperations.procedureCallWriteElevated(_, _)
else else
transactionalContext.statement.procedureCallOperations.procedureCallWrite(_, _) transactionalContext.statement.procedureCallOperations.procedureCallWrite(_, _)
callProcedure(name, args, call) callProcedure(name, args, call)
} }


override def callSchemaWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = { override def callSchemaWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = val call: KernelProcedureCall =
if (allowsOverride(allowed)) if (shouldElevate(allowed))
transactionalContext.statement.procedureCallOperations.procedureCallSchema(_, _, AccessMode.Static.OVERRIDE_SCHEMA) transactionalContext.statement.procedureCallOperations.procedureCallSchemaElevated(_, _)
else else
transactionalContext.statement.procedureCallOperations.procedureCallSchema(_, _) transactionalContext.statement.procedureCallOperations.procedureCallSchema(_, _)
callProcedure(name, args, call) callProcedure(name, args, call)
Expand All @@ -636,17 +636,16 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional
} }


override def callFunction(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = { override def callFunction(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val revertable = transactionalContext.accessMode match { val call: KernelFunctionCall =
case a: AuthSubject if a.allowsProcedureWith(allowed) => if (shouldElevate(allowed))
Some(transactionalContext.restrictCurrentTransaction(AccessMode.Static.OVERRIDE_READ)) transactionalContext.statement.procedureCallOperations.functionCallElevated(_, _)
case _ => None else
} transactionalContext.statement.procedureCallOperations.functionCall(_, _)
callFunction(name, args, transactionalContext.statement.readOperations().functionCall, revertable.foreach(_.close)) callFunction(name, args, call)
} }


private def callFunction(name: QualifiedName, args: Seq[Any], private def callFunction(name: QualifiedName, args: Seq[Any],
call: (KernelQualifiedName, Array[AnyRef]) => AnyRef, call: KernelFunctionCall) = {
onClose: => Unit) = {
val kn = new KernelQualifiedName(name.namespace.asJava, name.name) val kn = new KernelQualifiedName(name.namespace.asJava, name.name)
val toArray = args.map(_.asInstanceOf[AnyRef]).toArray val toArray = args.map(_.asInstanceOf[AnyRef]).toArray
call(kn, toArray) call(kn, toArray)
Expand Down
Expand Up @@ -22,7 +22,6 @@
import org.neo4j.collection.RawIterator; import org.neo4j.collection.RawIterator;
import org.neo4j.kernel.api.exceptions.ProcedureException; import org.neo4j.kernel.api.exceptions.ProcedureException;
import org.neo4j.kernel.api.proc.QualifiedName; import org.neo4j.kernel.api.proc.QualifiedName;
import org.neo4j.kernel.api.security.AccessMode;


/** /**
* Specifies procedure call operations for the three types of procedure calls that can be made. * Specifies procedure call operations for the three types of procedure calls that can be made.
Expand All @@ -40,15 +39,14 @@ RawIterator<Object[], ProcedureException> procedureCallRead( QualifiedName name,
throws ProcedureException; throws ProcedureException;


/** /**
* Invoke a read-only procedure by name, and override the transaction's access mode with * Invoke a read-only procedure by name, and elevate the transaction's access mode to
* the given access mode for the duration of the procedure execution. * the given access mode for the duration of the procedure execution.
* @param name the name of the procedure. * @param name the name of the procedure.
* @param arguments the procedure arguments. * @param arguments the procedure arguments.
* @param override the access mode to be used for the execution of the procedure.
* @return an iterator containing the procedure results. * @return an iterator containing the procedure results.
* @throws ProcedureException if there was an exception thrown during procedure execution. * @throws ProcedureException if there was an exception thrown during procedure execution.
*/ */
RawIterator<Object[], ProcedureException> procedureCallRead( QualifiedName name, Object[] arguments, AccessMode override ) RawIterator<Object[], ProcedureException> procedureCallReadElevated( QualifiedName name, Object[] arguments )
throws ProcedureException; throws ProcedureException;


/** /**
Expand All @@ -61,15 +59,14 @@ RawIterator<Object[], ProcedureException> procedureCallRead( QualifiedName name,
RawIterator<Object[], ProcedureException> procedureCallWrite( QualifiedName name, Object[] arguments ) RawIterator<Object[], ProcedureException> procedureCallWrite( QualifiedName name, Object[] arguments )
throws ProcedureException; throws ProcedureException;
/** /**
* Invoke a read/write procedure by name, and override the transaction's access mode with * Invoke a read/write procedure by name, and elevate the transaction's access mode to
* the given access mode for the duration of the procedure execution. * the given access mode for the duration of the procedure execution.
* @param name the name of the procedure. * @param name the name of the procedure.
* @param arguments the procedure arguments. * @param arguments the procedure arguments.
* @param override the access mode to be used for the execution of the procedure.
* @return an iterator containing the procedure results. * @return an iterator containing the procedure results.
* @throws ProcedureException if there was an exception thrown during procedure execution. * @throws ProcedureException if there was an exception thrown during procedure execution.
*/ */
RawIterator<Object[], ProcedureException> procedureCallWrite( QualifiedName name, Object[] arguments, AccessMode override ) RawIterator<Object[], ProcedureException> procedureCallWriteElevated( QualifiedName name, Object[] arguments )
throws ProcedureException; throws ProcedureException;


/** /**
Expand All @@ -82,14 +79,28 @@ RawIterator<Object[], ProcedureException> procedureCallWrite( QualifiedName name
RawIterator<Object[], ProcedureException> procedureCallSchema( QualifiedName name, Object[] arguments ) RawIterator<Object[], ProcedureException> procedureCallSchema( QualifiedName name, Object[] arguments )
throws ProcedureException; throws ProcedureException;
/** /**
* Invoke a schema write procedure by name, and override the transaction's access mode with * Invoke a schema write procedure by name, and elevate the transaction's access mode to
* the given access mode for the duration of the procedure execution. * the given access mode for the duration of the procedure execution.
* @param name the name of the procedure. * @param name the name of the procedure.
* @param arguments the procedure arguments. * @param arguments the procedure arguments.
* @param override the access mode to be used for the execution of the procedure.
* @return an iterator containing the procedure results. * @return an iterator containing the procedure results.
* @throws ProcedureException if there was an exception thrown during procedure execution. * @throws ProcedureException if there was an exception thrown during procedure execution.
*/ */
RawIterator<Object[], ProcedureException> procedureCallSchema( QualifiedName name, Object[] arguments, AccessMode override ) RawIterator<Object[], ProcedureException> procedureCallSchemaElevated( QualifiedName name, Object[] arguments )
throws ProcedureException; throws ProcedureException;

/** Invoke a read-only function by name
* @param name the name of the function.
* @param arguments the function arguments.
* @throws ProcedureException if there was an exception thrown during function execution.
*/
Object functionCall( QualifiedName name, Object[] arguments ) throws ProcedureException;

/** Invoke a read-only function by name, and elevate the transaction's access mode to
* the given access mode for the duration of the function execution.
* @param name the name of the function.
* @param arguments the function arguments.
* @throws ProcedureException if there was an exception thrown during function execution.
*/
Object functionCallElevated( QualifiedName name, Object[] arguments ) throws ProcedureException;
} }
Expand Up @@ -566,7 +566,4 @@ DoubleLongRegister indexSample( IndexDescriptor index, DoubleLongRegister target


/** Fetch all registered procedures */ /** Fetch all registered procedures */
Set<ProcedureSignature> proceduresGetAll(); Set<ProcedureSignature> proceduresGetAll();

/** Invoke a read-only procedure by name */
Object functionCall( QualifiedName name, Object[] input ) throws ProcedureException;
} }

0 comments on commit 0662879

Please sign in to comment.