diff --git a/neo4j/bolt/connection.py b/neo4j/bolt/connection.py index c1791fdd0..be1cbb00b 100644 --- a/neo4j/bolt/connection.py +++ b/neo4j/bolt/connection.py @@ -226,10 +226,7 @@ def on_success(self, metadata): connection.server = ServerInfo(address, version) def on_failure(self, metadata): - code = metadata.get("code") - error = (Unauthorized if code == "Neo.ClientError.Security.Unauthorized" else - ServiceUnavailable) - raise error(metadata.get("message", "INIT failed")) + raise ServiceUnavailable(metadata.get("message", "INIT failed"), metadata.get("code")) class Connection(object): @@ -629,12 +626,11 @@ class ServiceUnavailable(Exception): """ Raised when no database service is available. """ + def __init__(self, message, code=None): + super(ServiceUnavailable, self).__init__(message) + self.code = code + class ProtocolError(Exception): """ Raised when an unexpected or unsupported protocol event occurs. """ - - -class Unauthorized(Exception): - """ Raised when an action is not permitted. - """ \ No newline at end of file diff --git a/neo4j/v1/bolt.py b/neo4j/v1/bolt.py index bfdb4107f..d2f275ca1 100644 --- a/neo4j/v1/bolt.py +++ b/neo4j/v1/bolt.py @@ -26,7 +26,7 @@ fix_statement, fix_parameters, \ CypherError, SessionExpired, SessionError from .routing import RoutingConnectionPool -from .security import SecurityPlan +from .security import SecurityPlan, Unauthorized from .summary import ResultSummary from .types import Record @@ -45,7 +45,12 @@ def __init__(self, uri, **config): Driver.__init__(self, pool) def session(self, access_mode=None): - return BoltSession(self.pool.acquire(self.address)) + try: + return BoltSession(self.pool.acquire(self.address)) + except ServiceUnavailable as error: + if error.code == "Neo.ClientError.Security.Unauthorized": + raise Unauthorized(error.args[0]) + raise GraphDatabase.uri_schemes["bolt"] = DirectDriver diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index b9583026c..415d459ee 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -178,7 +178,10 @@ def fetch_routing_info(self, address): raise ServiceUnavailable("Server %r does not support routing" % (address,)) else: raise ServiceUnavailable("Routing support broken on server %r" % (address,)) - except ServiceUnavailable: + except ServiceUnavailable as error: + if error.code == "Neo.ClientError.Security.Unauthorized": + from neo4j.v1.security import Unauthorized + raise Unauthorized(error.args[0]) self.remove(address) return None @@ -270,7 +273,10 @@ def acquire_for_read(self): address = next(self.routing_table.readers) try: connection = self.acquire(address) - except ServiceUnavailable: + except ServiceUnavailable as error: + if error.code == "Neo.ClientError.Security.Unauthorized": + from neo4j.v1.security import Unauthorized + raise Unauthorized(error.args[0]) self.remove(address) else: return connection @@ -285,7 +291,10 @@ def acquire_for_write(self): address = next(self.routing_table.writers) try: connection = self.acquire(address) - except ServiceUnavailable: + except ServiceUnavailable as error: + if error.code == "Neo.ClientError.Security.Unauthorized": + from neo4j.v1.security import Unauthorized + raise Unauthorized(error.args[0]) self.remove(address) else: return connection diff --git a/neo4j/v1/security.py b/neo4j/v1/security.py index 8124f2fd3..f69253b3d 100644 --- a/neo4j/v1/security.py +++ b/neo4j/v1/security.py @@ -126,3 +126,8 @@ def _encryption_default(): "so communications are not secure") _warned_about_insecure_default = True return ENCRYPTION_DEFAULT + + +class Unauthorized(Exception): + """ Raised when an action is not permitted. + """ diff --git a/test/test_driver.py b/test/test_driver.py index 39d9f1ce1..3024768f7 100644 --- a/test/test_driver.py +++ b/test/test_driver.py @@ -25,7 +25,7 @@ from neo4j.v1 import ServiceUnavailable, ProtocolError, READ_ACCESS, WRITE_ACCESS, \ TRUST_ON_FIRST_USE, TRUST_CUSTOM_CA_SIGNED_CERTIFICATES, GraphDatabase, basic_auth, \ - custom_auth, SSL_AVAILABLE, SessionExpired, DirectDriver, RoutingDriver + custom_auth, SSL_AVAILABLE, SessionExpired, DirectDriver, RoutingDriver, Unauthorized from test.util import ServerTestCase, StubCluster BOLT_URI = "bolt://localhost:7687" @@ -231,3 +231,8 @@ def test_custom_ca_not_implemented(self): with self.assertRaises(NotImplementedError): _ = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, trust=TRUST_CUSTOM_CA_SIGNED_CERTIFICATES) + + def test_should_fail_on_incorrect_password(self): + with self.assertRaises(Unauthorized): + with GraphDatabase.driver(BOLT_URI, auth=("neo4j", "wrong-password")) as driver: + _ = driver.session()