diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 6d3f42d36..5d4895e75 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -177,12 +177,13 @@ def ping(cls, address, *, timeout=None, **config): return protocol_version @classmethod - def open(cls, address, *, auth=None, timeout=None, **pool_config): + def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_config): """ Open a new Bolt connection to a given server address. :param address: :param auth: - :param timeout: The connection timeout + :param timeout: the connection timeout in seconds + :param routing_context: dict containing routing context :param pool_config: :return: :raise BoltHandshakeError: raised if the Bolt Protocol can not negotiate a protocol version. @@ -200,15 +201,15 @@ def open(cls, address, *, auth=None, timeout=None, **pool_config): if pool_config.protocol_version == (3, 0): # Carry out Bolt subclass imports locally to avoid circular dependency issues. from neo4j.io._bolt3 import Bolt3 - connection = Bolt3(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent) + connection = Bolt3(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent, routing_context=routing_context) elif pool_config.protocol_version == (4, 0): # Carry out Bolt subclass imports locally to avoid circular dependency issues. from neo4j.io._bolt4x0 import Bolt4x0 - connection = Bolt4x0(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent) + connection = Bolt4x0(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent, routing_context=routing_context) elif pool_config.protocol_version == (4, 1): # Carry out Bolt subclass imports locally to avoid circular dependency issues. from neo4j.io._bolt4x1 import Bolt4x1 - connection = Bolt4x1(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent) + connection = Bolt4x1(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent, routing_context=routing_context) else: log.debug("[#%04X] S: ", s.getpeername()[1]) s.shutdown(SHUT_RDWR) @@ -499,27 +500,35 @@ def close(self): class BoltPool(IOPool): @classmethod - def open(cls, address, *, auth, pool_config, workspace_config): + def open(cls, address, *, auth, pool_config, workspace_config, routing_context=None): """Create a new BoltPool :param address: :param auth: :param pool_config: :param workspace_config: + :param routing_context: :return: BoltPool """ + if routing_context is None: + routing_context = {} + elif "address" in routing_context: + raise ConfigurationError("The key 'address' is reserved for routing context.") + routing_context["address"] = str(address) + def opener(addr, timeout): - return Bolt.open(addr, auth=auth, timeout=timeout, **pool_config) + return Bolt.open(addr, auth=auth, timeout=timeout, routing_context=routing_context, **pool_config) - pool = cls(opener, pool_config, workspace_config, address) + pool = cls(opener, pool_config, workspace_config, routing_context, address) seeds = [pool.acquire() for _ in range(pool_config.init_size)] pool.release(*seeds) return pool - def __init__(self, opener, pool_config, workspace_config, address): + def __init__(self, opener, pool_config, workspace_config, routing_context, address): super(BoltPool, self).__init__(opener, pool_config, workspace_config) self.address = address + self.routing_context = routing_context def __repr__(self): return "<{} address={!r}>".format(self.__class__.__name__, self.address) @@ -545,10 +554,17 @@ def open(cls, *addresses, auth, pool_config, workspace_config, routing_context=N :return: Neo4jPool """ + address = addresses[0] + if routing_context is None: + routing_context = {} + elif "address" in routing_context: + raise ConfigurationError("The key 'address' is reserved for routing context.") + routing_context["address"] = str(address) + def opener(addr, timeout): - return Bolt.open(addr, auth=auth, timeout=timeout, **pool_config) + return Bolt.open(addr, auth=auth, timeout=timeout, routing_context=routing_context, **pool_config) - pool = cls(opener, pool_config, workspace_config, routing_context, addresses) + pool = cls(opener, pool_config, workspace_config, routing_context, address) try: pool.update_routing_table(database=workspace_config.database) @@ -558,7 +574,7 @@ def opener(addr, timeout): else: return pool - def __init__(self, opener, pool_config, workspace_config, routing_context, addresses): + def __init__(self, opener, pool_config, workspace_config, routing_context, address): """ :param opener: @@ -569,15 +585,10 @@ def __init__(self, opener, pool_config, workspace_config, routing_context, addre """ super(Neo4jPool, self).__init__(opener, pool_config, workspace_config) # Each database have a routing table, the default database is a special case. - log.debug("[#0000] C: routing addresses %r", addresses) - self.init_address = addresses[0] - self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=addresses)} + log.debug("[#0000] C: routing address %r", address) + self.address = address + self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])} self.routing_context = routing_context - if self.routing_context is None: - self.routing_context = {} - elif "address" in self.routing_context: - raise ConfigurationError("The key 'address' is reserved for routing context.") - self.routing_context["address"] = str(self.init_address) self.refresh_lock = Lock() def __repr__(self): @@ -621,7 +632,7 @@ def fetch_routing_info(self, *, address, timeout, database): :param address: router address :param timeout: seconds :param database: the data base name to get routing table for - :param init_address: the address by which the client initially contacted the server as a hint for inclusion in the returned routing table. + :param address: the address by which the client initially contacted the server as a hint for inclusion in the returned routing table. :return: list of routing records or None if no connection could be established diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index cbd63fae6..b874a6d75 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -76,7 +76,7 @@ class Bolt3(Bolt): #: The pool of which this connection is a member pool = None - def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None): + def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None): self.unresolved_address = unresolved_address self.socket = sock self.server_info = ServerInfo(Address(sock.getpeername()), Bolt3.PROTOCOL_VERSION) @@ -90,6 +90,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No self.supports_multiple_results = False self.supports_multiple_databases = False self._is_reset = True + self.routing_context = routing_context # Determine the user agent if user_agent: diff --git a/neo4j/io/_bolt4x0.py b/neo4j/io/_bolt4x0.py index ac3d73e18..0c75ef74d 100644 --- a/neo4j/io/_bolt4x0.py +++ b/neo4j/io/_bolt4x0.py @@ -75,7 +75,7 @@ class Bolt4x0(Bolt): #: The pool of which this connection is a member pool = None - def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None): + def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None): self.unresolved_address = unresolved_address self.socket = sock self.server_info = ServerInfo(Address(sock.getpeername()), Bolt4x0.PROTOCOL_VERSION) @@ -89,6 +89,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No self.supports_multiple_results = True self.supports_multiple_databases = True self._is_reset = True + self.routing_context = routing_context # Determine the user agent if user_agent: diff --git a/neo4j/io/_bolt4x1.py b/neo4j/io/_bolt4x1.py index 15a056995..5236d9517 100644 --- a/neo4j/io/_bolt4x1.py +++ b/neo4j/io/_bolt4x1.py @@ -75,7 +75,7 @@ class Bolt4x1(Bolt): #: The pool of which this connection is a member pool = None - def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None): + def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None): self.unresolved_address = unresolved_address self.socket = sock self.server_info = ServerInfo(Address(sock.getpeername()), Bolt4x1.PROTOCOL_VERSION) @@ -89,6 +89,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No self.supports_multiple_results = True self.supports_multiple_databases = True self._is_reset = True + self.routing_context = routing_context # Determine the user agent if user_agent: @@ -135,6 +136,7 @@ def local_port(self): def hello(self): headers = {"user_agent": self.user_agent} headers.update(self.auth_dict) + headers["routing"] = self.routing_context logged_headers = dict(headers) if "credentials" in logged_headers: logged_headers["credentials"] = "*******" diff --git a/tests/stub/conftest.py b/tests/stub/conftest.py index e0298db21..6aab0d5b5 100644 --- a/tests/stub/conftest.py +++ b/tests/stub/conftest.py @@ -21,6 +21,7 @@ import subprocess import os +import time from threading import Thread from time import sleep @@ -44,18 +45,20 @@ def __init__(self, port, script): def run(self): self._process = subprocess.Popen(["python", "-m", "boltkit", "stub", "-v", "-l", ":{}".format(str(self.port)), "-t", "10", self.script], stdout=subprocess.PIPE) # Need verbose for this to work - line = self._process.stdout.readline() + line = self._process.stdout.readline().decode("utf-8") + log.debug("started stub server {}".format(self.port)) + log.debug(line.strip("\n")) def wait(self): - try: - returncode = self._process.wait(2) - if returncode != 0: - log.debug("stubserver return code {}".format(returncode)) - log.debug("check for miss match in script") - return returncode == 0 - except subprocess.TimeoutExpired: - log.debug("stubserver timeout!") - return False + while True: + return_code = self._process.poll() + if return_code is not None: + line = self._process.stdout.readline().decode("utf-8") + if line == "": + break + log.debug(line.strip("\n")) + + return True def kill(self): # Kill process if not already dead @@ -129,6 +132,7 @@ class DefaultBoltStubService(BoltStubService): class StubCluster(StubCluster): def __init__(self, *servers): + print("") scripts = [os.path.join(os.path.dirname(__file__), "scripts", server) for server in servers] bss = DefaultBoltStubService.load(*scripts) diff --git a/tests/stub/scripts/v3/empty.script b/tests/stub/scripts/v3/empty.script index e52d27d17..3b38a3ca1 100644 --- a/tests/stub/scripts/v3/empty.script +++ b/tests/stub/scripts/v3/empty.script @@ -2,3 +2,4 @@ !: AUTO HELLO !: AUTO GOODBYE !: AUTO RESET +!: PORT 9001 diff --git a/tests/stub/scripts/v3/empty_explicit_hello_goodbye.script b/tests/stub/scripts/v3/empty_explicit_hello_goodbye.script index f0390f93f..4ed2b804c 100644 --- a/tests/stub/scripts/v3/empty_explicit_hello_goodbye.script +++ b/tests/stub/scripts/v3/empty_explicit_hello_goodbye.script @@ -1,4 +1,5 @@ !: BOLT 3 +!: PORT 9001 C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"} S: SUCCESS {"server": "Neo4j/3.5.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"} diff --git a/tests/stub/scripts/v4x0/empty.script b/tests/stub/scripts/v4x0/empty.script index ff88ead00..6e97eca9e 100644 --- a/tests/stub/scripts/v4x0/empty.script +++ b/tests/stub/scripts/v4x0/empty.script @@ -1,4 +1,5 @@ !: BOLT 4 !: AUTO HELLO !: AUTO GOODBYE -!: AUTO RESET \ No newline at end of file +!: AUTO RESET +!: PORT 9001 \ No newline at end of file diff --git a/tests/stub/scripts/v4x0/empty_explicit_hello_goodbye.script b/tests/stub/scripts/v4x0/empty_explicit_hello_goodbye.script index 0c530cca8..29957f14a 100644 --- a/tests/stub/scripts/v4x0/empty_explicit_hello_goodbye.script +++ b/tests/stub/scripts/v4x0/empty_explicit_hello_goodbye.script @@ -1,4 +1,5 @@ !: BOLT 4 +!: PORT 9001 C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"} S: SUCCESS {"server": "Neo4j/4.0.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"} diff --git a/tests/stub/scripts/v4x1/empty_explicit_hello_goodbye.script b/tests/stub/scripts/v4x1/empty_explicit_hello_goodbye.script index 84da1223d..20043c9a1 100644 --- a/tests/stub/scripts/v4x1/empty_explicit_hello_goodbye.script +++ b/tests/stub/scripts/v4x1/empty_explicit_hello_goodbye.script @@ -1,6 +1,7 @@ !: BOLT 4.1 +!: PORT 9001 -C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"} +C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test", "routing": {"address": "localhost:9001"}} S: SUCCESS {"server": "Neo4j/4.1.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"} C: GOODBYE S: \ No newline at end of file diff --git a/tests/stub/scripts/v4x1/hello_with_routing_context_return_1_port_9002.script b/tests/stub/scripts/v4x1/hello_with_routing_context_return_1_port_9002.script new file mode 100644 index 000000000..8bd2038a1 --- /dev/null +++ b/tests/stub/scripts/v4x1/hello_with_routing_context_return_1_port_9002.script @@ -0,0 +1,12 @@ +!: BOLT 4.1 +!: AUTO GOODBYE +!: AUTO RESET +!: PORT 9002 + +C: HELLO {"scheme": "basic", "principal": "test", "credentials": "test", "user_agent": "test", "routing": {"address": "localhost:9001", "policy": "my_policy", "region": "china"}} +S: SUCCESS {"server": "Neo4j/4.1.0", "connection_id": "bolt-123456789"} +C: RUN "RETURN 1 AS x" {} {"mode": "r"} + PULL {"n": -1} +S: SUCCESS {"fields": ["x"]} + RECORD [1] + SUCCESS {"bookmark": "neo4j:bookmark-test-2", "type": "r", "t_last": 5, "db": "system"} diff --git a/tests/stub/scripts/v4x1/return_1_noop_port_9001.script b/tests/stub/scripts/v4x1/return_1_noop_port_9001.script index 41e5f1710..2d0f52639 100644 --- a/tests/stub/scripts/v4x1/return_1_noop_port_9001.script +++ b/tests/stub/scripts/v4x1/return_1_noop_port_9001.script @@ -3,7 +3,7 @@ !: AUTO RESET !: PORT 9001 -C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"} +C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test", "routing": {"address": "localhost:9001"}} S: SUCCESS {"server": "Neo4j/4.1.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"} C: RUN "RETURN 1 AS x" {} {"mode": "r"} PULL {"n": 2} diff --git a/tests/stub/scripts/v4x1/router_get_routing_table_with_context.script b/tests/stub/scripts/v4x1/router_get_routing_table_with_context.script new file mode 100644 index 000000000..8d7428f67 --- /dev/null +++ b/tests/stub/scripts/v4x1/router_get_routing_table_with_context.script @@ -0,0 +1,12 @@ +!: BOLT 4.1 +!: AUTO GOODBYE +!: AUTO RESET +!: PORT 9001 + +C: HELLO {"scheme": "basic", "principal": "test", "credentials": "test", "user_agent": "test", "routing": {"address": "localhost:9001", "policy": "my_policy", "region": "china"}} +S: SUCCESS {"server": "Neo4j/4.1.0", "connection_id": "bolt-123456789"} +C: RUN "CALL dbms.routing.getRoutingTable($context)" {"context": {"address": "localhost:9001", "policy": "my_policy", "region": "china"}} {"mode": "r", "db": "system"} + PULL {"n": -1} +S: SUCCESS {"fields": ["ttl", "servers"]} + RECORD [4321, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "READ"}, {"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]] + SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "r", "t_last": 5, "db": "system"} diff --git a/tests/stub/test_directdriver.py b/tests/stub/test_directdriver.py index 4661f2c14..0c486d8b1 100644 --- a/tests/stub/test_directdriver.py +++ b/tests/stub/test_directdriver.py @@ -100,7 +100,7 @@ def test_bolt_uri_constructs_bolt_driver(driver_info, test_script): def test_direct_driver_handshake_negotiation(driver_info, test_script, test_expected): # python -m pytest tests/stub/test_directdriver.py -s -v -k test_direct_driver_handshake_negotiation with StubCluster(test_script): - uri = "bolt://127.0.0.1:9001" + uri = "bolt://localhost:9001" if test_expected: with pytest.raises(test_expected) as error: driver = GraphDatabase.driver(uri, auth=driver_info["auth_token"], **driver_config) @@ -586,7 +586,7 @@ def test_bolt_driver_explicit_transaction_consume_result_case_b(driver_info, tes def test_direct_can_handle_noop(driver_info, test_script): # python -m pytest tests/stub/test_directdriver.py -s -v -k test_direct_can_handle_noop with StubCluster(test_script): - uri = "bolt://127.0.0.1:9001" + uri = "bolt://localhost:9001" with GraphDatabase.driver(uri, auth=driver_info["auth_token"], **driver_config) as driver: assert isinstance(driver, BoltDriver) with driver.session(fetch_size=2, default_access_mode=READ_ACCESS) as session: diff --git a/tests/stub/test_routingdriver.py b/tests/stub/test_routingdriver.py index 370286b92..c62d2decd 100644 --- a/tests/stub/test_routingdriver.py +++ b/tests/stub/test_routingdriver.py @@ -648,3 +648,23 @@ def test_forgets_address_on_database_unavailable_error(driver_info, test_scripts # reader 127.0.0.1:9004 should've been forgotten because of an raised assert not table.readers assert table.writers == {('127.0.0.1', 9006)} + + +@pytest.mark.parametrize( + "test_scripts", + [ + ("v4x1/router_get_routing_table_with_context.script", "v4x1/hello_with_routing_context_return_1_port_9002.script",), + ] +) +def test_hello_routing(driver_info, test_scripts): + # python -m pytest tests/stub/test_routingdriver.py -s -v -k test_hello_routing + with StubCluster(*test_scripts): + uri = "neo4j://localhost:9001/?region=china&policy=my_policy" + with GraphDatabase.driver(uri, auth=driver_info["auth_token"], user_agent="test") as driver: + with driver.session(default_access_mode=READ_ACCESS, fetch_size=-1) as session: + result = session.run("RETURN 1 AS x") + for record in result: + assert record["x"] == 1 + address = result.consume().server.address + assert address.host == "127.0.0.1" + assert address.port == 9002