Skip to content

Commit

Permalink
Handle recursive overridden AccessMode
Browse files Browse the repository at this point in the history
- Move getOriginalAccessMode() to AccessMode base class
- Move username() to AccessMode base class to get rid of static downcast
method. When we have a proper security context we can remove it again.
  • Loading branch information
henriknyman committed Oct 12, 2016
1 parent 081eed5 commit 8b246b7
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 58 deletions.
Expand Up @@ -53,7 +53,6 @@ import org.neo4j.kernel.api.exceptions.schema.{AlreadyConstrainedException, Alre
import org.neo4j.kernel.api.index.{IndexDescriptor, InternalIndexState}
import org.neo4j.kernel.api.proc.{QualifiedName => KernelQualifiedName}
import org.neo4j.kernel.api.security.{AccessMode, AuthSubject}
import org.neo4j.kernel.impl.api.security.{AccessModeSnapshot, OverriddenAccessMode}
import org.neo4j.kernel.impl.core.NodeManager
import org.neo4j.kernel.impl.locking.ResourceTypes

Expand Down Expand Up @@ -588,17 +587,8 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional

type KernelProcedureCall = (KernelQualifiedName, Array[AnyRef]) => RawIterator[Array[AnyRef], ProcedureException]

private def originalAccessMode: Object = {
// TODO: Write a test for recursive procedure calls with allowed annotation
transactionalContext.accessMode match {
case a: AccessModeSnapshot => a.getOriginalAccessMode
case a: OverriddenAccessMode => a.getOriginalAccessMode
case a => a
}
}

override def callReadOnlyProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = originalAccessMode match {
val call: KernelProcedureCall = transactionalContext.accessMode.getOriginalAccessMode match {
case a: AuthSubject if a.allowsProcedureWith(allowed) =>
transactionalContext.statement.procedureCallOperations.procedureCallRead(_, _, AccessMode.Static.OVERRIDE_READ)
case _ =>
Expand All @@ -608,7 +598,7 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional
}

override def callReadWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = originalAccessMode match {
val call: KernelProcedureCall = transactionalContext.accessMode.getOriginalAccessMode match {
case a: AuthSubject if a.allowsProcedureWith(allowed) =>
transactionalContext.statement.procedureCallOperations.procedureCallWrite(_, _, AccessMode.Static.OVERRIDE_WRITE)
case _ =>
Expand All @@ -618,7 +608,7 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional
}

override def callSchemaWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = originalAccessMode match {
val call: KernelProcedureCall = transactionalContext.accessMode.getOriginalAccessMode match {
case a: AuthSubject if a.allowsProcedureWith(allowed) =>
transactionalContext.statement.procedureCallOperations.procedureCallSchema(_, _, AccessMode.Static.OVERRIDE_SCHEMA)
case _ =>
Expand Down
Expand Up @@ -366,4 +366,14 @@ public AuthorizationViolationException onViolation( String msg )
boolean overrideOriginalMode();
AuthorizationViolationException onViolation( String msg );
String name();

default String username()
{
return ""; // Should never clash with a valid username
}

default AccessMode getOriginalAccessMode()
{
return this;
}
}
Expand Up @@ -57,6 +57,7 @@ public interface AuthSubject extends AccessMode
/**
* @return A string representing the primary principal of this subject
*/
@Override
String username();

/**
Expand Down
Expand Up @@ -42,8 +42,6 @@
import org.neo4j.kernel.impl.proc.Procedures;
import org.neo4j.storageengine.api.StorageStatement;

import static org.neo4j.kernel.impl.api.security.OverriddenAccessMode.getUsernameFromAccessMode;

/**
* A resource efficient implementation of {@link Statement}. Designed to be reused within a
* {@link KernelTransactionImplementation} instance, even across transactions since this instances itself
Expand Down Expand Up @@ -216,7 +214,7 @@ final void forceClose()

final Optional<String> username()
{
String username = getUsernameFromAccessMode( transaction.mode() );
String username = transaction.mode().username();
return Optional.of( username );
}

Expand Down
Expand Up @@ -27,7 +27,6 @@
import org.neo4j.kernel.api.proc.QualifiedName;
import org.neo4j.kernel.api.security.AccessMode;
import org.neo4j.kernel.api.security.AuthSubject;
import org.neo4j.kernel.impl.api.security.AccessModeSnapshot;
import org.neo4j.kernel.impl.proc.Procedures;

public class NonTransactionalDbmsOperations implements DbmsOperations
Expand All @@ -48,10 +47,8 @@ public RawIterator<Object[],ProcedureException> procedureCallDbms(
) throws ProcedureException
{
BasicContext ctx = new BasicContext();
AccessMode originalMode = (mode instanceof AccessModeSnapshot) ?
((AccessModeSnapshot) mode).getOriginalAccessMode() :
mode;

AccessMode originalMode = mode.getOriginalAccessMode();
if ( originalMode instanceof AuthSubject )
{
ctx.put( Context.AUTH_SUBJECT, (AuthSubject) originalMode );
Expand Down
Expand Up @@ -29,10 +29,12 @@ public class AccessModeSnapshot implements AccessMode
private final boolean allowsSchemaWrites;
private final boolean overrideOriginalMode;

private final AccessMode accessMode;
private final AccessMode originalMode;

public static AccessMode createAccessModeSnapshot( AccessMode accessMode )
{
// TODO: Use flyweight pattern instead of always creating new objects, when we have a proper
// security context and do not need to obtain the original mode through this object
return new AccessModeSnapshot( accessMode );
}

Expand All @@ -43,10 +45,13 @@ private AccessModeSnapshot( AccessMode accessMode )
allowsSchemaWrites = accessMode.allowsSchemaWrites();
overrideOriginalMode = accessMode.overrideOriginalMode();

// We use this for onViolation() and name()
this.accessMode = accessMode;
// We use this for delegation of all the remaining methods
this.originalMode = accessMode;
}

//---------------------------------------
// Snapshot permissions

@Override
public boolean allowsReads()
{
Expand All @@ -71,21 +76,30 @@ public boolean overrideOriginalMode()
return overrideOriginalMode;
}

//---------------------------------------
// Delegate remaining methods to original

@Override
public AuthorizationViolationException onViolation( String msg )
{
return accessMode.onViolation( msg );
return originalMode.onViolation( msg );
}

@Override
public String name()
{
return accessMode.name();
return originalMode.name();
}

// TODO: Move this to AccessMode interface with default implementation to support recursive case
@Override
public String username()
{
return originalMode.username();
}

@Override
public AccessMode getOriginalAccessMode()
{
return accessMode;
return originalMode.getOriginalAccessMode();
}
}
Expand Up @@ -21,7 +21,6 @@

import org.neo4j.graphdb.security.AuthorizationViolationException;
import org.neo4j.kernel.api.security.AccessMode;
import org.neo4j.kernel.api.security.AuthSubject;

public class OverriddenAccessMode implements AccessMode
{
Expand Down Expand Up @@ -80,35 +79,16 @@ public String name()
}
}

@Override
public String username()
{
return getUsernameFromAccessMode( originalMode );
return originalMode.username();
}

// TODO: Move this to AccessMode interface with default implementation to support recursive case
// OR move allowsProcedureWith() to AccessMode and override that here with recursive implementation
@Override
public AccessMode getOriginalAccessMode()
{
return originalMode;
return originalMode.getOriginalAccessMode();
}

public static String getUsernameFromAccessMode( AccessMode accessMode )
{
if ( accessMode instanceof AuthSubject )
{
return ((AuthSubject) accessMode).username();
}
else if ( accessMode instanceof OverriddenAccessMode )
{
return ((OverriddenAccessMode) accessMode).username();
}
else if ( accessMode instanceof AccessModeSnapshot )
{
return getUsernameFromAccessMode( ((AccessModeSnapshot) accessMode).getOriginalAccessMode() );
}
else
{
return ""; // Should never clash with a valid username
}
}
}
Expand Up @@ -66,7 +66,6 @@
import static org.neo4j.graphdb.security.AuthorizationViolationException.PERMISSION_DENIED;
import static org.neo4j.kernel.enterprise.builtinprocs.QueryId.fromExternalString;
import static org.neo4j.kernel.enterprise.builtinprocs.QueryId.ofInternalId;
import static org.neo4j.kernel.impl.api.security.OverriddenAccessMode.getUsernameFromAccessMode;
import static org.neo4j.procedure.Mode.DBMS;

@SuppressWarnings( "unused" )
Expand Down Expand Up @@ -123,7 +122,7 @@ public Stream<TransactionResult> listTransactions()
getActiveTransactions( graph.getDependencyResolver() )
.stream()
.filter( tx -> !tx.terminationReason().isPresent() )
.map( tx -> getUsernameFromAccessMode( tx.mode() ) )
.map( tx -> tx.mode().username() )
);
}

Expand Down Expand Up @@ -280,7 +279,7 @@ public static Stream<TransactionTerminationResult> terminateTransactionsForValid
{
long terminatedCount = getActiveTransactions( dependencyResolver )
.stream()
.filter( tx -> getUsernameFromAccessMode( tx.mode() ).equals( username ) &&
.filter( tx -> tx.mode().username().equals( username ) &&
!tx.isUnderlyingTransaction( currentTx ) )
.map( tx -> tx.markForTermination( Status.Transaction.Terminated ) )
.filter( marked -> marked )
Expand Down
Expand Up @@ -39,7 +39,6 @@

import static java.util.Collections.emptyList;
import static org.neo4j.graphdb.security.AuthorizationViolationException.PERMISSION_DENIED;
import static org.neo4j.kernel.impl.api.security.OverriddenAccessMode.getUsernameFromAccessMode;

@SuppressWarnings( "WeakerAccess" )
public class AuthProceduresBase
Expand Down Expand Up @@ -79,7 +78,7 @@ protected void terminateTransactionsForValidUser( String username )
getActiveTransactions()
.stream()
.filter( tx ->
getUsernameFromAccessMode( tx.mode() ).equals( username ) &&
tx.mode().username().equals( username ) &&
!tx.isUnderlyingTransaction( currentTx )
).forEach( tx -> tx.markForTermination( Status.Transaction.Terminated ) );
}
Expand Down
Expand Up @@ -73,7 +73,6 @@
import static org.neo4j.bolt.v1.messaging.util.MessageMatchers.msgSuccess;
import static org.neo4j.bolt.v1.transport.integration.TransportTestUtil.eventuallyReceives;
import static org.neo4j.helpers.collection.MapUtil.map;
import static org.neo4j.kernel.impl.api.security.OverriddenAccessMode.getUsernameFromAccessMode;
import static org.neo4j.procedure.Mode.READ;
import static org.neo4j.procedure.Mode.WRITE;
import static org.neo4j.server.security.enterprise.auth.plugin.api.PredefinedRoles.ADMIN;
Expand Down Expand Up @@ -459,7 +458,7 @@ private Map<String,Long> countTransactionsByUsername()
neo.getLocalGraph().getDependencyResolver()
).stream()
.filter( tx -> !tx.terminationReason().isPresent() )
.map( tx -> getUsernameFromAccessMode( tx.mode() ) )
.map( tx -> tx.mode().username() )
).collect( Collectors.toMap( r -> r.username, r -> r.activeTransactions ) );
}

Expand Down

0 comments on commit 8b246b7

Please sign in to comment.