Skip to content

Commit

Permalink
Change login method take an auth token as a map
Browse files Browse the repository at this point in the history
Bolt already gets authentication information from the client as a
Map<String,Object>.
Instead of extracting username and password within Bolt BasicAuthentication,
we now pass the map to AuthManager.login(), so that we can cater for more
advanced authentication mechanisms.
  • Loading branch information
henriknyman committed Jun 13, 2016
1 parent 4eefc68 commit be44158
Show file tree
Hide file tree
Showing 17 changed files with 339 additions and 155 deletions.
8 changes: 7 additions & 1 deletion community/bolt/pom.xml
Expand Up @@ -51,7 +51,7 @@

<dependency>
<groupId>org.neo4j</groupId>
<artifactId>neo4j-security</artifactId>
<artifactId>neo4j-kernel</artifactId>
<version>${project.version}</version>
</dependency>

Expand Down Expand Up @@ -89,6 +89,12 @@
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.neo4j</groupId>
<artifactId>neo4j-security</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.neo4j</groupId>
<artifactId>neo4j-io</artifactId>
Expand Down
Expand Up @@ -49,9 +49,4 @@ public interface Authentication
* Allows all tokens to authenticate.
*/
Authentication NONE = authToken -> AuthenticationResult.AUTH_DISABLED;

String SCHEME_KEY = "scheme";
String PRINCIPAL = "principal";
String CREDENTIALS = "credentials";
String NEW_CREDENTIALS = "new_credentials";
}
Expand Up @@ -25,12 +25,17 @@

import org.neo4j.graphdb.security.AuthorizationViolationException;
import org.neo4j.kernel.api.exceptions.Status;
import org.neo4j.kernel.api.security.AccessMode;
import org.neo4j.logging.Log;
import org.neo4j.logging.LogProvider;
import org.neo4j.kernel.api.security.AuthSubject;
import org.neo4j.kernel.api.security.AuthManager;
import org.neo4j.kernel.api.security.AuthSubject;
import org.neo4j.kernel.api.security.AuthToken;
import org.neo4j.kernel.api.security.exception.IllegalCredentialsException;
import org.neo4j.kernel.api.security.exception.InvalidAuthTokenException;
import org.neo4j.logging.Log;
import org.neo4j.logging.LogProvider;

import static org.neo4j.kernel.api.security.AuthToken.NEW_CREDENTIALS;
import static org.neo4j.kernel.api.security.AuthToken.PRINCIPAL;
import static org.neo4j.kernel.api.security.AuthToken.SCHEME_KEY;

/**
* Performs basic authentication with user name and password.
Expand Down Expand Up @@ -58,80 +63,82 @@ public AuthenticationResult authenticate( Map<String,Object> authToken ) throws
"Authentication token must contain: '" + SCHEME_KEY + " : " + SCHEME + "'" );
}

String user = safeCast( PRINCIPAL, authToken );
String password = safeCast( CREDENTIALS, authToken );
if ( authToken.containsKey( NEW_CREDENTIALS ) )
{
return update( user, password, safeCast( NEW_CREDENTIALS, authToken ) );
return update( authToken );
}
else
{
return authenticate( user, password );
return doAuthenticate( authToken );
}
}

private AuthenticationResult authenticate( String user, String password ) throws AuthenticationException
private AuthenticationResult doAuthenticate( Map<String,Object> authToken ) throws AuthenticationException
{
AuthSubject authSubject = authManager.login( user, password );
boolean credentialsExpired = false;
switch ( authSubject.getAuthenticationResult() )
try
{
case SUCCESS:
break;
case PASSWORD_CHANGE_REQUIRED:
credentialsExpired = true;
break;
case TOO_MANY_ATTEMPTS:
throw new AuthenticationException( Status.Security.AuthenticationRateLimit, identifier.get() );
default:
log.warn( "Failed authentication attempt for '%s'", user);
throw new AuthenticationException( Status.Security.Unauthorized, identifier.get() );
AuthSubject authSubject = authManager.login( authToken );

boolean credentialsExpired = false;
switch ( authSubject.getAuthenticationResult() )
{
case SUCCESS:
break;
case PASSWORD_CHANGE_REQUIRED:
credentialsExpired = true;
break;
case TOO_MANY_ATTEMPTS:
throw new AuthenticationException( Status.Security.AuthenticationRateLimit, identifier.get() );
default:
log.warn( "Failed authentication attempt for '%s'", AuthToken.safeCast( PRINCIPAL, authToken ) );
throw new AuthenticationException( Status.Security.Unauthorized, identifier.get() );
}

return new BasicAuthenticationResult( authSubject, credentialsExpired );
}
catch ( InvalidAuthTokenException e )
{
throw new AuthenticationException( e.status(), identifier.get(), e.getMessage() );
}
return new BasicAuthenticationResult( authSubject, credentialsExpired );
}

private AuthenticationResult update( String user, String password, String newPassword ) throws AuthenticationException
private AuthenticationResult update( Map<String,Object> authToken ) throws AuthenticationException
{
AuthSubject authSubject = authManager.login( user, password );
switch ( authSubject.getAuthenticationResult() )
try
{
case SUCCESS:
case PASSWORD_CHANGE_REQUIRED:
try
String newPassword = AuthToken.safeCast( NEW_CREDENTIALS, authToken );

AuthSubject authSubject = authManager.login( authToken );

switch ( authSubject.getAuthenticationResult() )
{
case SUCCESS:
case PASSWORD_CHANGE_REQUIRED:
authSubject.setPassword( newPassword );
//re-authenticate user
authSubject = authManager.login( user, newPassword );
}
catch ( AuthorizationViolationException e )
{
throw new AuthenticationException( Status.Security.Forbidden, identifier.get(), e.getMessage(), e );
authSubject = authManager.login( authToken );
break;
default:
throw new AuthenticationException( Status.Security.Unauthorized, identifier.get() );
}
catch ( IOException e )
{
throw new AuthenticationException( Status.Security.Unauthorized, identifier.get(), e.getMessage(), e );
}
catch ( IllegalCredentialsException e )
{
throw new AuthenticationException(e.status(), identifier.get(), e.getMessage(), e );
}
break;
default:
throw new AuthenticationException( Status.Security.Unauthorized, identifier.get() );
}
return new BasicAuthenticationResult( authSubject, false );
}

private String safeCast( String key, Map<String,Object> authToken ) throws AuthenticationException
{
Object value = authToken.get( key );
if ( value == null || !(value instanceof String) )
return new BasicAuthenticationResult( authSubject, false );
}
catch ( AuthorizationViolationException e )
{
throw new AuthenticationException( Status.Security.Unauthorized, identifier.get(),
"The value associated with the key `" + key + "` must be a String but was: " +
(value == null ? "null" : value.getClass().getSimpleName()));
throw new AuthenticationException( e.status(), identifier.get(), e.getMessage(), e );
}
catch ( IOException e )
{
throw new AuthenticationException( Status.Security.Unauthorized, identifier.get(), e.getMessage(), e );
}
catch ( IllegalCredentialsException e )
{
throw new AuthenticationException( e.status(), identifier.get(), e.getMessage(), e );
}
catch ( InvalidAuthTokenException e )
{
throw new AuthenticationException( e.status(), identifier.get(), e.getMessage() );
}

return (String) value;
}
}
Expand Up @@ -38,6 +38,7 @@
import org.neo4j.kernel.api.exceptions.Status;
import org.neo4j.kernel.api.security.AccessMode;
import org.neo4j.kernel.api.security.AuthSubject;
import org.neo4j.kernel.api.security.AuthToken;
import org.neo4j.kernel.impl.core.ThreadToStatementContextBridge;
import org.neo4j.kernel.impl.coreapi.InternalTransaction;
import org.neo4j.kernel.impl.coreapi.PropertyContainerLocker;
Expand Down Expand Up @@ -80,7 +81,7 @@ public State init( SessionStateMachine ctx, String clientName, Map<String,Object
ctx.credentialsExpired = authResult.credentialsExpired();
ctx.result( authResult.credentialsExpired() );
ctx.spi.udcRegisterClient( clientName );
ctx.setQuerySourceFromClientNameAndPrincipal( clientName, authToken.get( Authentication.PRINCIPAL ) );
ctx.setQuerySourceFromClientNameAndPrincipal( clientName, authToken.get( AuthToken.PRINCIPAL ) );
return IDLE;
}
catch ( AuthenticationException e )
Expand Down
Expand Up @@ -35,7 +35,7 @@
import org.neo4j.server.security.auth.BasicAuthSubject;

import static java.util.Collections.singletonList;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.anyMap;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand All @@ -49,13 +49,13 @@ public class BasicAuthenticationTest
private final Supplier<String> identifier = () -> "UNIQUE";

@Test
public void shouldNotDoAnythingOnSuccess() throws AuthenticationException
public void shouldNotDoAnythingOnSuccess() throws Exception
{
// Given
BasicAuthManager manager = mock( BasicAuthManager.class );
BasicAuthSubject authSubject = mock( BasicAuthSubject.class );
BasicAuthentication authentication = new BasicAuthentication( manager, mock( LogProvider.class ), identifier );
when( manager.login( anyString(), anyString() ) ).thenReturn( authSubject );
when( manager.login( anyMap() ) ).thenReturn( authSubject );
when( authSubject.getAuthenticationResult() ).thenReturn( AuthenticationResult.SUCCESS );

//Expect nothing
Expand All @@ -65,7 +65,7 @@ public void shouldNotDoAnythingOnSuccess() throws AuthenticationException
}

@Test
public void shouldThrowAndLogOnFailure() throws AuthenticationException
public void shouldThrowAndLogOnFailure() throws Exception
{
// Given
BasicAuthManager manager = mock( BasicAuthManager.class );
Expand All @@ -74,7 +74,7 @@ public void shouldThrowAndLogOnFailure() throws AuthenticationException
LogProvider logProvider = mock( LogProvider.class );
when( logProvider.getLog( BasicAuthentication.class ) ).thenReturn( log );
BasicAuthentication authentication = new BasicAuthentication( manager, logProvider, identifier );
when( manager.login( anyString(), anyString() ) ).thenReturn( authSubject );
when( manager.login( anyMap() ) ).thenReturn( authSubject );
when( authSubject.getAuthenticationResult() ).thenReturn( AuthenticationResult.FAILURE );

// Expect
Expand All @@ -90,13 +90,13 @@ public void shouldThrowAndLogOnFailure() throws AuthenticationException
}

@Test
public void shouldIndicateThatCredentialsExpired() throws AuthenticationException
public void shouldIndicateThatCredentialsExpired() throws Exception
{
// Given
BasicAuthManager manager = mock( BasicAuthManager.class );
BasicAuthSubject authSubject = mock( BasicAuthSubject.class );
BasicAuthentication authentication = new BasicAuthentication( manager, mock( LogProvider.class ), identifier );
when( manager.login( anyString(), anyString() ) ).thenReturn( authSubject );
when( manager.login( anyMap() ) ).thenReturn( authSubject );
when( authSubject.getAuthenticationResult() ).thenReturn( AuthenticationResult.PASSWORD_CHANGE_REQUIRED );

// Expect
Expand All @@ -110,13 +110,13 @@ public void shouldIndicateThatCredentialsExpired() throws AuthenticationExceptio
}

@Test
public void shouldFailWhenTooManyAttempts() throws AuthenticationException
public void shouldFailWhenTooManyAttempts() throws Exception
{
// Given
BasicAuthManager manager = mock( BasicAuthManager.class );
BasicAuthSubject authSubject = mock( BasicAuthSubject.class );
BasicAuthentication authentication = new BasicAuthentication( manager, mock( LogProvider.class ), identifier );
when( manager.login( anyString(), anyString() ) ).thenReturn( authSubject );
when( manager.login( anyMap() ) ).thenReturn( authSubject );
when( authSubject.getAuthenticationResult() ).thenReturn( AuthenticationResult.TOO_MANY_ATTEMPTS );

// Expect
Expand All @@ -129,13 +129,13 @@ public void shouldFailWhenTooManyAttempts() throws AuthenticationException
}

@Test
public void shouldBeAbleToUpdateCredentials() throws AuthenticationException
public void shouldBeAbleToUpdateCredentials() throws Exception
{
// Given
BasicAuthManager manager = mock( BasicAuthManager.class );
BasicAuthSubject authSubject = mock( BasicAuthSubject.class );
BasicAuthentication authentication = new BasicAuthentication( manager, mock( LogProvider.class ), identifier );
when( manager.login( anyString(), anyString() ) ).thenReturn( authSubject );
when( manager.login( anyMap() ) ).thenReturn( authSubject );
when( authSubject.getAuthenticationResult() ).thenReturn( AuthenticationResult.SUCCESS );

//Expect nothing
Expand All @@ -146,13 +146,13 @@ public void shouldBeAbleToUpdateCredentials() throws AuthenticationException
}

@Test
public void shouldBeAbleToUpdateExpiredCredentials() throws AuthenticationException
public void shouldBeAbleToUpdateExpiredCredentials() throws Exception
{
// Given
BasicAuthManager manager = mock( BasicAuthManager.class );
BasicAuthSubject authSubject = mock( BasicAuthSubject.class );
BasicAuthentication authentication = new BasicAuthentication( manager, mock( LogProvider.class ), identifier );
when( manager.login( anyString(), anyString() ) ).thenReturn( authSubject );
when( manager.login( anyMap() ) ).thenReturn( authSubject );
when( authSubject.getAuthenticationResult() ).thenReturn( AuthenticationResult.PASSWORD_CHANGE_REQUIRED );

//Expect nothing
Expand All @@ -163,13 +163,13 @@ public void shouldBeAbleToUpdateExpiredCredentials() throws AuthenticationExcept
}

@Test
public void shouldNotBeAbleToUpdateCredentialsIfOldCredentialsAreInvalid() throws AuthenticationException
public void shouldNotBeAbleToUpdateCredentialsIfOldCredentialsAreInvalid() throws Exception
{
// Given
BasicAuthManager manager = mock( BasicAuthManager.class );
BasicAuthSubject authSubject = mock( BasicAuthSubject.class );
BasicAuthentication authentication = new BasicAuthentication( manager, mock( LogProvider.class ), identifier );
when( manager.login( anyString(), anyString() ) ).thenReturn( authSubject );
when( manager.login( anyMap() ) ).thenReturn( authSubject );
when( authSubject.getAuthenticationResult() ).thenReturn( AuthenticationResult.FAILURE );

// Expect
Expand All @@ -184,13 +184,13 @@ public void shouldNotBeAbleToUpdateCredentialsIfOldCredentialsAreInvalid() throw
}

@Test
public void shouldFailOnUnknownScheme() throws AuthenticationException
public void shouldFailOnUnknownScheme() throws Exception
{
// Given
BasicAuthManager manager = mock( BasicAuthManager.class );
BasicAuthSubject authSubject = mock( BasicAuthSubject.class );
BasicAuthentication authentication = new BasicAuthentication( manager, mock( LogProvider.class ), identifier );
when( manager.login( anyString(), anyString() ) ).thenReturn( authSubject );
when( manager.login( anyMap() ) ).thenReturn( authSubject );
when( authSubject.getAuthenticationResult() ).thenReturn( AuthenticationResult.SUCCESS );

// Expect
Expand All @@ -203,13 +203,13 @@ public void shouldFailOnUnknownScheme() throws AuthenticationException
}

@Test
public void shouldFailOnMalformedToken() throws AuthenticationException
public void shouldFailOnMalformedToken() throws Exception
{
// Given
BasicAuthManager manager = mock( BasicAuthManager.class );
BasicAuthSubject authSubject = mock( BasicAuthSubject.class );
BasicAuthentication authentication = new BasicAuthentication( manager, mock( LogProvider.class ), identifier );
when( manager.login( anyString(), anyString() ) ).thenReturn( authSubject );
when( manager.login( anyMap() ) ).thenReturn( authSubject );
when( authSubject.getAuthenticationResult() ).thenReturn( AuthenticationResult.SUCCESS );

// Expect
Expand Down
Expand Up @@ -130,6 +130,22 @@ public void shouldFailIfWrongCredentials() throws Throwable
String.format( "The client is unauthorized due to authentication failure. (ID:%s)", server.uniqueIdentier()) ) ) );
}

@Test
public void shouldFailIfMalformedAuthToken() throws Throwable
{
// When
client.connect( address )
.send( TransportTestUtil.acceptedVersions( 1, 0, 0, 0 ) )
.send( TransportTestUtil.chunk(
init( "TestClient/1.1",
map( "principal", "neo4j", "this-should-have-been-credentials", "neo4j", "scheme", "basic" ) ) ) );

// Then
assertThat( client, eventuallyRecieves( new byte[]{0, 0, 0, 1} ) );
assertThat( client, eventuallyRecieves( msgFailure( Status.Security.Unauthorized,
String.format( "The value associated with the key `credentials` must be a String but was: null (ID:%s)", server.uniqueIdentier()) ) ) );
}

@Test
public void shouldBeAbleToUpdateCredentials() throws Throwable
{
Expand Down

0 comments on commit be44158

Please sign in to comment.