Skip to content

Commit

Permalink
Handle recursive overridden AccessMode
Browse files Browse the repository at this point in the history
- Move getOriginalAccessMode() to AccessMode base class
- Move username() to AccessMode base class to get rid of static downcast
method. When we have a proper security context we can remove it again.
  • Loading branch information
henriknyman committed Oct 12, 2016
1 parent 081eed5 commit 8b246b7
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 58 deletions.
Expand Up @@ -53,7 +53,6 @@ import org.neo4j.kernel.api.exceptions.schema.{AlreadyConstrainedException, Alre
import org.neo4j.kernel.api.index.{IndexDescriptor, InternalIndexState} import org.neo4j.kernel.api.index.{IndexDescriptor, InternalIndexState}
import org.neo4j.kernel.api.proc.{QualifiedName => KernelQualifiedName} import org.neo4j.kernel.api.proc.{QualifiedName => KernelQualifiedName}
import org.neo4j.kernel.api.security.{AccessMode, AuthSubject} import org.neo4j.kernel.api.security.{AccessMode, AuthSubject}
import org.neo4j.kernel.impl.api.security.{AccessModeSnapshot, OverriddenAccessMode}
import org.neo4j.kernel.impl.core.NodeManager import org.neo4j.kernel.impl.core.NodeManager
import org.neo4j.kernel.impl.locking.ResourceTypes import org.neo4j.kernel.impl.locking.ResourceTypes


Expand Down Expand Up @@ -588,17 +587,8 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional


type KernelProcedureCall = (KernelQualifiedName, Array[AnyRef]) => RawIterator[Array[AnyRef], ProcedureException] type KernelProcedureCall = (KernelQualifiedName, Array[AnyRef]) => RawIterator[Array[AnyRef], ProcedureException]


private def originalAccessMode: Object = {
// TODO: Write a test for recursive procedure calls with allowed annotation
transactionalContext.accessMode match {
case a: AccessModeSnapshot => a.getOriginalAccessMode
case a: OverriddenAccessMode => a.getOriginalAccessMode
case a => a
}
}

override def callReadOnlyProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = { override def callReadOnlyProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = originalAccessMode match { val call: KernelProcedureCall = transactionalContext.accessMode.getOriginalAccessMode match {
case a: AuthSubject if a.allowsProcedureWith(allowed) => case a: AuthSubject if a.allowsProcedureWith(allowed) =>
transactionalContext.statement.procedureCallOperations.procedureCallRead(_, _, AccessMode.Static.OVERRIDE_READ) transactionalContext.statement.procedureCallOperations.procedureCallRead(_, _, AccessMode.Static.OVERRIDE_READ)
case _ => case _ =>
Expand All @@ -608,7 +598,7 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional
} }


override def callReadWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = { override def callReadWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = originalAccessMode match { val call: KernelProcedureCall = transactionalContext.accessMode.getOriginalAccessMode match {
case a: AuthSubject if a.allowsProcedureWith(allowed) => case a: AuthSubject if a.allowsProcedureWith(allowed) =>
transactionalContext.statement.procedureCallOperations.procedureCallWrite(_, _, AccessMode.Static.OVERRIDE_WRITE) transactionalContext.statement.procedureCallOperations.procedureCallWrite(_, _, AccessMode.Static.OVERRIDE_WRITE)
case _ => case _ =>
Expand All @@ -618,7 +608,7 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional
} }


override def callSchemaWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = { override def callSchemaWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = originalAccessMode match { val call: KernelProcedureCall = transactionalContext.accessMode.getOriginalAccessMode match {
case a: AuthSubject if a.allowsProcedureWith(allowed) => case a: AuthSubject if a.allowsProcedureWith(allowed) =>
transactionalContext.statement.procedureCallOperations.procedureCallSchema(_, _, AccessMode.Static.OVERRIDE_SCHEMA) transactionalContext.statement.procedureCallOperations.procedureCallSchema(_, _, AccessMode.Static.OVERRIDE_SCHEMA)
case _ => case _ =>
Expand Down
Expand Up @@ -366,4 +366,14 @@ public AuthorizationViolationException onViolation( String msg )
boolean overrideOriginalMode(); boolean overrideOriginalMode();
AuthorizationViolationException onViolation( String msg ); AuthorizationViolationException onViolation( String msg );
String name(); String name();

default String username()
{
return ""; // Should never clash with a valid username
}

default AccessMode getOriginalAccessMode()
{
return this;
}
} }
Expand Up @@ -57,6 +57,7 @@ public interface AuthSubject extends AccessMode
/** /**
* @return A string representing the primary principal of this subject * @return A string representing the primary principal of this subject
*/ */
@Override
String username(); String username();


/** /**
Expand Down
Expand Up @@ -42,8 +42,6 @@
import org.neo4j.kernel.impl.proc.Procedures; import org.neo4j.kernel.impl.proc.Procedures;
import org.neo4j.storageengine.api.StorageStatement; import org.neo4j.storageengine.api.StorageStatement;


import static org.neo4j.kernel.impl.api.security.OverriddenAccessMode.getUsernameFromAccessMode;

/** /**
* A resource efficient implementation of {@link Statement}. Designed to be reused within a * A resource efficient implementation of {@link Statement}. Designed to be reused within a
* {@link KernelTransactionImplementation} instance, even across transactions since this instances itself * {@link KernelTransactionImplementation} instance, even across transactions since this instances itself
Expand Down Expand Up @@ -216,7 +214,7 @@ final void forceClose()


final Optional<String> username() final Optional<String> username()
{ {
String username = getUsernameFromAccessMode( transaction.mode() ); String username = transaction.mode().username();
return Optional.of( username ); return Optional.of( username );
} }


Expand Down
Expand Up @@ -27,7 +27,6 @@
import org.neo4j.kernel.api.proc.QualifiedName; import org.neo4j.kernel.api.proc.QualifiedName;
import org.neo4j.kernel.api.security.AccessMode; import org.neo4j.kernel.api.security.AccessMode;
import org.neo4j.kernel.api.security.AuthSubject; import org.neo4j.kernel.api.security.AuthSubject;
import org.neo4j.kernel.impl.api.security.AccessModeSnapshot;
import org.neo4j.kernel.impl.proc.Procedures; import org.neo4j.kernel.impl.proc.Procedures;


public class NonTransactionalDbmsOperations implements DbmsOperations public class NonTransactionalDbmsOperations implements DbmsOperations
Expand All @@ -48,10 +47,8 @@ public RawIterator<Object[],ProcedureException> procedureCallDbms(
) throws ProcedureException ) throws ProcedureException
{ {
BasicContext ctx = new BasicContext(); BasicContext ctx = new BasicContext();
AccessMode originalMode = (mode instanceof AccessModeSnapshot) ?
((AccessModeSnapshot) mode).getOriginalAccessMode() :
mode;


AccessMode originalMode = mode.getOriginalAccessMode();
if ( originalMode instanceof AuthSubject ) if ( originalMode instanceof AuthSubject )
{ {
ctx.put( Context.AUTH_SUBJECT, (AuthSubject) originalMode ); ctx.put( Context.AUTH_SUBJECT, (AuthSubject) originalMode );
Expand Down
Expand Up @@ -29,10 +29,12 @@ public class AccessModeSnapshot implements AccessMode
private final boolean allowsSchemaWrites; private final boolean allowsSchemaWrites;
private final boolean overrideOriginalMode; private final boolean overrideOriginalMode;


private final AccessMode accessMode; private final AccessMode originalMode;


public static AccessMode createAccessModeSnapshot( AccessMode accessMode ) public static AccessMode createAccessModeSnapshot( AccessMode accessMode )
{ {
// TODO: Use flyweight pattern instead of always creating new objects, when we have a proper
// security context and do not need to obtain the original mode through this object
return new AccessModeSnapshot( accessMode ); return new AccessModeSnapshot( accessMode );
} }


Expand All @@ -43,10 +45,13 @@ private AccessModeSnapshot( AccessMode accessMode )
allowsSchemaWrites = accessMode.allowsSchemaWrites(); allowsSchemaWrites = accessMode.allowsSchemaWrites();
overrideOriginalMode = accessMode.overrideOriginalMode(); overrideOriginalMode = accessMode.overrideOriginalMode();


// We use this for onViolation() and name() // We use this for delegation of all the remaining methods
this.accessMode = accessMode; this.originalMode = accessMode;
} }


//---------------------------------------
// Snapshot permissions

@Override @Override
public boolean allowsReads() public boolean allowsReads()
{ {
Expand All @@ -71,21 +76,30 @@ public boolean overrideOriginalMode()
return overrideOriginalMode; return overrideOriginalMode;
} }


//---------------------------------------
// Delegate remaining methods to original

@Override @Override
public AuthorizationViolationException onViolation( String msg ) public AuthorizationViolationException onViolation( String msg )
{ {
return accessMode.onViolation( msg ); return originalMode.onViolation( msg );
} }


@Override @Override
public String name() public String name()
{ {
return accessMode.name(); return originalMode.name();
} }


// TODO: Move this to AccessMode interface with default implementation to support recursive case @Override
public String username()
{
return originalMode.username();
}

@Override
public AccessMode getOriginalAccessMode() public AccessMode getOriginalAccessMode()
{ {
return accessMode; return originalMode.getOriginalAccessMode();
} }
} }
Expand Up @@ -21,7 +21,6 @@


import org.neo4j.graphdb.security.AuthorizationViolationException; import org.neo4j.graphdb.security.AuthorizationViolationException;
import org.neo4j.kernel.api.security.AccessMode; import org.neo4j.kernel.api.security.AccessMode;
import org.neo4j.kernel.api.security.AuthSubject;


public class OverriddenAccessMode implements AccessMode public class OverriddenAccessMode implements AccessMode
{ {
Expand Down Expand Up @@ -80,35 +79,16 @@ public String name()
} }
} }


@Override
public String username() public String username()
{ {
return getUsernameFromAccessMode( originalMode ); return originalMode.username();
} }


// TODO: Move this to AccessMode interface with default implementation to support recursive case @Override
// OR move allowsProcedureWith() to AccessMode and override that here with recursive implementation
public AccessMode getOriginalAccessMode() public AccessMode getOriginalAccessMode()
{ {
return originalMode; return originalMode.getOriginalAccessMode();
} }


public static String getUsernameFromAccessMode( AccessMode accessMode )
{
if ( accessMode instanceof AuthSubject )
{
return ((AuthSubject) accessMode).username();
}
else if ( accessMode instanceof OverriddenAccessMode )
{
return ((OverriddenAccessMode) accessMode).username();
}
else if ( accessMode instanceof AccessModeSnapshot )
{
return getUsernameFromAccessMode( ((AccessModeSnapshot) accessMode).getOriginalAccessMode() );
}
else
{
return ""; // Should never clash with a valid username
}
}
} }
Expand Up @@ -66,7 +66,6 @@
import static org.neo4j.graphdb.security.AuthorizationViolationException.PERMISSION_DENIED; import static org.neo4j.graphdb.security.AuthorizationViolationException.PERMISSION_DENIED;
import static org.neo4j.kernel.enterprise.builtinprocs.QueryId.fromExternalString; import static org.neo4j.kernel.enterprise.builtinprocs.QueryId.fromExternalString;
import static org.neo4j.kernel.enterprise.builtinprocs.QueryId.ofInternalId; import static org.neo4j.kernel.enterprise.builtinprocs.QueryId.ofInternalId;
import static org.neo4j.kernel.impl.api.security.OverriddenAccessMode.getUsernameFromAccessMode;
import static org.neo4j.procedure.Mode.DBMS; import static org.neo4j.procedure.Mode.DBMS;


@SuppressWarnings( "unused" ) @SuppressWarnings( "unused" )
Expand Down Expand Up @@ -123,7 +122,7 @@ public Stream<TransactionResult> listTransactions()
getActiveTransactions( graph.getDependencyResolver() ) getActiveTransactions( graph.getDependencyResolver() )
.stream() .stream()
.filter( tx -> !tx.terminationReason().isPresent() ) .filter( tx -> !tx.terminationReason().isPresent() )
.map( tx -> getUsernameFromAccessMode( tx.mode() ) ) .map( tx -> tx.mode().username() )
); );
} }


Expand Down Expand Up @@ -280,7 +279,7 @@ public static Stream<TransactionTerminationResult> terminateTransactionsForValid
{ {
long terminatedCount = getActiveTransactions( dependencyResolver ) long terminatedCount = getActiveTransactions( dependencyResolver )
.stream() .stream()
.filter( tx -> getUsernameFromAccessMode( tx.mode() ).equals( username ) && .filter( tx -> tx.mode().username().equals( username ) &&
!tx.isUnderlyingTransaction( currentTx ) ) !tx.isUnderlyingTransaction( currentTx ) )
.map( tx -> tx.markForTermination( Status.Transaction.Terminated ) ) .map( tx -> tx.markForTermination( Status.Transaction.Terminated ) )
.filter( marked -> marked ) .filter( marked -> marked )
Expand Down
Expand Up @@ -39,7 +39,6 @@


import static java.util.Collections.emptyList; import static java.util.Collections.emptyList;
import static org.neo4j.graphdb.security.AuthorizationViolationException.PERMISSION_DENIED; import static org.neo4j.graphdb.security.AuthorizationViolationException.PERMISSION_DENIED;
import static org.neo4j.kernel.impl.api.security.OverriddenAccessMode.getUsernameFromAccessMode;


@SuppressWarnings( "WeakerAccess" ) @SuppressWarnings( "WeakerAccess" )
public class AuthProceduresBase public class AuthProceduresBase
Expand Down Expand Up @@ -79,7 +78,7 @@ protected void terminateTransactionsForValidUser( String username )
getActiveTransactions() getActiveTransactions()
.stream() .stream()
.filter( tx -> .filter( tx ->
getUsernameFromAccessMode( tx.mode() ).equals( username ) && tx.mode().username().equals( username ) &&
!tx.isUnderlyingTransaction( currentTx ) !tx.isUnderlyingTransaction( currentTx )
).forEach( tx -> tx.markForTermination( Status.Transaction.Terminated ) ); ).forEach( tx -> tx.markForTermination( Status.Transaction.Terminated ) );
} }
Expand Down
Expand Up @@ -73,7 +73,6 @@
import static org.neo4j.bolt.v1.messaging.util.MessageMatchers.msgSuccess; import static org.neo4j.bolt.v1.messaging.util.MessageMatchers.msgSuccess;
import static org.neo4j.bolt.v1.transport.integration.TransportTestUtil.eventuallyReceives; import static org.neo4j.bolt.v1.transport.integration.TransportTestUtil.eventuallyReceives;
import static org.neo4j.helpers.collection.MapUtil.map; import static org.neo4j.helpers.collection.MapUtil.map;
import static org.neo4j.kernel.impl.api.security.OverriddenAccessMode.getUsernameFromAccessMode;
import static org.neo4j.procedure.Mode.READ; import static org.neo4j.procedure.Mode.READ;
import static org.neo4j.procedure.Mode.WRITE; import static org.neo4j.procedure.Mode.WRITE;
import static org.neo4j.server.security.enterprise.auth.plugin.api.PredefinedRoles.ADMIN; import static org.neo4j.server.security.enterprise.auth.plugin.api.PredefinedRoles.ADMIN;
Expand Down Expand Up @@ -459,7 +458,7 @@ private Map<String,Long> countTransactionsByUsername()
neo.getLocalGraph().getDependencyResolver() neo.getLocalGraph().getDependencyResolver()
).stream() ).stream()
.filter( tx -> !tx.terminationReason().isPresent() ) .filter( tx -> !tx.terminationReason().isPresent() )
.map( tx -> getUsernameFromAccessMode( tx.mode() ) ) .map( tx -> tx.mode().username() )
).collect( Collectors.toMap( r -> r.username, r -> r.activeTransactions ) ); ).collect( Collectors.toMap( r -> r.username, r -> r.activeTransactions ) );
} }


Expand Down

0 comments on commit 8b246b7

Please sign in to comment.