Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 30 additions & 26 deletions neo4j/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,23 @@

try:
from neobolt.exceptions import (
ConnectionExpired,
CypherError,
TransientError,
IncompleteCommitError,
ServiceUnavailable,
TransientError,
)
except ImportError:
# We allow this to fail because this module can be imported implicitly
# during setup. At that point, dependencies aren't available.
pass
else:
__all__.extend([
"ConnectionExpired",
"CypherError",
"TransientError",
"IncompleteCommitError",
"ServiceUnavailable",
"TransientError",
])


Expand Down Expand Up @@ -363,35 +367,31 @@ def _connect(self, access_mode=None):
if access_mode is None:
access_mode = self._default_access_mode
if self._connection:
self._disconnect(sync=True)
self._connection.sync()
self._disconnect()
self._connection = self._acquirer(access_mode)

def _disconnect(self, sync):
from neobolt.exceptions import ConnectionExpired, ServiceUnavailable
def _disconnect(self):
if self._connection:
if sync:
try:
self._connection.sync()
except (SessionError, ConnectionExpired, ServiceUnavailable):
pass
if self._connection:
self._connection.in_use = False
self._connection = None
self._connection.in_use = False
self._connection = None

def close(self):
""" Close the session. This will release any borrowed resources,
such as connections, and will roll back any outstanding transactions.
"""
from neobolt.exceptions import ConnectionExpired, CypherError, ServiceUnavailable
try:
if self.has_transaction():
try:
self.rollback_transaction()
except (CypherError, TransactionError, SessionError, ConnectionExpired, ServiceUnavailable):
pass
finally:
self._closed = True
self._disconnect(sync=True)
if self._connection:
if self._transaction:
self._connection.rollback()
self._transaction = None
try:
self._connection.sync()
except (ConnectionExpired, CypherError, TransactionError,
ServiceUnavailable, SessionError):
pass
finally:
self._disconnect()
self._closed = True

def closed(self):
""" Indicator for whether or not this session has been closed.
Expand Down Expand Up @@ -554,7 +554,7 @@ def detach(self, result, sync=True):
if self._last_result is result:
self._last_result = None
if not self.has_transaction():
self._disconnect(sync=False)
self._disconnect()

result._session = None
return count
Expand Down Expand Up @@ -620,8 +620,11 @@ def commit_transaction(self):
metadata = {}
try:
self._connection.commit(on_success=metadata.update)
self._connection.sync()
except IncompleteCommitError:
raise ServiceUnavailable("Connection closed during commit")
finally:
self._disconnect(sync=True)
self._disconnect()
self._transaction = None
bookmark = metadata.get("bookmark")
self._bookmarks_in = tuple([bookmark])
Expand All @@ -641,8 +644,9 @@ def rollback_transaction(self):
metadata = {}
try:
cx.rollback(on_success=metadata.update)
cx.sync()
finally:
self._disconnect(sync=True)
self._disconnect()
self._transaction = None

def _run_transaction(self, access_mode, unit_of_work, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
neobolt<2,>=1.7.4
neobolt<2,>=1.7.6
neotime<2,>=1.7.1
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from neo4j.meta import package, version

install_requires = [
"neobolt<2,>=1.7.4",
"neobolt<2,>=1.7.6",
"neotime<2,>=1.7.1",
]
classifiers = [
Expand Down
18 changes: 0 additions & 18 deletions test/integration/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,24 +396,6 @@ def test_broken_transaction_should_not_break_session(self):
with session.begin_transaction() as tx:
tx.run("RETURN 1")

def test_last_run_statement_should_be_cleared_on_failure(self):
if not self.at_least_server_version(3, 2):
raise SkipTest("Statement reuse is not supported before server 3.2")

with self.driver.session() as session:
tx = session.begin_transaction()
tx.run("RETURN 1").consume()
connection_1 = session._connection
assert connection_1._last_run_statement == "RETURN 1"
with self.assertRaises(CypherSyntaxError):
result = tx.run("X")
connection_2 = session._connection
result.consume()
# connection_2 = session._connection
assert connection_2 is connection_1
assert connection_2._last_run_statement is None
tx.close()

def test_statement_object_not_supported(self):
with self.driver.session() as session:
with session.begin_transaction() as tx:
Expand Down
12 changes: 12 additions & 0 deletions test/stub/scripts/connection_error_on_commit.script
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
!: BOLT 3
!: AUTO HELLO
!: AUTO RESET

C: BEGIN {}
RUN "CREATE (n {name:'Bob'})" {} {}
PULL_ALL
S: SUCCESS {}
SUCCESS {}
SUCCESS {}
C: COMMIT
S: <EXIT>
6 changes: 3 additions & 3 deletions test/stub/scripts/return_1_four_times.script
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ S: SUCCESS {"fields": ["x"]}
RECORD [1]
SUCCESS {}

C: RUN "" {"x": 1}
C: RUN "RETURN $x" {"x": 1}
PULL_ALL
S: SUCCESS {"fields": ["x"]}
RECORD [1]
SUCCESS {}

C: RUN "" {"x": 1}
C: RUN "RETURN $x" {"x": 1}
PULL_ALL
S: SUCCESS {"fields": ["x"]}
RECORD [1]
SUCCESS {}

C: RUN "" {"x": 1}
C: RUN "RETURN $x" {"x": 1}
PULL_ALL
S: SUCCESS {"fields": ["x"]}
RECORD [1]
Expand Down
2 changes: 1 addition & 1 deletion test/stub/scripts/return_1_in_tx_twice.script
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ C: RUN "BEGIN" {"bookmark": "bookmark:1", "bookmarks": ["bookmark:1"]}
S: SUCCESS {"fields": []}
SUCCESS {}

C: RUN "" {}
C: RUN "RETURN 1" {}
PULL_ALL
S: SUCCESS {"fields": ["1"]}
RECORD [1]
Expand Down
2 changes: 1 addition & 1 deletion test/stub/scripts/return_1_twice.script
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ S: SUCCESS {"fields": ["x"]}
RECORD [1]
SUCCESS {}

C: RUN "" {"x": 1}
C: RUN "RETURN $x" {"x": 1}
PULL_ALL
S: SUCCESS {"fields": ["x"]}
RECORD [1]
Expand Down
2 changes: 1 addition & 1 deletion test/stub/scripts/return_1_twice_in_tx.script
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ S: SUCCESS {"fields": ["x"]}
RECORD [1]
SUCCESS {}

C: RUN "" {"x": 1}
C: RUN "RETURN $x" {"x": 1}
PULL_ALL
S: SUCCESS {"fields": ["x"]}
RECORD [1]
Expand Down
51 changes: 51 additions & 0 deletions test/stub/test_transactions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

# Copyright (c) 2002-2019 "Neo4j,"
# Neo4j Sweden AB [http://neo4j.com]
#
# This file is part of Neo4j.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from neobolt.exceptions import ServiceUnavailable

from neo4j import GraphDatabase

from test.stub.tools import StubTestCase, StubCluster


class TransactionTestCase(StubTestCase):

@staticmethod
def create_bob(tx):
tx.run("CREATE (n {name:'Bob'})").data()

def test_connection_error_on_explicit_commit(self):
with StubCluster({9001: "connection_error_on_commit.script"}):
uri = "bolt://127.0.0.1:9001"
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False, max_retry_time=0) as driver:
with driver.session() as session:
tx = session.begin_transaction()
tx.run("CREATE (n {name:'Bob'})").data()
with self.assertRaises(ServiceUnavailable):
tx.commit()

def test_connection_error_on_commit(self):
with StubCluster({9001: "connection_error_on_commit.script"}):
uri = "bolt://127.0.0.1:9001"
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False, max_retry_time=0) as driver:
with driver.session() as session:
with self.assertRaises(ServiceUnavailable):
session.write_transaction(self.create_bob)