Skip to content

Commit

Permalink
Tidied up session interface to make access_mode a kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
technige committed Feb 1, 2019
1 parent dc08642 commit 9d28038
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -4,3 +4,4 @@
- Package can now no longer be installed as `neo4j-driver`; use `pip install neo4j` instead
- Support dropped for Python 2.7; explicit support added for Python 3.7 and 3.8
- The `neo4j.v1` subpackage is now no longer available; all imports should be taken from the `neo4j` package instead
- Changed `session(access_mode)` from a positional to a keyword argument
31 changes: 20 additions & 11 deletions neo4j/__init__.py
Expand Up @@ -153,18 +153,25 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
self.close()

def session(self, access_mode=None, **parameters):
def _assert_open(self):
if self.closed():
raise DriverError("Driver closed")

def session(self, **parameters):
""" Create a new :class:`.Session` object based on this
:class:`.Driver`.
:param access_mode: default access mode (read or write) for
transactions in this session
:param parameters: custom session parameters (see
:class:`.Session` for details)
:returns: new :class:`.Session` object
"""
if self.closed():
raise DriverError("Driver closed")
raise NotImplementedError("Blocking sessions are not implemented for the %s class" % type(self).__name__)

def async_session(self, **parameters):
raise NotImplementedError("Asynchronous sessions are not implemented for the %s class" % type(self).__name__)

def rx_session(self, **parameters):
raise NotImplementedError("Reactive sessions are not implemented for the %s class" % type(self).__name__)

def close(self):
""" Shut down, closing any open connections in the pool.
Expand Down Expand Up @@ -220,10 +227,11 @@ def connector(address, **kwargs):
instance._max_retry_time = config.get("max_retry_time", default_config["max_retry_time"])
return instance

def session(self, access_mode=None, **parameters):
def session(self, **parameters):
self._assert_open()
if "max_retry_time" not in parameters:
parameters["max_retry_time"] = self._max_retry_time
return Session(self._pool.acquire, access_mode, **parameters)
return Session(self._pool.acquire, **parameters)


class RoutingDriver(Driver):
Expand Down Expand Up @@ -267,10 +275,11 @@ def connector(address, **kwargs):
instance._max_retry_time = config.get("max_retry_time", default_config["max_retry_time"])
return instance

def session(self, access_mode=None, **parameters):
def session(self, **parameters):
self._assert_open()
if "max_retry_time" not in parameters:
parameters["max_retry_time"] = self._max_retry_time
return Session(self._pool.acquire, access_mode, **parameters)
return Session(self._pool.acquire, **parameters)


class Session(object):
Expand Down Expand Up @@ -332,9 +341,9 @@ class Session(object):

_closed = False

def __init__(self, acquirer, access_mode, **parameters):
def __init__(self, acquirer, **parameters):
self._acquirer = acquirer
self._default_access_mode = access_mode
self._default_access_mode = parameters.get("access_mode")
for key, value in parameters.items():
if key == "bookmark":
if value:
Expand Down
14 changes: 7 additions & 7 deletions test/stub/test_routingdriver.py
Expand Up @@ -188,7 +188,7 @@ def test_should_call_get_routing_table_procedure(self):
with StubCluster({9001: "get_routing_table.script", 9002: "return_1.script"}):
uri = "bolt+routing://127.0.0.1:9001"
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver:
with driver.session(READ_ACCESS) as session:
with driver.session(access_mode=READ_ACCESS) as session:
result = session.run("RETURN $x", {"x": 1})
for record in result:
assert record["x"] == 1
Expand All @@ -198,7 +198,7 @@ def test_should_call_get_routing_table_with_context(self):
with StubCluster({9001: "get_routing_table_with_context.script", 9002: "return_1.script"}):
uri = "bolt+routing://127.0.0.1:9001/?name=molly&age=1"
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver:
with driver.session(READ_ACCESS) as session:
with driver.session(access_mode=READ_ACCESS) as session:
result = session.run("RETURN $x", {"x": 1})
for record in result:
assert record["x"] == 1
Expand All @@ -208,7 +208,7 @@ def test_should_serve_read_when_missing_writer(self):
with StubCluster({9001: "router_no_writers.script", 9005: "return_1.script"}):
uri = "bolt+routing://127.0.0.1:9001"
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver:
with driver.session(READ_ACCESS) as session:
with driver.session(access_mode=READ_ACCESS) as session:
result = session.run("RETURN $x", {"x": 1})
for record in result:
assert record["x"] == 1
Expand Down Expand Up @@ -249,7 +249,7 @@ def test_forgets_address_on_not_a_leader_error(self):
with StubCluster({9001: "router.script", 9006: "not_a_leader.script"}):
uri = "bolt+routing://127.0.0.1:9001"
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver:
with driver.session(WRITE_ACCESS) as session:
with driver.session(access_mode=WRITE_ACCESS) as session:
with self.assertRaises(ClientError):
_ = session.run("CREATE (n {name:'Bob'})")

Expand All @@ -267,7 +267,7 @@ def test_forgets_address_on_forbidden_on_read_only_database_error(self):
with StubCluster({9001: "router.script", 9006: "forbidden_on_read_only_database.script"}):
uri = "bolt+routing://127.0.0.1:9001"
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver:
with driver.session(WRITE_ACCESS) as session:
with driver.session(access_mode=WRITE_ACCESS) as session:
with self.assertRaises(ClientError):
_ = session.run("CREATE (n {name:'Bob'})")

Expand All @@ -285,7 +285,7 @@ def test_forgets_address_on_service_unavailable_error(self):
with StubCluster({9001: "router.script", 9004: "rude_reader.script"}):
uri = "bolt+routing://127.0.0.1:9001"
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver:
with driver.session(READ_ACCESS) as session:
with driver.session(access_mode=READ_ACCESS) as session:
with self.assertRaises(SessionExpired):
_ = session.run("RETURN 1")

Expand All @@ -309,7 +309,7 @@ def test_forgets_address_on_database_unavailable_error(self):
with StubCluster({9001: "router.script", 9004: "database_unavailable.script"}):
uri = "bolt+routing://127.0.0.1:9001"
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver:
with driver.session(READ_ACCESS) as session:
with driver.session(access_mode=READ_ACCESS) as session:
with self.assertRaises(TransientError):
_ = session.run("RETURN 1")

Expand Down

0 comments on commit 9d28038

Please sign in to comment.