Skip to content

Commit

Permalink
Configurable CORS header
Browse files Browse the repository at this point in the history
  • Loading branch information
technige authored and ali-ince committed Apr 25, 2018
1 parent f836516 commit ab973ef
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 14 deletions.
Expand Up @@ -850,6 +850,13 @@ public enum LabelIndex
public static final Setting<String> default_advertised_address =
setting( "dbms.connectors.default_advertised_address", STRING, "localhost" );

@Description( "Value of the Access-Control-Allow-Origin header sent over any HTTP or HTTPS " +
"connector. This defaults to '*', which allows broadest compatibility but is " +
"least secure. Note that any URI provided here limits HTTP/HTTPS access to " +
"that URI only." )
public static final Setting<String> access_control_allow_origin =
setting( "dbms.connectors.access_control_allow_origin", STRING, "*" );

@Internal
public static final Setting<Boolean> bolt_logging_enabled = setting( "unsupported.dbms.logs.bolt.enabled",
BOOLEAN, FALSE );
Expand Down
Expand Up @@ -19,6 +19,7 @@
*/
package org.neo4j.server.modules;

import java.util.Collections;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import javax.servlet.Filter;
Expand Down Expand Up @@ -63,7 +64,7 @@ public void start()
authorizationFilter = createAuthorizationDisabledFilter();
}

webServer.addFilter( authorizationFilter, "/*" );
webServer.addFilter( authorizationFilter, "/*", Collections.emptyMap() );
}

@Override
Expand Down
Expand Up @@ -20,6 +20,7 @@
package org.neo4j.server.modules;

import java.net.URI;
import java.util.Collections;
import java.util.List;

import org.neo4j.concurrent.RecentK;
Expand All @@ -42,6 +43,7 @@
import org.neo4j.udc.UsageDataKeys;

import static java.util.Arrays.asList;
import static org.neo4j.graphdb.factory.GraphDatabaseSettings.access_control_allow_origin;

/**
* Mounts the database REST API.
Expand Down Expand Up @@ -69,8 +71,9 @@ public void start()
{
URI restApiUri = restApiUri( );

webServer.addFilter( new CollectUserAgentFilter( clientNames() ), "/*" );
webServer.addFilter( new CorsFilter( logProvider ), "/*" );
webServer.addFilter( new CollectUserAgentFilter( clientNames() ), "/*", Collections.emptyMap() );
webServer.addFilter( new CorsFilter( logProvider ), "/*", Collections.singletonMap(
"access_control_allow_origin", config.get( access_control_allow_origin ) ) );
webServer.addJAXRSClasses( getClassNames(), restApiUri.toString(), null );
loadPlugins();
}
Expand Down
Expand Up @@ -20,6 +20,7 @@
package org.neo4j.server.modules;

import java.util.ArrayList;
import java.util.Collections;

import org.neo4j.helpers.collection.Iterables;
import org.neo4j.kernel.configuration.Config;
Expand Down Expand Up @@ -53,7 +54,7 @@ public void start()
{
mountedFilter = new SecurityFilter( securityRules );

webServer.addFilter( mountedFilter, "/*" );
webServer.addFilter( mountedFilter, "/*", Collections.emptyMap() );

for ( SecurityRule rule : securityRules )
{
Expand Down
Expand Up @@ -48,9 +48,12 @@ public class CorsFilter implements Filter
public static final String ACCESS_CONTROL_ALLOW_HEADERS = "Access-Control-Allow-Headers";
public static final String ACCESS_CONTROL_REQUEST_METHOD = "Access-Control-Request-Method";
public static final String ACCESS_CONTROL_REQUEST_HEADERS = "Access-Control-Request-Headers";
public static final String VARY = "Vary";

private final Log log;

private FilterConfig filterConfig;

public CorsFilter( LogProvider logProvider )
{
this.log = logProvider.getLog( getClass() );
Expand All @@ -59,6 +62,7 @@ public CorsFilter( LogProvider logProvider )
@Override
public void init( FilterConfig filterConfig ) throws ServletException
{
this.filterConfig = filterConfig;
}

@Override
Expand All @@ -68,7 +72,22 @@ public void doFilter( ServletRequest servletRequest, ServletResponse servletResp
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;

response.setHeader( ACCESS_CONTROL_ALLOW_ORIGIN, "*" );
String uri = "*";
if ( filterConfig != null )
{
uri = filterConfig.getInitParameter( "access_control_allow_origin" );
}
response.setHeader( ACCESS_CONTROL_ALLOW_ORIGIN, uri );
if ( !"*".equals( uri ) )
{
// If the server specifies an origin host rather than "*", then it must also include Origin in
// the Vary response header to indicate to clients that server responses will differ based on
// the value of the Origin request header.
//
// -- https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
//
response.setHeader( VARY, "Origin" );
}

Enumeration<String> requestMethodEnumeration = request.getHeaders( ACCESS_CONTROL_REQUEST_METHOD );
if ( requestMethodEnumeration != null )
Expand Down Expand Up @@ -96,6 +115,7 @@ public void doFilter( ServletRequest servletRequest, ServletResponse servletResp
@Override
public void destroy()
{
this.filterConfig = null;
}

private void addAllowedMethodIfValid( String methodName, HttpServletResponse response )
Expand Down
Expand Up @@ -260,9 +260,9 @@ public void removeJAXRSClasses( List<String> classNames, String serverMountPoint
}

@Override
public void addFilter( Filter filter, String pathSpec )
public void addFilter( Filter filter, String pathSpec, Map<String, String> initParameters )
{
filters.add( new FilterDefinition( filter, pathSpec ) );
filters.add( new FilterDefinition( filter, pathSpec, initParameters ) );
}

@Override
Expand Down Expand Up @@ -511,9 +511,9 @@ private void addFiltersTo( ServletContextHandler context )
{
for ( FilterDefinition filterDef : filters )
{
context.addFilter( new FilterHolder( filterDef.getFilter() ),
filterDef.getPathSpec(), EnumSet.allOf( DispatcherType.class )
);
FilterHolder filterHolder = new FilterHolder( filterDef.getFilter() );
filterHolder.setInitParameters( filterDef.initParameters );
context.addFilter( filterHolder, filterDef.getPathSpec(), EnumSet.allOf( DispatcherType.class ) );
}
}

Expand All @@ -526,11 +526,13 @@ private static class FilterDefinition
{
private final Filter filter;
private final String pathSpec;
private final Map<String, String> initParameters;

FilterDefinition( Filter filter, String pathSpec )
FilterDefinition( Filter filter, String pathSpec, Map<String, String> initParameters )
{
this.filter = filter;
this.pathSpec = pathSpec;
this.initParameters = initParameters;
}

public boolean matches( Filter filter, String pathSpec )
Expand Down
Expand Up @@ -26,6 +26,7 @@
import java.net.InetSocketAddress;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import javax.servlet.Filter;
Expand Down Expand Up @@ -60,7 +61,7 @@ public interface WebServer
void addJAXRSClasses( List<String> classNames, String serverMountPoint, Collection<Injectable<?>> injectables );
void removeJAXRSClasses( List<String> classNames, String serverMountPoint );

void addFilter( Filter filter, String pathSpec );
void addFilter( Filter filter, String pathSpec, Map<String, String> initParameters );

void removeFilter( Filter filter, String pathSpec );

Expand Down
Expand Up @@ -107,6 +107,18 @@ public void shouldAddCorsMethodsHeader() throws Exception
testCorsAllowMethods( DELETE );
}

@Test
public void shouldAddCorsHeaderWhenConfigured() throws Exception
{
String origin = "https://example.com:7687";
startServer( false, origin );

testCorsAllowMethods( POST, origin );
testCorsAllowMethods( GET, origin );
testCorsAllowMethods( PATCH, origin );
testCorsAllowMethods( DELETE, origin );
}

@Test
public void shouldAddCorsRequestHeaders() throws Exception
{
Expand All @@ -122,13 +134,18 @@ public void shouldAddCorsRequestHeaders() throws Exception
}

private void testCorsAllowMethods( HttpMethod method ) throws Exception
{
testCorsAllowMethods( method, "*" );
}

private void testCorsAllowMethods( HttpMethod method, String origin ) throws Exception
{
HTTP.Builder requestBuilder = requestWithHeaders( "authDisabled", "authDisabled" )
.withHeaders( ACCESS_CONTROL_REQUEST_METHOD, method.toString() );
HTTP.Response response = runQuery( requestBuilder );

assertEquals( OK.getStatusCode(), response.status() );
assertCorsHeaderPresent( response );
assertCorsHeaderEquals( response, origin );
assertEquals( method, HttpMethod.valueOf( response.header( ACCESS_CONTROL_ALLOW_METHODS ) ) );
}

Expand Down Expand Up @@ -160,6 +177,11 @@ HttpHeaders.AUTHORIZATION, basicAuthHeader( username, password )

private static void assertCorsHeaderPresent( HTTP.Response response )
{
assertEquals( "*", response.header( ACCESS_CONTROL_ALLOW_ORIGIN ) );
assertCorsHeaderEquals( response, "*" );
}

private static void assertCorsHeaderEquals( HTTP.Response response, String origin )
{
assertEquals( origin, response.header( ACCESS_CONTROL_ALLOW_ORIGIN ) );
}
}
Expand Up @@ -50,6 +50,15 @@ protected void startServer( boolean authEnabled ) throws IOException
server.start();
}

protected void startServer( boolean authEnabled, String accessControlAllowOrigin ) throws IOException
{
server = CommunityServerBuilder.serverOnRandomPorts()
.withProperty( GraphDatabaseSettings.auth_enabled.name(), Boolean.toString( authEnabled ) )
.withProperty( GraphDatabaseSettings.access_control_allow_origin.name(), accessControlAllowOrigin )
.build();
server.start();
}

protected String basicAuthHeader( String username, String password )
{
String usernamePassword = username + ':' + password;
Expand Down
Expand Up @@ -61,6 +61,12 @@ dbms.directories.import=import
# individual connectors below.
#dbms.connectors.default_advertised_address=localhost

# Value of the Access-Control-Allow-Origin header sent over any HTTP or HTTPS
# connector. This defaults to '*', which allows broadest compatibility but is
# least secure. Note that any URI provided here limits HTTP/HTTPS access to
# that URI only.
#dbms.connectors.access_control_allow_origin=*

# You can also choose a specific advertised hostname or IP address, and
# configure an advertised port for each connector, by setting their
# individual advertised_address.
Expand Down

0 comments on commit ab973ef

Please sign in to comment.