Skip to content

Commit

Permalink
Expose userAgent in listConnections procedure
Browse files Browse the repository at this point in the history
For bolt connections it is the same as user agent provided in
INIT/HELLO message. For HTTP/HTTPS connections it is extracted from
the User-Agent request header. This additional column is meant to
provide a bit more details about the client.
  • Loading branch information
lutovich committed Oct 12, 2018
1 parent 8a08917 commit 04e56c1
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 32 deletions.
Expand Up @@ -26,9 +26,11 @@
import javax.servlet.ServletResponse; import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.HttpHeaders;


import org.neo4j.graphdb.security.AuthorizationViolationException; import org.neo4j.graphdb.security.AuthorizationViolationException;
import org.neo4j.internal.kernel.api.security.LoginContext; import org.neo4j.internal.kernel.api.security.LoginContext;
import org.neo4j.server.web.JettyHttpConnection;


import static javax.servlet.http.HttpServletRequest.BASIC_AUTH; import static javax.servlet.http.HttpServletRequest.BASIC_AUTH;


Expand All @@ -46,8 +48,13 @@ public void doFilter( ServletRequest servletRequest, ServletResponse servletResp


try try
{ {
LoginContext loginContext = getAuthDisabledLoginContext();
String userAgent = request.getHeader( HttpHeaders.USER_AGENT );

JettyHttpConnection.updateUserForCurrentConnection( loginContext.subject().username(), userAgent );

filterChain.doFilter( filterChain.doFilter(
new AuthorizedRequestWrapper( BASIC_AUTH, "neo4j", request, getAuthDisabledLoginContext() ), new AuthorizedRequestWrapper( BASIC_AUTH, "neo4j", request, loginContext ),
servletResponse ); servletResponse );
} }
catch ( AuthorizationViolationException e ) catch ( AuthorizationViolationException e )
Expand Down
Expand Up @@ -80,6 +80,10 @@ public void doFilter( ServletRequest servletRequest, ServletResponse servletResp
final HttpServletRequest request = (HttpServletRequest) servletRequest; final HttpServletRequest request = (HttpServletRequest) servletRequest;
final HttpServletResponse response = (HttpServletResponse) servletResponse; final HttpServletResponse response = (HttpServletResponse) servletResponse;


String userAgent = request.getHeader( HttpHeaders.USER_AGENT );
// username is only known after authentication, make connection aware of the user-agent
JettyHttpConnection.updateUserForCurrentConnection( null, userAgent );

final String path = request.getContextPath() + ( request.getPathInfo() == null ? "" : request.getPathInfo() ); final String path = request.getContextPath() + ( request.getPathInfo() == null ? "" : request.getPathInfo() );


if ( request.getMethod().equals( "OPTIONS" ) || whitelisted( path ) ) if ( request.getMethod().equals( "OPTIONS" ) || whitelisted( path ) )
Expand Down Expand Up @@ -110,7 +114,8 @@ public void doFilter( ServletRequest servletRequest, ServletResponse servletResp
try try
{ {
LoginContext securityContext = authenticate( username, password ); LoginContext securityContext = authenticate( username, password );
JettyHttpConnection.updateUserForCurrentConnection( username, request.getHeader( HttpHeaders.USER_AGENT ) ); // username is now known, make connection aware of both username and user-agent
JettyHttpConnection.updateUserForCurrentConnection( username, userAgent );


switch ( securityContext.subject().getAuthenticationResult() ) switch ( securityContext.subject().getAuthenticationResult() )
{ {
Expand Down
Expand Up @@ -33,6 +33,7 @@ public class ListConnectionResult
public final String connectTime; public final String connectTime;
public final String connector; public final String connector;
public final String username; public final String username;
public final String userAgent;
public final String serverAddress; public final String serverAddress;
public final String clientAddress; public final String clientAddress;


Expand All @@ -42,6 +43,7 @@ public class ListConnectionResult
connectTime = ProceduresTimeFormatHelper.formatTime( connection.connectTime(), timeZone ); connectTime = ProceduresTimeFormatHelper.formatTime( connection.connectTime(), timeZone );
connector = connection.connector(); connector = connection.connector();
username = connection.username(); username = connection.username();
userAgent = connection.userAgent();
serverAddress = SocketAddress.format( connection.serverAddress() ); serverAddress = SocketAddress.format( connection.serverAddress() );
clientAddress = SocketAddress.format( connection.clientAddress() ); clientAddress = SocketAddress.format( connection.clientAddress() );
} }
Expand Down
100 changes: 70 additions & 30 deletions integrationtests/src/test/java/org/neo4j/net/ConnectionTrackingIT.java
Expand Up @@ -32,6 +32,7 @@
import java.net.URI; import java.net.URI;
import java.time.OffsetDateTime; import java.time.OffsetDateTime;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
Expand All @@ -41,6 +42,7 @@
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeoutException;
import javax.ws.rs.core.HttpHeaders;


import org.neo4j.bolt.v1.messaging.request.InitMessage; import org.neo4j.bolt.v1.messaging.request.InitMessage;
import org.neo4j.bolt.v1.messaging.request.PullAllMessage; import org.neo4j.bolt.v1.messaging.request.PullAllMessage;
Expand Down Expand Up @@ -68,7 +70,6 @@
import org.neo4j.values.storable.Value; import org.neo4j.values.storable.Value;


import static java.time.format.DateTimeFormatter.ISO_OFFSET_DATE_TIME; import static java.time.format.DateTimeFormatter.ISO_OFFSET_DATE_TIME;
import static java.util.Arrays.asList;
import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.MINUTES;
import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toList;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
Expand All @@ -87,6 +88,9 @@
import static org.neo4j.helpers.collection.MapUtil.map; import static org.neo4j.helpers.collection.MapUtil.map;
import static org.neo4j.kernel.api.exceptions.Status.Transaction.Terminated; import static org.neo4j.kernel.api.exceptions.Status.Transaction.Terminated;
import static org.neo4j.kernel.impl.enterprise.configuration.OnlineBackupSettings.online_backup_enabled; import static org.neo4j.kernel.impl.enterprise.configuration.OnlineBackupSettings.online_backup_enabled;
import static org.neo4j.net.ConnectionTrackingIT.TestConnector.BOLT;
import static org.neo4j.net.ConnectionTrackingIT.TestConnector.HTTP;
import static org.neo4j.net.ConnectionTrackingIT.TestConnector.HTTPS;
import static org.neo4j.server.configuration.ServerSettings.webserver_max_threads; import static org.neo4j.server.configuration.ServerSettings.webserver_max_threads;
import static org.neo4j.test.assertion.Assert.assertEventually; import static org.neo4j.test.assertion.Assert.assertEventually;
import static org.neo4j.test.server.HTTP.RawPayload; import static org.neo4j.test.server.HTTP.RawPayload;
Expand All @@ -103,6 +107,9 @@ public class ConnectionTrackingIT
private static final String OTHER_USER = "otherUser"; private static final String OTHER_USER = "otherUser";
private static final String OTHER_USER_PWD = "test"; private static final String OTHER_USER_PWD = "test";


private static final List<String> LIST_CONNECTIONS_PROCEDURE_COLUMNS = Arrays.asList(
"connectionId", "connectTime", "connector", "username", "userAgent", "serverAddress", "clientAddress" );

@ClassRule @ClassRule
public static final Neo4jRule neo4j = new EnterpriseNeo4jRule() public static final Neo4jRule neo4j = new EnterpriseNeo4jRule()
.withConfig( auth_enabled, "true" ) .withConfig( auth_enabled, "true" )
Expand Down Expand Up @@ -155,9 +162,9 @@ public void afterEach() throws Exception
@Test @Test
public void shouldListNoConnectionsWhenIdle() throws Exception public void shouldListNoConnectionsWhenIdle() throws Exception
{ {
verifyConnectionCount( "http", null, 0 ); verifyConnectionCount( HTTP, null, 0 );
verifyConnectionCount( "https", null, 0 ); verifyConnectionCount( HTTPS, null, 0 );
verifyConnectionCount( "bolt", null, 0 ); verifyConnectionCount( BOLT, null, 0 );
} }


@Test @Test
Expand Down Expand Up @@ -199,8 +206,8 @@ public void shouldListAuthenticatedHttpConnections() throws Exception
} }


awaitNumberOfAuthenticatedConnectionsToBe( 7 ); awaitNumberOfAuthenticatedConnectionsToBe( 7 );
verifyConnectionCount( "http", "neo4j", 4 ); verifyAuthenticatedConnectionCount( HTTP, "neo4j", 4 );
verifyConnectionCount( "http", OTHER_USER, 3 ); verifyAuthenticatedConnectionCount( HTTP, OTHER_USER, 3 );
} ); } );
} }


Expand All @@ -219,8 +226,8 @@ public void shouldListAuthenticatedHttpsConnections() throws Exception
} }


awaitNumberOfAuthenticatedConnectionsToBe( 9 ); awaitNumberOfAuthenticatedConnectionsToBe( 9 );
verifyConnectionCount( "https", "neo4j", 4 ); verifyAuthenticatedConnectionCount( HTTPS, "neo4j", 4 );
verifyConnectionCount( "https", OTHER_USER, 5 ); verifyAuthenticatedConnectionCount( HTTPS, OTHER_USER, 5 );
} ); } );
} }


Expand All @@ -239,8 +246,8 @@ public void shouldListAuthenticatedBoltConnections() throws Exception
} }


awaitNumberOfAuthenticatedConnectionsToBe( 7 ); awaitNumberOfAuthenticatedConnectionsToBe( 7 );
verifyConnectionCount( "bolt", "neo4j", 2 ); verifyAuthenticatedConnectionCount( BOLT, "neo4j", 2 );
verifyConnectionCount( "bolt", OTHER_USER, 5 ); verifyAuthenticatedConnectionCount( BOLT, OTHER_USER, 5 );
} ); } );
} }


Expand All @@ -263,28 +270,28 @@ public void shouldListAuthenticatedConnections() throws Exception
} }


awaitNumberOfAuthenticatedConnectionsToBe( 10 ); awaitNumberOfAuthenticatedConnectionsToBe( 10 );
verifyConnectionCount( "bolt", OTHER_USER, 4 ); verifyConnectionCount( BOLT, OTHER_USER, 4 );
verifyConnectionCount( "http", "neo4j", 1 ); verifyConnectionCount( HTTP, "neo4j", 1 );
verifyConnectionCount( "https", "neo4j", 5 ); verifyConnectionCount( HTTPS, "neo4j", 5 );
} ); } );
} }


@Test @Test
public void shouldKillHttpConnection() throws Exception public void shouldKillHttpConnection() throws Exception
{ {
testKillingOfConnections( neo4j.httpURI(), "http", 4 ); testKillingOfConnections( neo4j.httpURI(), HTTP, 4 );
} }


@Test @Test
public void shouldKillHttpsConnection() throws Exception public void shouldKillHttpsConnection() throws Exception
{ {
testKillingOfConnections( neo4j.httpsURI(), "https", 2 ); testKillingOfConnections( neo4j.httpsURI(), HTTPS, 2 );
} }


@Test @Test
public void shouldKillBoltConnection() throws Exception public void shouldKillBoltConnection() throws Exception
{ {
testKillingOfConnections( neo4j.boltURI(), "bolt", 3 ); testKillingOfConnections( neo4j.boltURI(), BOLT, 3 );
} }


private void testListingOfUnauthenticatedConnections( int httpCount, int httpsCount, int boltCount ) throws Exception private void testListingOfUnauthenticatedConnections( int httpCount, int httpsCount, int boltCount ) throws Exception
Expand All @@ -306,12 +313,12 @@ private void testListingOfUnauthenticatedConnections( int httpCount, int httpsCo


awaitNumberOfAcceptedConnectionsToBe( httpCount + httpsCount + boltCount ); awaitNumberOfAcceptedConnectionsToBe( httpCount + httpsCount + boltCount );


verifyConnectionCount( "http", null, httpCount ); verifyConnectionCount( HTTP, null, httpCount );
verifyConnectionCount( "https", null, httpsCount ); verifyConnectionCount( HTTPS, null, httpsCount );
verifyConnectionCount( "bolt", null, boltCount ); verifyConnectionCount( BOLT, null, boltCount );
} }


private void testKillingOfConnections( URI uri, String connector, int count ) throws Exception private void testKillingOfConnections( URI uri, TestConnector connector, int count ) throws Exception
{ {
List<TransportConnection> socketConnections = new ArrayList<>(); List<TransportConnection> socketConnections = new ArrayList<>();
for ( int i = 0; i < count; i++ ) for ( int i = 0; i < count; i++ )
Expand Down Expand Up @@ -353,17 +360,28 @@ private static void awaitNumberOfAcceptedConnectionsToBe( int n ) throws Interru
1, MINUTES ); 1, MINUTES );
} }


private static void verifyConnectionCount( String connector, String username, int expectedCount ) throws InterruptedException private static void verifyConnectionCount( TestConnector connector, String username, int expectedCount ) throws InterruptedException
{
verifyConnectionCount( connector, username, expectedCount, false );
}

private static void verifyAuthenticatedConnectionCount( TestConnector connector, String username, int expectedCount ) throws InterruptedException
{
verifyConnectionCount( connector, username, expectedCount, true );
}

private static void verifyConnectionCount( TestConnector connector, String username, int expectedCount, boolean expectAuthenticated )
throws InterruptedException
{ {
assertEventually( connections -> "Unexpected number of listed connections: " + connections, assertEventually( connections -> "Unexpected number of listed connections: " + connections,
() -> listMatchingConnection( connector, username ), hasSize( expectedCount ), () -> listMatchingConnection( connector, username, expectAuthenticated ), hasSize( expectedCount ),
1, MINUTES ); 1, MINUTES );
} }


private static List<Map<String,Object>> listMatchingConnection( String connector, String username ) private static List<Map<String,Object>> listMatchingConnection( TestConnector connector, String username, boolean expectAuthenticated )
{ {
Result result = neo4j.getGraphDatabaseService().execute( "CALL dbms.listConnections()" ); Result result = neo4j.getGraphDatabaseService().execute( "CALL dbms.listConnections()" );
assertEquals( asList( "connectionId", "connectTime", "connector", "username", "serverAddress", "clientAddress" ), result.columns() ); assertEquals( LIST_CONNECTIONS_PROCEDURE_COLUMNS, result.columns() );
List<Map<String,Object>> records = result.stream().collect( toList() ); List<Map<String,Object>> records = result.stream().collect( toList() );


List<Map<String,Object>> matchingRecords = new ArrayList<>(); List<Map<String,Object>> matchingRecords = new ArrayList<>();
Expand All @@ -372,8 +390,13 @@ private static List<Map<String,Object>> listMatchingConnection( String connector
String actualConnector = record.get( "connector" ).toString(); String actualConnector = record.get( "connector" ).toString();
assertNotNull( actualConnector ); assertNotNull( actualConnector );
Object actualUsername = record.get( "username" ); Object actualUsername = record.get( "username" );
if ( Objects.equals( connector, actualConnector ) && Objects.equals( username, actualUsername ) ) if ( Objects.equals( connector.name, actualConnector ) && Objects.equals( username, actualUsername ) )
{ {
if ( expectAuthenticated )
{
assertEquals( connector.userAgent, record.get( "userAgent" ) );
}

matchingRecords.add( record ); matchingRecords.add( record );
} }


Expand Down Expand Up @@ -463,11 +486,12 @@ private Future<Response> updateNodeViaHttps( long id, String username, String pa
private Future<Response> updateNodeViaHttp( long id, boolean encrypted, String username, String password ) private Future<Response> updateNodeViaHttp( long id, boolean encrypted, String username, String password )
{ {
String uri = txCommitUri( encrypted ); String uri = txCommitUri( encrypted );
String userAgent = encrypted ? HTTPS.userAgent : HTTP.userAgent;

return executor.submit( () -> return executor.submit( () ->
{ withBasicAuth( username, password )
return withBasicAuth( username, password ) .withHeaders( HttpHeaders.USER_AGENT, userAgent )
.POST( uri, query( "MATCH (n) WHERE id(n) = " + id + " SET n.prop = 42" ) ); .POST( uri, query( "MATCH (n) WHERE id(n) = " + id + " SET n.prop = 42" ) )
}
); );
} }


Expand Down Expand Up @@ -566,6 +590,22 @@ private static RawPayload query( String statement )
private static InitMessage initMessage( String username, String password ) private static InitMessage initMessage( String username, String password )
{ {
Map<String,Object> authToken = map( "scheme", "basic", "principal", username, "credentials", password ); Map<String,Object> authToken = map( "scheme", "basic", "principal", username, "credentials", password );
return new InitMessage( "TestClient", authToken ); return new InitMessage( BOLT.userAgent, authToken );
}

enum TestConnector
{
HTTP( "http", "http-user-agent" ),
HTTPS( "https", "https-user-agent" ),
BOLT( "bolt", "bolt-user-agent" );

final String name;
final String userAgent;

TestConnector( String name, String userAgent )
{
this.name = name;
this.userAgent = userAgent;
}
} }
} }

0 comments on commit 04e56c1

Please sign in to comment.