Skip to content

Commit

Permalink
Restrict access to unsecured components
Browse files Browse the repository at this point in the history
Internal procedures should have access to all components while
external procedures should not be allowed to access components
that could let them bypass security.
  • Loading branch information
OliviaYtterbrink committed Feb 8, 2017
1 parent 1a32929 commit d5dd203
Show file tree
Hide file tree
Showing 15 changed files with 134 additions and 58 deletions.
Expand Up @@ -362,16 +362,16 @@ private Procedures setupProcedures( PlatformModule platform, EditionModule editi

// Register injected public API components
Log proceduresLog = platform.logging.getUserLog( Procedures.class );
procedures.registerComponent( Log.class, ( ctx ) -> proceduresLog );
procedures.registerComponent( Log.class, ( ctx ) -> proceduresLog, true );

Guard guard = platform.dependencies.resolveDependency( Guard.class );
procedures.registerComponent( TerminationGuard.class, new TerminationGuardProvider( guard ) );
procedures.registerComponent( TerminationGuard.class, new TerminationGuardProvider( guard ), true );

// Register injected private API components: useful to have available in procedures to access the kernel etc.
ProcedureGDSFactory gdsFactory = new ProcedureGDSFactory( platform.config, platform.storeDir,
platform.dependencies, storeId, this.queryExecutor, editionModule.coreAPIAvailabilityGuard,
platform.urlAccessRule );
procedures.registerComponent( GraphDatabaseService.class, gdsFactory::apply );
procedures.registerComponent( GraphDatabaseService.class, gdsFactory::apply, true );

// Below components are not public API, but are made available for internal
// procedures to call, and to provide temporary workarounds for the following
Expand All @@ -380,12 +380,12 @@ private Procedures setupProcedures( PlatformModule platform, EditionModule editi
// - Group-transaction writes (same pattern as above, but rather than splitting large transactions,
// combine lots of small ones)
// - Bleeding-edge performance (KernelTransaction, to bypass overhead of working with Core API)
procedures.registerComponent( DependencyResolver.class, ( ctx ) -> platform.dependencies );
procedures.registerComponent( KernelTransaction.class, ( ctx ) -> ctx.get( KERNEL_TRANSACTION ) );
procedures.registerComponent( GraphDatabaseAPI.class, ( ctx ) -> platform.graphDatabaseFacade );
procedures.registerComponent( DependencyResolver.class, ( ctx ) -> platform.dependencies, false );
procedures.registerComponent( KernelTransaction.class, ( ctx ) -> ctx.get( KERNEL_TRANSACTION ), false );
procedures.registerComponent( GraphDatabaseAPI.class, ( ctx ) -> platform.graphDatabaseFacade, false );

// Security procedures
procedures.registerComponent( SecurityContext.class, ctx -> ctx.get( SECURITY_CONTEXT ) );
procedures.registerComponent( SecurityContext.class, ctx -> ctx.get( SECURITY_CONTEXT ), true );

// Edition procedures
try
Expand Down
Expand Up @@ -91,7 +91,7 @@ private Callables loadProcedures( URL jar, ClassLoader loader, Callables target
while ( classes.hasNext() )
{
Class<?> next = classes.next();
target.addAllProcedures( compiler.compileProcedure( next, Optional.empty() ) );
target.addAllProcedures( compiler.compileProcedure( next, Optional.empty(), false ) );
target.addAllFunctions( compiler.compileFunction( next ) );
target.addAllAggregationFunctions( compiler.compileAggregationFunction( next ) );
}
Expand Down
Expand Up @@ -49,7 +49,8 @@ public class Procedures extends LifecycleAdapter
{
private final ProcedureRegistry registry = new ProcedureRegistry();
private final TypeMappers typeMappers = new TypeMappers();
private final ComponentRegistry components = new ComponentRegistry();
private final ComponentRegistry safeComponents = new ComponentRegistry();
private final ComponentRegistry allComponents = new ComponentRegistry();
private final ReflectiveProcedureCompiler compiler;
private final ThrowingConsumer<Procedures, ProcedureException> builtin;
private final File pluginDir;
Expand All @@ -67,7 +68,7 @@ public Procedures( ThrowingConsumer<Procedures,ProcedureException> builtin, File
this.builtin = builtin;
this.pluginDir = pluginDir;
this.log = log;
this.compiler = new ReflectiveProcedureCompiler( typeMappers, components, log, config );
this.compiler = new ReflectiveProcedureCompiler( typeMappers, safeComponents, allComponents, log, config );
}

/**
Expand Down Expand Up @@ -125,7 +126,7 @@ public void register( CallableProcedure proc, boolean overrideCurrentImplementat
}

/**
* Register a new procedure defined with annotations on a java class.
* Register a new internal procedure defined with annotations on a java class.
* @param proc the procedure class
*/
public void registerProcedure( Class<?> proc ) throws KernelException
Expand All @@ -134,7 +135,7 @@ public void registerProcedure( Class<?> proc ) throws KernelException
}

/**
* Register a new procedure defined with annotations on a java class.
* Register a new internal procedure defined with annotations on a java class.
* @param proc the procedure class
* @param overrideCurrentImplementation set to true if procedures within this class should override older procedures with the same name
*/
Expand All @@ -144,7 +145,7 @@ public void registerProcedure( Class<?> proc, boolean overrideCurrentImplementat
}

/**
* Register a new procedure defined with annotations on a java class.
* Register a new internal procedure defined with annotations on a java class.
* @param proc the procedure class
* @param overrideCurrentImplementation set to true if procedures within this class should override older procedures with the same name
* @param warning the warning the procedure should generate when called
Expand All @@ -153,7 +154,7 @@ public void registerProcedure( Class<?> proc, boolean overrideCurrentImplementat
throws
KernelException
{
for ( CallableProcedure procedure : compiler.compileProcedure( proc, warning ) )
for ( CallableProcedure procedure : compiler.compileProcedure( proc, warning, true ) )
{
register( procedure, overrideCurrentImplementation );
}
Expand Down Expand Up @@ -213,13 +214,17 @@ public void registerType( Class<?> javaClass, TypeMappers.NeoValueConverter toNe

/**
* Registers a component, these become available in reflective procedures for injection.
*
* @param cls the type of component to be registered (this is what users 'ask' for in their field declaration)
* @param provider a function that supplies the component, given the context of a procedure invocation
* @param safe set to false if this component can bypass security, true if it respects security
*/
public <T> void registerComponent( Class<T> cls, ComponentRegistry.Provider<T> provider )
public <T> void registerComponent( Class<T> cls, ComponentRegistry.Provider<T> provider, boolean safe )
{
components.register( cls, provider );
if ( safe )
{
safeComponents.register( cls, provider );
}
allComponents.register( cls, provider );
}

public ProcedureSignature procedure( QualifiedName name ) throws ProcedureException
Expand Down
Expand Up @@ -68,17 +68,19 @@ class ReflectiveProcedureCompiler
private final MethodHandles.Lookup lookup = MethodHandles.lookup();
private final OutputMappers outputMappers;
private final MethodSignatureCompiler inputSignatureDeterminer;
private final FieldInjections fieldInjections;
private final FieldInjections safeFieldInjections;
private final FieldInjections allFieldInjections;
private final Log log;
private final TypeMappers typeMappers;
private final ProcedureAllowedConfig config;

ReflectiveProcedureCompiler( TypeMappers typeMappers, ComponentRegistry components, Log log,
ProcedureAllowedConfig config )
ReflectiveProcedureCompiler( TypeMappers typeMappers, ComponentRegistry safeComponents,
ComponentRegistry allComponents, Log log, ProcedureAllowedConfig config )
{
inputSignatureDeterminer = new MethodSignatureCompiler( typeMappers );
outputMappers = new OutputMappers( typeMappers );
this.fieldInjections = new FieldInjections( components );
this.safeFieldInjections = new FieldInjections( safeComponents );
this.allFieldInjections = new FieldInjections( allComponents );
this.log = log;
this.typeMappers = typeMappers;
this.config = config;
Expand Down Expand Up @@ -152,7 +154,8 @@ List<CallableUserAggregationFunction> compileAggregationFunction( Class<?> fcnDe
}
}

List<CallableProcedure> compileProcedure( Class<?> procDefinition, Optional<String> warning ) throws KernelException
List<CallableProcedure> compileProcedure( Class<?> procDefinition, Optional<String> warning,
boolean fullAccess ) throws KernelException
{
try
{
Expand All @@ -170,7 +173,7 @@ List<CallableProcedure> compileProcedure( Class<?> procDefinition, Optional<Stri
ArrayList<CallableProcedure> out = new ArrayList<>( procedureMethods.size() );
for ( Method method : procedureMethods )
{
out.add( compileProcedure( procDefinition, constructor, method, warning ) );
out.add( compileProcedure( procDefinition, constructor, method, warning, fullAccess ) );
}
out.sort( Comparator.comparing( a -> a.signature().name().toString() ) );
return out;
Expand All @@ -187,7 +190,7 @@ List<CallableProcedure> compileProcedure( Class<?> procDefinition, Optional<Stri
}

private ReflectiveProcedure compileProcedure( Class<?> procDefinition, MethodHandle constructor, Method method,
Optional<String> warning )
Optional<String> warning, boolean fullAccess )
throws ProcedureException, IllegalAccessException
{
String valueName = method.getAnnotation( Procedure.class ).value();
Expand All @@ -197,7 +200,15 @@ private ReflectiveProcedure compileProcedure( Class<?> procDefinition, MethodHan
List<FieldSignature> inputSignature = inputSignatureDeterminer.signatureFor( method );
OutputMapper outputMapper = outputMappers.mapper( method );
MethodHandle procedureMethod = lookup.unreflect( method );
List<FieldInjections.FieldSetter> setters = fieldInjections.setters( procDefinition );
List<FieldInjections.FieldSetter> setters;
if ( fullAccess )
{
setters = allFieldInjections.setters( procDefinition );
}
else
{
setters = safeFieldInjections.setters( procDefinition );
}

Optional<String> description = description( method );
Procedure procedure = method.getAnnotation( Procedure.class );
Expand Down Expand Up @@ -243,7 +254,7 @@ private ReflectiveUserFunction compileFunction( Class<?> procDefinition, MethodH
Class<?> returnType = method.getReturnType();
TypeMappers.NeoValueConverter valueConverter = typeMappers.converterFor( returnType );
MethodHandle procedureMethod = lookup.unreflect( method );
List<FieldInjections.FieldSetter> setters = fieldInjections.setters( procDefinition );
List<FieldInjections.FieldSetter> setters = safeFieldInjections.setters( procDefinition );

Optional<String> description = description( method );
UserFunction function = method.getAnnotation( UserFunction.class );
Expand Down Expand Up @@ -345,7 +356,7 @@ private ReflectiveUserAggregationFunction compileAggregationFunction( Class<?> d
MethodHandle creator = lookup.unreflect( method );
MethodHandle updateMethod = lookup.unreflect( update );
MethodHandle resultMethod = lookup.unreflect( result );
List<FieldInjections.FieldSetter> setters = fieldInjections.setters( definition);
List<FieldInjections.FieldSetter> setters = safeFieldInjections.setters( definition );

Optional<String> description = description( method );
UserAggregationFunction function = method.getAnnotation( UserAggregationFunction.class );
Expand Down
Expand Up @@ -285,9 +285,9 @@ private Integer token( String name, Map<Integer,String> tokens )
@Before
public void setup() throws Exception
{
procs.registerComponent( KernelTransaction.class, ( ctx ) -> ctx.get( KERNEL_TRANSACTION ) );
procs.registerComponent( DependencyResolver.class, ( ctx ) -> ctx.get( DEPENDENCY_RESOLVER ) );
procs.registerComponent( GraphDatabaseAPI.class, ( ctx ) -> ctx.get( GRAPHDATABASEAPI ) );
procs.registerComponent( KernelTransaction.class, ( ctx ) -> ctx.get( KERNEL_TRANSACTION ), false );
procs.registerComponent( DependencyResolver.class, ( ctx ) -> ctx.get( DEPENDENCY_RESOLVER ), false );
procs.registerComponent( GraphDatabaseAPI.class, ( ctx ) -> ctx.get( GRAPHDATABASEAPI ), false );

procs.registerType( Node.class, new TypeMappers.SimpleConverter( NTNode, Node.class ) );
procs.registerType( Relationship.class, new TypeMappers.SimpleConverter( NTRelationship, Relationship.class ) );
Expand Down
Expand Up @@ -57,7 +57,7 @@ public class ProcedureJarLoaderTest

private final ProcedureJarLoader jarloader =
new ProcedureJarLoader( new ReflectiveProcedureCompiler( new TypeMappers(), new ComponentRegistry(),
NullLog.getInstance(), ProcedureAllowedConfig.DEFAULT ), NullLog.getInstance() );
new ComponentRegistry(), NullLog.getInstance(), ProcedureAllowedConfig.DEFAULT ), NullLog.getInstance() );

@Test
public void shouldLoadProcedureFromJar() throws Throwable
Expand Down
Expand Up @@ -55,6 +55,7 @@
import static org.neo4j.helpers.collection.Iterators.asList;
import static org.neo4j.kernel.api.proc.ProcedureSignature.procedureSignature;

@SuppressWarnings( "WeakerAccess" )
public class ReflectiveProcedureTest
{
@Rule
Expand All @@ -67,7 +68,8 @@ public class ReflectiveProcedureTest
public void setUp() throws Exception
{
components = new ComponentRegistry();
procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(), components, NullLog.getInstance(), ProcedureAllowedConfig.DEFAULT );
procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(), components, components,
NullLog.getInstance(), ProcedureAllowedConfig.DEFAULT );
}

@Test
Expand All @@ -77,7 +79,7 @@ public void shouldInjectLogging() throws KernelException
Log log = spy( Log.class );
components.register( Log.class, (ctx) -> log );
CallableProcedure procedure =
procedureCompiler.compileProcedure( LoggingProcedure.class, Optional.empty() ).get( 0 );
procedureCompiler.compileProcedure( LoggingProcedure.class, Optional.empty(), true ).get( 0 );

// When
procedure.apply( new BasicContext(), new Object[0] );
Expand Down Expand Up @@ -273,12 +275,12 @@ public void shouldSupportProcedureDeprecation() throws Throwable
{
// Given
Log log = mock(Log.class);
ReflectiveProcedureCompiler procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(), components, log,
ProcedureAllowedConfig.DEFAULT );
ReflectiveProcedureCompiler procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(), components,
components, log, ProcedureAllowedConfig.DEFAULT );

// When
List<CallableProcedure> procs =
procedureCompiler.compileProcedure( ProcedureWithDeprecation.class, Optional.empty() );
procedureCompiler.compileProcedure( ProcedureWithDeprecation.class, Optional.empty(), true );

// Then
verify( log ).warn( "Use of @Procedure(deprecatedBy) without @Deprecated in badProc" );
Expand Down Expand Up @@ -524,6 +526,6 @@ public void badProc()

private List<CallableProcedure> compile( Class<?> clazz ) throws KernelException
{
return procedureCompiler.compileProcedure( clazz, Optional.empty() );
return procedureCompiler.compileProcedure( clazz, Optional.empty(), true );
}
}
Expand Up @@ -215,7 +215,8 @@ public Stream<MyOutputRecord> defaultValues( @Name( value = "a", defaultValue =

private List<CallableProcedure> compile( Class<?> clazz ) throws KernelException
{
return new ReflectiveProcedureCompiler( new TypeMappers(), new ComponentRegistry(), NullLog.getInstance(),
ProcedureAllowedConfig.DEFAULT ).compileProcedure( clazz, Optional.empty() );
return new ReflectiveProcedureCompiler( new TypeMappers(), new ComponentRegistry(), new ComponentRegistry(),
NullLog.getInstance(),
ProcedureAllowedConfig.DEFAULT ).compileProcedure( clazz, Optional.empty(), true );
}
}
Expand Up @@ -69,7 +69,8 @@ public class ReflectiveUserAggregationFunctionTest
public void setUp() throws Exception
{
components = new ComponentRegistry();
procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(), components, NullLog.getInstance(),
procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(), components, new ComponentRegistry(),
NullLog.getInstance(),
ProcedureAllowedConfig.DEFAULT );
}

Expand Down Expand Up @@ -344,7 +345,7 @@ public void shouldSupportFunctionDeprecation() throws Throwable
// Given
Log log = mock(Log.class);
ReflectiveProcedureCompiler procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(),
components, log, ProcedureAllowedConfig.DEFAULT );
components, new ComponentRegistry(), log, ProcedureAllowedConfig.DEFAULT );

// When
List<CallableUserAggregationFunction> funcs = procedureCompiler.compileAggregationFunction( FunctionWithDeprecation.class );
Expand Down
Expand Up @@ -65,7 +65,8 @@ public class ReflectiveUserFunctionTest
public void setUp() throws Exception
{
components = new ComponentRegistry();
procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(), components, NullLog.getInstance(), ProcedureAllowedConfig.DEFAULT );
procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(), components, new ComponentRegistry(),
NullLog.getInstance(), ProcedureAllowedConfig.DEFAULT );
}

@Test
Expand Down Expand Up @@ -249,7 +250,8 @@ public void shouldSupportFunctionDeprecation() throws Throwable
{
// Given
Log log = mock(Log.class);
ReflectiveProcedureCompiler procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(), components, log, ProcedureAllowedConfig.DEFAULT );
ReflectiveProcedureCompiler procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(), components,
new ComponentRegistry(), log, ProcedureAllowedConfig.DEFAULT );

// When
List<CallableUserFunction> funcs = procedureCompiler.compileFunction( FunctionWithDeprecation.class );
Expand Down

0 comments on commit d5dd203

Please sign in to comment.