Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions neo4j/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,18 +348,11 @@ def supports_multi_db(self):
:return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false.
:rtype: bool
"""
from neo4j.io._bolt4x0 import Bolt4x0

multi_database = False
cx = self._pool.acquire(access_mode=READ_ACCESS, timeout=self._pool.workspace_config.connection_acquisition_timeout, database=self._pool.workspace_config.database)

# TODO: This logic should be inside the Bolt subclasses, because it can change depending on Bolt Protocol Version.
if cx.PROTOCOL_VERSION >= Bolt4x0.PROTOCOL_VERSION and cx.server_info.version_info() >= Version(4, 0, 0):
multi_database = True

support = cx.supports_multiple_databases
self._pool.release(cx)

return multi_database
return support


class BoltDriver(Direct, Driver):
Expand Down
29 changes: 28 additions & 1 deletion neo4j/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
parse_qs,
)
from.exceptions import (
DriverError,
ConfigurationError,
)

Expand Down Expand Up @@ -149,12 +150,38 @@ def __init__(self, address, protocol_version):

@property
def agent(self):
"""The server agent string the server responded with.

:return: Server agent string
:rtype: string
"""
# Example "Neo4j/4.0.5"
# Example "Neo4j/4"
return self.metadata.get("server")

def version_info(self):
"""Return the server version if available.

:return: Server Version or None
:rtype: tuple
"""
if not self.agent:
return None
_, _, value = self.agent.partition("/")
# Note: Confirm that the server agent string begins with "Neo4j/" and fail gracefully if not.
# This is intended to help prevent drivers working for non-genuine Neo4j instances.

neo4j, _, value = self.agent.partition("/")
try:
assert neo4j == "Neo4j"
except AssertionError:
raise DriverError("Server name does not start with Neo4j/")

try:
if self.protocol_version >= (4, 0):
return self.protocol_version
except TypeError:
pass

value = value.replace("-", ".").split(".")
for i, v in enumerate(value):
try:
Expand Down
1 change: 1 addition & 0 deletions neo4j/io/_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No
self._max_connection_lifetime = max_connection_lifetime
self._creation_timestamp = perf_counter()
self.supports_multiple_results = False
self.supports_multiple_databases = False
self._is_reset = True

# Determine the user agent
Expand Down
1 change: 1 addition & 0 deletions neo4j/io/_bolt4x0.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No
self._max_connection_lifetime = max_connection_lifetime # self.pool_config.max_connection_lifetime
self._creation_timestamp = perf_counter()
self.supports_multiple_results = True
self.supports_multiple_databases = True
self._is_reset = True

# Determine the user agent
Expand Down
21 changes: 15 additions & 6 deletions tests/integration/test_bolt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
unit_of_work,
Transaction,
Result,
ServerInfo,
)
from neo4j.exceptions import (
ServiceUnavailable,
Expand All @@ -39,6 +40,7 @@
ClientError,
)
from neo4j._exceptions import BoltHandshakeError
from neo4j.io._bolt3 import Bolt3

# python -m pytest tests/integration/test_bolt_driver.py -s -v

Expand Down Expand Up @@ -139,21 +141,28 @@ def test_supports_multi_db(bolt_uri, auth):

with driver.session() as session:
result = session.run("RETURN 1")
value = result.single().value() # Consumes the result
_ = result.single().value() # Consumes the result
summary = result.consume()
server_info = summary.server

assert isinstance(summary, ResultSummary)
assert isinstance(server_info, ServerInfo)
assert server_info.version_info() is not None
assert isinstance(server_info.protocol_version, Version)

result = driver.supports_multi_db()
driver.close()

if server_info.version_info() >= Version(4, 0, 0) and server_info.protocol_version >= Version(4, 0):
assert result is True
assert summary.database == "neo4j" # This is the default database name if not set explicitly on the Neo4j Server
assert summary.query_type == "r"
else:
if server_info.protocol_version == Bolt3.PROTOCOL_VERSION:
assert result is False
assert summary.database is None
assert summary.query_type == "r"
else:
assert result is True
assert server_info.version_info() >= Version(4, 0)
assert server_info.protocol_version >= Version(4, 0)
assert summary.database == "neo4j" # This is the default database name if not set explicitly on the Neo4j Server
assert summary.query_type == "r"


def test_test_multi_db_specify_database(bolt_uri, auth):
Expand Down
21 changes: 15 additions & 6 deletions tests/integration/test_neo4j_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Version,
READ_ACCESS,
ResultSummary,
ServerInfo,
)
from neo4j.exceptions import (
ServiceUnavailable,
Expand All @@ -39,6 +40,7 @@
from neo4j.conf import (
RoutingConfig,
)
from neo4j.io._bolt3 import Bolt3

# python -m pytest tests/integration/test_neo4j_driver.py -s -v

Expand Down Expand Up @@ -72,21 +74,28 @@ def test_supports_multi_db(neo4j_uri, auth, target):

with driver.session() as session:
result = session.run("RETURN 1")
value = result.single().value() # Consumes the result
_ = result.single().value() # Consumes the result
summary = result.consume()
server_info = summary.server

assert isinstance(summary, ResultSummary)
assert isinstance(server_info, ServerInfo)
assert server_info.version_info() is not None
assert isinstance(server_info.protocol_version, Version)

result = driver.supports_multi_db()
driver.close()

if server_info.version_info() >= Version(4, 0, 0) and server_info.protocol_version >= Version(4, 0):
assert result is True
assert summary.database == "neo4j" # This is the default database name if not set explicitly on the Neo4j Server
assert summary.query_type == "r"
else:
if server_info.protocol_version == Bolt3.PROTOCOL_VERSION:
assert result is False
assert summary.database is None
assert summary.query_type == "r"
else:
assert result is True
assert server_info.version_info() >= Version(4, 0)
assert server_info.protocol_version >= Version(4, 0)
assert summary.database == "neo4j" # This is the default database name if not set explicitly on the Neo4j Server
assert summary.query_type == "r"


def test_test_multi_db_specify_database(neo4j_uri, auth, target):
Expand Down