Skip to content

Commit

Permalink
Move allowsProcedureWith to AccessMode
Browse files Browse the repository at this point in the history
- Rename AccessModeSnapshot create method
  • Loading branch information
henriknyman committed Oct 12, 2016
1 parent 4921503 commit d9385d1
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 28 deletions.
Expand Up @@ -588,32 +588,29 @@ final class TransactionBoundQueryContext(val transactionalContext: Transactional
type KernelProcedureCall = (KernelQualifiedName, Array[AnyRef]) => RawIterator[Array[AnyRef], ProcedureException]

override def callReadOnlyProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = transactionalContext.accessMode.getOriginalAccessMode match {
case a: AuthSubject if a.allowsProcedureWith(allowed) =>
val call: KernelProcedureCall =
if (allowed.nonEmpty && transactionalContext.accessMode.getOriginalAccessMode.allowsProcedureWith(allowed))
transactionalContext.statement.procedureCallOperations.procedureCallRead(_, _, AccessMode.Static.OVERRIDE_READ)
case _ =>
else
transactionalContext.statement.procedureCallOperations.procedureCallRead(_, _)
}
callProcedure(name, args, call)
}

override def callReadWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = transactionalContext.accessMode.getOriginalAccessMode match {
case a: AuthSubject if a.allowsProcedureWith(allowed) =>
val call: KernelProcedureCall =
if (allowed.nonEmpty && transactionalContext.accessMode.getOriginalAccessMode.allowsProcedureWith(allowed))
transactionalContext.statement.procedureCallOperations.procedureCallWrite(_, _, AccessMode.Static.OVERRIDE_WRITE)
case _ =>
else
transactionalContext.statement.procedureCallOperations.procedureCallWrite(_, _)
}
callProcedure(name, args, call)
}

override def callSchemaWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = {
val call: KernelProcedureCall = transactionalContext.accessMode.getOriginalAccessMode match {
case a: AuthSubject if a.allowsProcedureWith(allowed) =>
val call: KernelProcedureCall =
if (allowed.nonEmpty && transactionalContext.accessMode.getOriginalAccessMode.allowsProcedureWith(allowed))
transactionalContext.statement.procedureCallOperations.procedureCallSchema(_, _, AccessMode.Static.OVERRIDE_SCHEMA)
case _ =>
else
transactionalContext.statement.procedureCallOperations.procedureCallSchema(_, _)
}
callProcedure(name, args, call)
}

Expand Down
Expand Up @@ -20,6 +20,7 @@
package org.neo4j.kernel.api.security;

import org.neo4j.graphdb.security.AuthorizationViolationException;
import org.neo4j.kernel.api.exceptions.InvalidArgumentsException;
import org.neo4j.kernel.api.exceptions.Status;

/** Controls the capabilities of a KernelTransaction. */
Expand Down Expand Up @@ -340,5 +341,18 @@ default AccessMode getOriginalAccessMode()
return this;
}

/**
* Determines whether this subject is allowed to execute a procedure with the parameter string in its
* procedure annotation.
*
* @param roleNames
* @return
* @throws InvalidArgumentsException
*/
default boolean allowsProcedureWith( String[] roleNames ) throws InvalidArgumentsException
{
return false;
}

AccessMode getSnapshot();
}
Expand Up @@ -47,14 +47,6 @@ public interface AuthSubject extends AccessMode
*/
void setPasswordChangeNoLongerRequired();

/**
* Determines whether this subject is allowed to execute a procedure with the parameter string in its procedure annotation.
* @param roleNames
* @return
* @throws InvalidArgumentsException
*/
boolean allowsProcedureWith( String[] roleNames ) throws InvalidArgumentsException;

/**
* @param username a username
* @return true if the provided username is the underlying user name of this subject
Expand All @@ -74,7 +66,7 @@ default void ensureUserExistsWithName( String username ) throws InvalidArguments
@Override
default AccessMode getSnapshot()
{
return AccessModeSnapshot.createAccessModeSnapshot( this );
return AccessModeSnapshot.create( this );
}

abstract class StaticAccessModeAdapter implements AuthSubject
Expand Down
Expand Up @@ -31,7 +31,7 @@ public class AccessModeSnapshot implements AccessMode

private final AccessMode originalMode;

public static AccessMode createAccessModeSnapshot( AccessMode accessMode )
public static AccessMode create( AccessMode accessMode )
{
// TODO: Use flyweight pattern instead of always creating new objects, when we have a proper
// security context and do not need to obtain the original mode through this object
Expand Down
Expand Up @@ -94,6 +94,6 @@ public AccessMode getOriginalAccessMode()
@Override
public AccessMode getSnapshot()
{
return AccessModeSnapshot.createAccessModeSnapshot( this );
return AccessModeSnapshot.create( this );
}
}
Expand Up @@ -42,7 +42,7 @@ public void shouldCorrectlyReflectAccessMode()

private void testAccessMode( AccessMode originalAccessMode )
{
AccessMode accessModeSnapshot = AccessModeSnapshot.createAccessModeSnapshot( originalAccessMode );
AccessMode accessModeSnapshot = AccessModeSnapshot.create( originalAccessMode );
assertEquals( accessModeSnapshot.allowsReads(), originalAccessMode.allowsReads() );
assertEquals( accessModeSnapshot.allowsWrites(), originalAccessMode.allowsWrites() );
assertEquals( accessModeSnapshot.allowsSchemaWrites(), originalAccessMode.allowsSchemaWrites() );
Expand Down
Expand Up @@ -193,8 +193,8 @@ public String username()
public AccessMode getSnapshot()
{
// NOTE: allowsProcedureWith() is delegated to the original access mode (=this)
// so we just need to store the authorization info, and call the default
// so we just need to store the authorization info, and create a normal access mode snapshot
authorizationInfoSnapshot = authManager.getAuthorizationInfo( shiroSubject.getPrincipals() );
return AccessModeSnapshot.createAccessModeSnapshot( this );
return AccessModeSnapshot.create( this );
}
}
Expand Up @@ -19,6 +19,7 @@
*/
package org.neo4j.server.security.enterprise.auth;

import org.apache.directory.api.util.Strings;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
Expand Down Expand Up @@ -335,8 +336,14 @@ void assertPasswordChangeWhenPasswordChangeRequired( S subject, String newPasswo
void assertFail( S subject, String call, String partOfErrorMsg )
{
String err = assertCallEmpty( subject, call );
assertThat( err, not( equalTo( "" ) ) );
assertThat( err, containsString( partOfErrorMsg ) );
if ( Strings.isEmpty( partOfErrorMsg ) )
{
assertThat( err, not( equalTo( "" ) ) );
}
else
{
assertThat( err, containsString( partOfErrorMsg ) );
}
}

private void assertFail( S subject, String call, String partOfErrorMsg1, String partOfErrorMsg2 )
Expand Down

0 comments on commit d9385d1

Please sign in to comment.