Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions neo4j/bolt/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,7 @@ def on_success(self, metadata):
connection.server = ServerInfo(address, version)

def on_failure(self, metadata):
code = metadata.get("code")
error = (Unauthorized if code == "Neo.ClientError.Security.Unauthorized" else
ServiceUnavailable)
raise error(metadata.get("message", "INIT failed"))
raise ServiceUnavailable(metadata.get("message", "INIT failed"), metadata.get("code"))


class Connection(object):
Expand Down Expand Up @@ -629,12 +626,11 @@ class ServiceUnavailable(Exception):
""" Raised when no database service is available.
"""

def __init__(self, message, code=None):
super(ServiceUnavailable, self).__init__(message)
self.code = code


class ProtocolError(Exception):
""" Raised when an unexpected or unsupported protocol event occurs.
"""


class Unauthorized(Exception):
""" Raised when an action is not permitted.
"""
9 changes: 7 additions & 2 deletions neo4j/v1/bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
fix_statement, fix_parameters, \
CypherError, SessionExpired, SessionError
from .routing import RoutingConnectionPool
from .security import SecurityPlan
from .security import SecurityPlan, Unauthorized
from .summary import ResultSummary
from .types import Record

Expand All @@ -45,7 +45,12 @@ def __init__(self, uri, **config):
Driver.__init__(self, pool)

def session(self, access_mode=None):
return BoltSession(self.pool.acquire(self.address))
try:
return BoltSession(self.pool.acquire(self.address))
except ServiceUnavailable as error:
if error.code == "Neo.ClientError.Security.Unauthorized":
raise Unauthorized(error.args[0])
raise


GraphDatabase.uri_schemes["bolt"] = DirectDriver
Expand Down
15 changes: 12 additions & 3 deletions neo4j/v1/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ def fetch_routing_info(self, address):
raise ServiceUnavailable("Server %r does not support routing" % (address,))
else:
raise ServiceUnavailable("Routing support broken on server %r" % (address,))
except ServiceUnavailable:
except ServiceUnavailable as error:
if error.code == "Neo.ClientError.Security.Unauthorized":
from neo4j.v1.security import Unauthorized
raise Unauthorized(error.args[0])
self.remove(address)
return None

Expand Down Expand Up @@ -270,7 +273,10 @@ def acquire_for_read(self):
address = next(self.routing_table.readers)
try:
connection = self.acquire(address)
except ServiceUnavailable:
except ServiceUnavailable as error:
if error.code == "Neo.ClientError.Security.Unauthorized":
from neo4j.v1.security import Unauthorized
raise Unauthorized(error.args[0])
self.remove(address)
else:
return connection
Expand All @@ -285,7 +291,10 @@ def acquire_for_write(self):
address = next(self.routing_table.writers)
try:
connection = self.acquire(address)
except ServiceUnavailable:
except ServiceUnavailable as error:
if error.code == "Neo.ClientError.Security.Unauthorized":
from neo4j.v1.security import Unauthorized
raise Unauthorized(error.args[0])
self.remove(address)
else:
return connection
Expand Down
5 changes: 5 additions & 0 deletions neo4j/v1/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,8 @@ def _encryption_default():
"so communications are not secure")
_warned_about_insecure_default = True
return ENCRYPTION_DEFAULT


class Unauthorized(Exception):
""" Raised when an action is not permitted.
"""
7 changes: 6 additions & 1 deletion test/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from neo4j.v1 import ServiceUnavailable, ProtocolError, READ_ACCESS, WRITE_ACCESS, \
TRUST_ON_FIRST_USE, TRUST_CUSTOM_CA_SIGNED_CERTIFICATES, GraphDatabase, basic_auth, \
custom_auth, SSL_AVAILABLE, SessionExpired, DirectDriver, RoutingDriver
custom_auth, SSL_AVAILABLE, SessionExpired, DirectDriver, RoutingDriver, Unauthorized
from test.util import ServerTestCase, StubCluster

BOLT_URI = "bolt://localhost:7687"
Expand Down Expand Up @@ -231,3 +231,8 @@ def test_custom_ca_not_implemented(self):
with self.assertRaises(NotImplementedError):
_ = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN,
trust=TRUST_CUSTOM_CA_SIGNED_CERTIFICATES)

def test_should_fail_on_incorrect_password(self):
with self.assertRaises(Unauthorized):
with GraphDatabase.driver(BOLT_URI, auth=("neo4j", "wrong-password")) as driver:
_ = driver.session()