Skip to content

Commit

Permalink
Extended Procedure sandboxing to involve user defined functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
eebus committed Mar 6, 2017
1 parent d84abc9 commit 18d7f4a
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 53 deletions.
Expand Up @@ -524,9 +524,9 @@ public enum LabelIndex
public static final Setting<File> auth_store =
pathSetting( "unsupported.dbms.security.auth_store.location", NO_DEFAULT );

@Description( "A list of procedures (comma separated) that are allowed full access to the database. " +
"The list may contain both fully-qualified procedure names, and partial names with the wildcard '*'. " +
"Note that this enables these procedures to bypass security. Use with caution." )
@Description( "A list of procedures and user defined functions (comma separated) that are allowed full access to " +
"the database. The list may contain both fully-qualified procedure names, and partial names with the " +
"wildcard '*'. Note that this enables these procedures to bypass security. Use with caution." )
public static final Setting<String> procedure_unrestricted =
setting( "dbms.security.procedures.unrestricted", Settings.STRING, "" );

Expand Down
Expand Up @@ -19,9 +19,9 @@
*/
package org.neo4j.kernel.api.exceptions;

public class ProcedureInjectionException extends ProcedureException
public class ComponentInjectionException extends ProcedureException
{
public ProcedureInjectionException( Status statusCode, String message,
public ComponentInjectionException( Status statusCode, String message,
Object... parameters )
{
super( statusCode, message, parameters );
Expand Down
Expand Up @@ -49,6 +49,6 @@ public UserFunctionSignature signature()
}

@Override
public abstract Aggregator create( Context ctx);
public abstract Aggregator create( Context ctx) throws ProcedureException;
}
}
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2002-2017 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.kernel.api.proc;

import org.neo4j.kernel.api.exceptions.ProcedureException;
import org.neo4j.kernel.api.exceptions.Status;

public class LoadFailAggregatedFunction extends CallableUserAggregationFunction.BasicUserAggregationFunction
{
public LoadFailAggregatedFunction( UserFunctionSignature signature )
{
super( signature );
}

@Override
public Aggregator create( Context ctx ) throws ProcedureException
{
throw new ProcedureException( Status.Procedure.ProcedureRegistrationFailed,
signature().description().orElse( "Failed to load " + signature().name().toString() ) );
}
}
@@ -0,0 +1,39 @@
/*
* Copyright (c) 2002-2017 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.kernel.api.proc;

import org.neo4j.collection.RawIterator;
import org.neo4j.kernel.api.exceptions.ProcedureException;
import org.neo4j.kernel.api.exceptions.Status;

public class LoadFailFunction extends CallableUserFunction.BasicUserFunction
{
public LoadFailFunction( UserFunctionSignature signature )
{
super( signature );
}

@Override
public RawIterator<Object[],ProcedureException> apply( Context ctx, Object[] input ) throws ProcedureException
{
throw new ProcedureException( Status.Procedure.ProcedureRegistrationFailed,
signature().description().orElse( "Failed to load " + signature().name().toString() ) );
}
}
Expand Up @@ -26,7 +26,7 @@
import java.util.LinkedList;
import java.util.List;

import org.neo4j.kernel.api.exceptions.ProcedureInjectionException;
import org.neo4j.kernel.api.exceptions.ComponentInjectionException;
import org.neo4j.kernel.api.exceptions.ProcedureException;
import org.neo4j.kernel.api.exceptions.Status;
import org.neo4j.procedure.Context;
Expand Down Expand Up @@ -122,10 +122,10 @@ private FieldSetter createInjector( Class<?> cls, Field field ) throws Procedure
ComponentRegistry.Provider<?> provider = components.providerFor( field.getType() );
if( provider == null )
{
throw new ProcedureInjectionException( Status.Procedure.ProcedureRegistrationFailed,
throw new ComponentInjectionException( Status.Procedure.ProcedureRegistrationFailed,
"Unable to set up injection for procedure `%s`, the field `%s` " +
"has type `%s` which is not a known injectable component.",
cls.getSimpleName(), field.getName(), field.getType());
cls.getSimpleName(), field.getName(), field.getType());
}

MethodHandle setter = MethodHandles.lookup().unreflectSetter( field );
Expand Down
Expand Up @@ -34,7 +34,7 @@
import java.util.stream.Stream;

import org.neo4j.collection.RawIterator;
import org.neo4j.kernel.api.exceptions.ProcedureInjectionException;
import org.neo4j.kernel.api.exceptions.ComponentInjectionException;
import org.neo4j.kernel.api.exceptions.KernelException;
import org.neo4j.kernel.api.exceptions.ProcedureException;
import org.neo4j.kernel.api.exceptions.Status;
Expand All @@ -43,6 +43,8 @@
import org.neo4j.kernel.api.proc.CallableUserFunction;
import org.neo4j.kernel.api.proc.Context;
import org.neo4j.kernel.api.proc.FieldSignature;
import org.neo4j.kernel.api.proc.LoadFailAggregatedFunction;
import org.neo4j.kernel.api.proc.LoadFailFunction;
import org.neo4j.kernel.api.proc.LoadFailProcedure;
import org.neo4j.kernel.api.proc.ProcedureSignature;
import org.neo4j.kernel.api.proc.QualifiedName;
Expand Down Expand Up @@ -231,32 +233,32 @@ private CallableProcedure compileProcedure( Class<?> procDefinition, MethodHandl
Optional<String> deprecated = deprecated( method, procedure::deprecatedBy,
"Use of @Procedure(deprecatedBy) without @Deprecated in " + procName );

List<FieldInjections.FieldSetter> setters;
setters = allFieldInjections.setters( procDefinition );
ProcedureSignature signature;

List<FieldInjections.FieldSetter> setters = allFieldInjections.setters( procDefinition );
if ( !fullAccess && !config.fullAccessFor( procName.toString() ) )
{
try
{
setters = safeFieldInjections.setters( procDefinition );
}
catch ( ProcedureInjectionException e )
catch ( ComponentInjectionException e )
{
description = Optional.of( procName.toString() +
" is not available due to not having unrestricted access rights, check configuration." );
log.warn( description.get());
signature = new ProcedureSignature( procName, inputSignature, outputMapper.signature(), Mode.DEFAULT,
Optional.empty(), new String[0], description, warning );
log.warn( description.get() );
ProcedureSignature signature =
new ProcedureSignature( procName, inputSignature, outputMapper.signature(), Mode.DEFAULT,
Optional.empty(), new String[0], description, warning );
return new LoadFailProcedure( signature );
}
}
signature = new ProcedureSignature( procName, inputSignature, outputMapper.signature(), mode, deprecated,
config.rolesFor( procName.toString() ), description, warning );

ProcedureSignature signature =
new ProcedureSignature( procName, inputSignature, outputMapper.signature(), mode, deprecated,
config.rolesFor( procName.toString() ), description, warning );
return new ReflectiveProcedure( signature, constructor, procedureMethod, outputMapper, setters );
}

private ReflectiveUserFunction compileFunction( Class<?> procDefinition, MethodHandle constructor, Method method )
private CallableUserFunction compileFunction( Class<?> procDefinition, MethodHandle constructor, Method method )
throws ProcedureException, IllegalAccessException
{
String valueName = method.getAnnotation( UserFunction.class ).value();
Expand All @@ -274,29 +276,45 @@ private ReflectiveUserFunction compileFunction( Class<?> procDefinition, MethodH
Class<?> returnType = method.getReturnType();
TypeMappers.NeoValueConverter valueConverter = typeMappers.converterFor( returnType );
MethodHandle procedureMethod = lookup.unreflect( method );
List<FieldInjections.FieldSetter> setters = safeFieldInjections.setters( procDefinition );

Optional<String> description = description( method );
UserFunction function = method.getAnnotation( UserFunction.class );

Optional<String> deprecated = deprecated( method, function::deprecatedBy,
"Use of @UserFunction(deprecatedBy) without @Deprecated in " + procName );

List<FieldInjections.FieldSetter> setters = allFieldInjections.setters( procDefinition );
if ( !config.fullAccessFor( procName.toString() ) )
{
try
{
setters = safeFieldInjections.setters( procDefinition );
}
catch ( ComponentInjectionException e )
{
description = Optional.of( procName.toString() +
" is not available due to not having unrestricted access rights, check configuration." );
log.warn( description.get() );
UserFunctionSignature signature =
new UserFunctionSignature( procName, inputSignature, valueConverter.type(), deprecated,
config.rolesFor( procName.toString() ), description );
return new LoadFailFunction( signature );
}
}

UserFunctionSignature signature =
new UserFunctionSignature( procName, inputSignature, valueConverter.type(), deprecated,
config.rolesFor( procName.toString() ), description );

return new ReflectiveUserFunction( signature, constructor, procedureMethod, valueConverter, setters );
}

private ReflectiveUserAggregationFunction compileAggregationFunction( Class<?> definition, MethodHandle constructor, Method method )
throws ProcedureException, IllegalAccessException
private CallableUserAggregationFunction compileAggregationFunction( Class<?> definition, MethodHandle constructor,
Method method ) throws ProcedureException, IllegalAccessException
{
String valueName = method.getAnnotation( UserAggregationFunction.class ).value();
String definedName = method.getAnnotation( UserAggregationFunction.class ).name();
QualifiedName funcName = extractName( definition, method, valueName, definedName );

if (funcName.namespace() == null || funcName.namespace().length == 0)
if ( funcName.namespace() == null || funcName.namespace().length == 0 )
{
throw new ProcedureException( Status.Procedure.ProcedureRegistrationFailed,
"It is not allowed to define functions in the root namespace please use a namespace, e.g. `@UserFunction(\"org.example.com.%s\")",
Expand All @@ -309,18 +327,18 @@ private ReflectiveUserAggregationFunction compileAggregationFunction( Class<?> d
Class<?> aggregator = method.getReturnType();
for ( Method m : aggregator.getDeclaredMethods() )
{
if (m.isAnnotationPresent( UserAggregationUpdate.class ))
{
if ( update != null )
{
throw new ProcedureException( Status.Procedure.ProcedureRegistrationFailed,
"Class '%s' contains multiple methods annotated with '@%s'.", aggregator.getSimpleName(),
UserAggregationUpdate.class.getSimpleName() );
}
update = m;

}
if (m.isAnnotationPresent( UserAggregationResult.class ))
if ( m.isAnnotationPresent( UserAggregationUpdate.class ) )
{
if ( update != null )
{
throw new ProcedureException( Status.Procedure.ProcedureRegistrationFailed,
"Class '%s' contains multiple methods annotated with '@%s'.", aggregator.getSimpleName(),
UserAggregationUpdate.class.getSimpleName() );
}
update = m;

}
if ( m.isAnnotationPresent( UserAggregationResult.class ) )
{
if ( result != null )
{
Expand All @@ -331,39 +349,37 @@ private ReflectiveUserAggregationFunction compileAggregationFunction( Class<?> d
result = m;
}
}
if ( result == null || update == null)
if ( result == null || update == null )
{
throw new ProcedureException( Status.Procedure.ProcedureRegistrationFailed,
"Class '%s' must contain methods annotated with both '@%s' as well as '@%s'.",
aggregator.getSimpleName(),
UserAggregationResult.class.getSimpleName(),
aggregator.getSimpleName(), UserAggregationResult.class.getSimpleName(),
UserAggregationUpdate.class.getSimpleName() );
}
if (update.getReturnType() != void.class)
if ( update.getReturnType() != void.class )
{
throw new ProcedureException( Status.Procedure.ProcedureRegistrationFailed,
"Update method '%s' in %s has type '%s' but must have return type 'void'.", update.getName(),
aggregator.getSimpleName(), update.getReturnType().getSimpleName() );

}
if ( !Modifier.isPublic(method.getModifiers()))
if ( !Modifier.isPublic( method.getModifiers() ) )
{
throw new ProcedureException( Status.Procedure.ProcedureRegistrationFailed,
"Aggregation method '%s' in %s must be public.", method.getName(),
definition.getSimpleName() );
"Aggregation method '%s' in %s must be public.", method.getName(), definition.getSimpleName() );
}
if ( !Modifier.isPublic(aggregator.getModifiers()))
if ( !Modifier.isPublic( aggregator.getModifiers() ) )
{
throw new ProcedureException( Status.Procedure.ProcedureRegistrationFailed,
"Aggregation class '%s' must be public.", aggregator.getSimpleName() );
}
if ( !Modifier.isPublic(update.getModifiers()))
if ( !Modifier.isPublic( update.getModifiers() ) )
{
throw new ProcedureException( Status.Procedure.ProcedureRegistrationFailed,
"Aggregation update method '%s' in %s must be public.", method.getName(),
aggregator.getSimpleName() );
}
if ( !Modifier.isPublic(result.getModifiers()))
if ( !Modifier.isPublic( result.getModifiers() ) )
{
throw new ProcedureException( Status.Procedure.ProcedureRegistrationFailed,
"Aggregation result method '%s' in %s must be public.", method.getName(),
Expand All @@ -376,19 +392,39 @@ private ReflectiveUserAggregationFunction compileAggregationFunction( Class<?> d
MethodHandle creator = lookup.unreflect( method );
MethodHandle updateMethod = lookup.unreflect( update );
MethodHandle resultMethod = lookup.unreflect( result );
List<FieldInjections.FieldSetter> setters = safeFieldInjections.setters( definition );

Optional<String> description = description( method );
UserAggregationFunction function = method.getAnnotation( UserAggregationFunction.class );

Optional<String> deprecated = deprecated( method, function::deprecatedBy,
"Use of @UserAggregationFunction(deprecatedBy) without @Deprecated in " + funcName );

List<FieldInjections.FieldSetter> setters = allFieldInjections.setters( definition );
if ( !config.fullAccessFor( funcName.toString() ) )
{
try
{
setters = safeFieldInjections.setters( definition );
}
catch ( ComponentInjectionException e )
{
description = Optional.of( funcName.toString() +
" is not available due to not having unrestricted access rights, check configuration." );
log.warn( description.get() );
UserFunctionSignature signature =
new UserFunctionSignature( funcName, inputSignature, valueConverter.type(), deprecated,
config.rolesFor( funcName.toString() ), description );

return new LoadFailAggregatedFunction( signature );
}
}

UserFunctionSignature signature =
new UserFunctionSignature( funcName, inputSignature, valueConverter.type(), deprecated,
config.rolesFor( funcName.toString() ), description );

return new ReflectiveUserAggregationFunction( signature, constructor, creator, updateMethod, resultMethod, valueConverter, setters );
return new ReflectiveUserAggregationFunction( signature, constructor, creator, updateMethod, resultMethod,
valueConverter, setters );
}

private Optional<String> deprecated( Method method, Supplier<String> supplier, String warning )
Expand Down
Expand Up @@ -69,7 +69,7 @@ public class ReflectiveUserAggregationFunctionTest
public void setUp() throws Exception
{
components = new ComponentRegistry();
procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(), components, new ComponentRegistry(),
procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(), components, components,
NullLog.getInstance(), ProcedureConfig.DEFAULT );
}

Expand Down
Expand Up @@ -65,7 +65,7 @@ public class ReflectiveUserFunctionTest
public void setUp() throws Exception
{
components = new ComponentRegistry();
procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(), components, new ComponentRegistry(),
procedureCompiler = new ReflectiveProcedureCompiler( new TypeMappers(), components, components,
NullLog.getInstance(), ProcedureConfig.DEFAULT );
}

Expand Down

0 comments on commit 18d7f4a

Please sign in to comment.