diff --git a/neo4j/__init__.py b/neo4j/__init__.py index 1fe02cafa..e913ff237 100644 --- a/neo4j/__init__.py +++ b/neo4j/__init__.py @@ -348,18 +348,11 @@ def supports_multi_db(self): :return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false. :rtype: bool """ - from neo4j.io._bolt4x0 import Bolt4x0 - - multi_database = False cx = self._pool.acquire(access_mode=READ_ACCESS, timeout=self._pool.workspace_config.connection_acquisition_timeout, database=self._pool.workspace_config.database) - - # TODO: This logic should be inside the Bolt subclasses, because it can change depending on Bolt Protocol Version. - if cx.PROTOCOL_VERSION >= Bolt4x0.PROTOCOL_VERSION and cx.server_info.version_info() >= Version(4, 0, 0): - multi_database = True - + support = cx.supports_multiple_databases self._pool.release(cx) - return multi_database + return support class BoltDriver(Direct, Driver): diff --git a/neo4j/api.py b/neo4j/api.py index 18236c3e5..101c1fe9b 100644 --- a/neo4j/api.py +++ b/neo4j/api.py @@ -23,6 +23,7 @@ parse_qs, ) from.exceptions import ( + DriverError, ConfigurationError, ) @@ -149,12 +150,38 @@ def __init__(self, address, protocol_version): @property def agent(self): + """The server agent string the server responded with. + + :return: Server agent string + :rtype: string + """ + # Example "Neo4j/4.0.5" + # Example "Neo4j/4" return self.metadata.get("server") def version_info(self): + """Return the server version if available. + + :return: Server Version or None + :rtype: tuple + """ if not self.agent: return None - _, _, value = self.agent.partition("/") + # Note: Confirm that the server agent string begins with "Neo4j/" and fail gracefully if not. + # This is intended to help prevent drivers working for non-genuine Neo4j instances. + + neo4j, _, value = self.agent.partition("/") + try: + assert neo4j == "Neo4j" + except AssertionError: + raise DriverError("Server name does not start with Neo4j/") + + try: + if self.protocol_version >= (4, 0): + return self.protocol_version + except TypeError: + pass + value = value.replace("-", ".").split(".") for i, v in enumerate(value): try: diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index 8e0f1f06b..c1ddece0c 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -88,6 +88,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No self._max_connection_lifetime = max_connection_lifetime self._creation_timestamp = perf_counter() self.supports_multiple_results = False + self.supports_multiple_databases = False self._is_reset = True # Determine the user agent diff --git a/neo4j/io/_bolt4x0.py b/neo4j/io/_bolt4x0.py index 3fb3bd9f5..035979720 100644 --- a/neo4j/io/_bolt4x0.py +++ b/neo4j/io/_bolt4x0.py @@ -87,6 +87,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No self._max_connection_lifetime = max_connection_lifetime # self.pool_config.max_connection_lifetime self._creation_timestamp = perf_counter() self.supports_multiple_results = True + self.supports_multiple_databases = True self._is_reset = True # Determine the user agent diff --git a/tests/integration/test_bolt_driver.py b/tests/integration/test_bolt_driver.py index a9150ee87..49a7cf87d 100644 --- a/tests/integration/test_bolt_driver.py +++ b/tests/integration/test_bolt_driver.py @@ -31,6 +31,7 @@ unit_of_work, Transaction, Result, + ServerInfo, ) from neo4j.exceptions import ( ServiceUnavailable, @@ -39,6 +40,7 @@ ClientError, ) from neo4j._exceptions import BoltHandshakeError +from neo4j.io._bolt3 import Bolt3 # python -m pytest tests/integration/test_bolt_driver.py -s -v @@ -139,21 +141,28 @@ def test_supports_multi_db(bolt_uri, auth): with driver.session() as session: result = session.run("RETURN 1") - value = result.single().value() # Consumes the result + _ = result.single().value() # Consumes the result summary = result.consume() server_info = summary.server + assert isinstance(summary, ResultSummary) + assert isinstance(server_info, ServerInfo) + assert server_info.version_info() is not None + assert isinstance(server_info.protocol_version, Version) + result = driver.supports_multi_db() driver.close() - if server_info.version_info() >= Version(4, 0, 0) and server_info.protocol_version >= Version(4, 0): - assert result is True - assert summary.database == "neo4j" # This is the default database name if not set explicitly on the Neo4j Server - assert summary.query_type == "r" - else: + if server_info.protocol_version == Bolt3.PROTOCOL_VERSION: assert result is False assert summary.database is None assert summary.query_type == "r" + else: + assert result is True + assert server_info.version_info() >= Version(4, 0) + assert server_info.protocol_version >= Version(4, 0) + assert summary.database == "neo4j" # This is the default database name if not set explicitly on the Neo4j Server + assert summary.query_type == "r" def test_test_multi_db_specify_database(bolt_uri, auth): diff --git a/tests/integration/test_neo4j_driver.py b/tests/integration/test_neo4j_driver.py index 9bc97b3ad..18208922c 100644 --- a/tests/integration/test_neo4j_driver.py +++ b/tests/integration/test_neo4j_driver.py @@ -27,6 +27,7 @@ Version, READ_ACCESS, ResultSummary, + ServerInfo, ) from neo4j.exceptions import ( ServiceUnavailable, @@ -39,6 +40,7 @@ from neo4j.conf import ( RoutingConfig, ) +from neo4j.io._bolt3 import Bolt3 # python -m pytest tests/integration/test_neo4j_driver.py -s -v @@ -72,21 +74,28 @@ def test_supports_multi_db(neo4j_uri, auth, target): with driver.session() as session: result = session.run("RETURN 1") - value = result.single().value() # Consumes the result + _ = result.single().value() # Consumes the result summary = result.consume() server_info = summary.server + assert isinstance(summary, ResultSummary) + assert isinstance(server_info, ServerInfo) + assert server_info.version_info() is not None + assert isinstance(server_info.protocol_version, Version) + result = driver.supports_multi_db() driver.close() - if server_info.version_info() >= Version(4, 0, 0) and server_info.protocol_version >= Version(4, 0): - assert result is True - assert summary.database == "neo4j" # This is the default database name if not set explicitly on the Neo4j Server - assert summary.query_type == "r" - else: + if server_info.protocol_version == Bolt3.PROTOCOL_VERSION: assert result is False assert summary.database is None assert summary.query_type == "r" + else: + assert result is True + assert server_info.version_info() >= Version(4, 0) + assert server_info.protocol_version >= Version(4, 0) + assert summary.database == "neo4j" # This is the default database name if not set explicitly on the Neo4j Server + assert summary.query_type == "r" def test_test_multi_db_specify_database(neo4j_uri, auth, target):