Skip to content

Commit

Permalink
Move allowsProcedureWith to AccessMode
Browse files Browse the repository at this point in the history
- Rename AccessModeSnapshot create method
  • Loading branch information
henriknyman committed Oct 12, 2016
1 parent 4921503 commit d9385d1
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 28 deletions.
Expand Up @@ -588,32 +588,29 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional
type KernelProcedureCall = (KernelQualifiedName, Array[AnyRef]) => RawIterator[Array[AnyRef], ProcedureException] type KernelProcedureCall = (KernelQualifiedName, Array[AnyRef]) => RawIterator[Array[AnyRef], ProcedureException]


override def callReadOnlyProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = { override def callReadOnlyProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = transactionalContext.accessMode.getOriginalAccessMode match { val call: KernelProcedureCall =
case a: AuthSubject if a.allowsProcedureWith(allowed) => if (allowed.nonEmpty && transactionalContext.accessMode.getOriginalAccessMode.allowsProcedureWith(allowed))
transactionalContext.statement.procedureCallOperations.procedureCallRead(_, _, AccessMode.Static.OVERRIDE_READ) transactionalContext.statement.procedureCallOperations.procedureCallRead(_, _, AccessMode.Static.OVERRIDE_READ)
case _ => else
transactionalContext.statement.procedureCallOperations.procedureCallRead(_, _) transactionalContext.statement.procedureCallOperations.procedureCallRead(_, _)
}
callProcedure(name, args, call) callProcedure(name, args, call)
} }


override def callReadWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = { override def callReadWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = transactionalContext.accessMode.getOriginalAccessMode match { val call: KernelProcedureCall =
case a: AuthSubject if a.allowsProcedureWith(allowed) => if (allowed.nonEmpty && transactionalContext.accessMode.getOriginalAccessMode.allowsProcedureWith(allowed))
transactionalContext.statement.procedureCallOperations.procedureCallWrite(_, _, AccessMode.Static.OVERRIDE_WRITE) transactionalContext.statement.procedureCallOperations.procedureCallWrite(_, _, AccessMode.Static.OVERRIDE_WRITE)
case _ => else
transactionalContext.statement.procedureCallOperations.procedureCallWrite(_, _) transactionalContext.statement.procedureCallOperations.procedureCallWrite(_, _)
}
callProcedure(name, args, call) callProcedure(name, args, call)
} }


override def callSchemaWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = { override def callSchemaWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = transactionalContext.accessMode.getOriginalAccessMode match { val call: KernelProcedureCall =
case a: AuthSubject if a.allowsProcedureWith(allowed) => if (allowed.nonEmpty && transactionalContext.accessMode.getOriginalAccessMode.allowsProcedureWith(allowed))
transactionalContext.statement.procedureCallOperations.procedureCallSchema(_, _, AccessMode.Static.OVERRIDE_SCHEMA) transactionalContext.statement.procedureCallOperations.procedureCallSchema(_, _, AccessMode.Static.OVERRIDE_SCHEMA)
case _ => else
transactionalContext.statement.procedureCallOperations.procedureCallSchema(_, _) transactionalContext.statement.procedureCallOperations.procedureCallSchema(_, _)
}
callProcedure(name, args, call) callProcedure(name, args, call)
} }


Expand Down
Expand Up @@ -20,6 +20,7 @@
package org.neo4j.kernel.api.security; package org.neo4j.kernel.api.security;


import org.neo4j.graphdb.security.AuthorizationViolationException; import org.neo4j.graphdb.security.AuthorizationViolationException;
import org.neo4j.kernel.api.exceptions.InvalidArgumentsException;
import org.neo4j.kernel.api.exceptions.Status; import org.neo4j.kernel.api.exceptions.Status;


/** Controls the capabilities of a KernelTransaction. */ /** Controls the capabilities of a KernelTransaction. */
Expand Down Expand Up @@ -340,5 +341,18 @@ default AccessMode getOriginalAccessMode()
return this; return this;
} }


/**
* Determines whether this subject is allowed to execute a procedure with the parameter string in its
* procedure annotation.
*
* @param roleNames
* @return
* @throws InvalidArgumentsException
*/
default boolean allowsProcedureWith( String[] roleNames ) throws InvalidArgumentsException
{
return false;
}

AccessMode getSnapshot(); AccessMode getSnapshot();
} }
Expand Up @@ -47,14 +47,6 @@ public interface AuthSubject extends AccessMode
*/ */
void setPasswordChangeNoLongerRequired(); void setPasswordChangeNoLongerRequired();


/**
* Determines whether this subject is allowed to execute a procedure with the parameter string in its procedure annotation.
* @param roleNames
* @return
* @throws InvalidArgumentsException
*/
boolean allowsProcedureWith( String[] roleNames ) throws InvalidArgumentsException;

/** /**
* @param username a username * @param username a username
* @return true if the provided username is the underlying user name of this subject * @return true if the provided username is the underlying user name of this subject
Expand All @@ -74,7 +66,7 @@ default void ensureUserExistsWithName( String username ) throws InvalidArguments
@Override @Override
default AccessMode getSnapshot() default AccessMode getSnapshot()
{ {
return AccessModeSnapshot.createAccessModeSnapshot( this ); return AccessModeSnapshot.create( this );
} }


abstract class StaticAccessModeAdapter implements AuthSubject abstract class StaticAccessModeAdapter implements AuthSubject
Expand Down
Expand Up @@ -31,7 +31,7 @@ public class AccessModeSnapshot implements AccessMode


private final AccessMode originalMode; private final AccessMode originalMode;


public static AccessMode createAccessModeSnapshot( AccessMode accessMode ) public static AccessMode create( AccessMode accessMode )
{ {
// TODO: Use flyweight pattern instead of always creating new objects, when we have a proper // TODO: Use flyweight pattern instead of always creating new objects, when we have a proper
// security context and do not need to obtain the original mode through this object // security context and do not need to obtain the original mode through this object
Expand Down
Expand Up @@ -94,6 +94,6 @@ public AccessMode getOriginalAccessMode()
@Override @Override
public AccessMode getSnapshot() public AccessMode getSnapshot()
{ {
return AccessModeSnapshot.createAccessModeSnapshot( this ); return AccessModeSnapshot.create( this );
} }
} }
Expand Up @@ -42,7 +42,7 @@ public void shouldCorrectlyReflectAccessMode()


private void testAccessMode( AccessMode originalAccessMode ) private void testAccessMode( AccessMode originalAccessMode )
{ {
AccessMode accessModeSnapshot = AccessModeSnapshot.createAccessModeSnapshot( originalAccessMode ); AccessMode accessModeSnapshot = AccessModeSnapshot.create( originalAccessMode );
assertEquals( accessModeSnapshot.allowsReads(), originalAccessMode.allowsReads() ); assertEquals( accessModeSnapshot.allowsReads(), originalAccessMode.allowsReads() );
assertEquals( accessModeSnapshot.allowsWrites(), originalAccessMode.allowsWrites() ); assertEquals( accessModeSnapshot.allowsWrites(), originalAccessMode.allowsWrites() );
assertEquals( accessModeSnapshot.allowsSchemaWrites(), originalAccessMode.allowsSchemaWrites() ); assertEquals( accessModeSnapshot.allowsSchemaWrites(), originalAccessMode.allowsSchemaWrites() );
Expand Down
Expand Up @@ -193,8 +193,8 @@ public String username()
public AccessMode getSnapshot() public AccessMode getSnapshot()
{ {
// NOTE: allowsProcedureWith() is delegated to the original access mode (=this) // NOTE: allowsProcedureWith() is delegated to the original access mode (=this)
// so we just need to store the authorization info, and call the default // so we just need to store the authorization info, and create a normal access mode snapshot
authorizationInfoSnapshot = authManager.getAuthorizationInfo( shiroSubject.getPrincipals() ); authorizationInfoSnapshot = authManager.getAuthorizationInfo( shiroSubject.getPrincipals() );
return AccessModeSnapshot.createAccessModeSnapshot( this ); return AccessModeSnapshot.create( this );
} }
} }
Expand Up @@ -19,6 +19,7 @@
*/ */
package org.neo4j.server.security.enterprise.auth; package org.neo4j.server.security.enterprise.auth;


import org.apache.directory.api.util.Strings;
import org.junit.After; import org.junit.After;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
Expand Down Expand Up @@ -335,8 +336,14 @@ void assertPasswordChangeWhenPasswordChangeRequired( S subject, String newPasswo
void assertFail( S subject, String call, String partOfErrorMsg ) void assertFail( S subject, String call, String partOfErrorMsg )
{ {
String err = assertCallEmpty( subject, call ); String err = assertCallEmpty( subject, call );
assertThat( err, not( equalTo( "" ) ) ); if ( Strings.isEmpty( partOfErrorMsg ) )
assertThat( err, containsString( partOfErrorMsg ) ); {
assertThat( err, not( equalTo( "" ) ) );
}
else
{
assertThat( err, containsString( partOfErrorMsg ) );
}
} }


private void assertFail( S subject, String call, String partOfErrorMsg1, String partOfErrorMsg2 ) private void assertFail( S subject, String call, String partOfErrorMsg1, String partOfErrorMsg2 )
Expand Down

0 comments on commit d9385d1

Please sign in to comment.