Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 32 additions & 21 deletions neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,13 @@ def ping(cls, address, *, timeout=None, **config):
return protocol_version

@classmethod
def open(cls, address, *, auth=None, timeout=None, **pool_config):
def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_config):
""" Open a new Bolt connection to a given server address.

:param address:
:param auth:
:param timeout: The connection timeout
:param timeout: the connection timeout in seconds
:param routing_context: dict containing routing context
:param pool_config:
:return:
:raise BoltHandshakeError: raised if the Bolt Protocol can not negotiate a protocol version.
Expand All @@ -200,15 +201,15 @@ def open(cls, address, *, auth=None, timeout=None, **pool_config):
if pool_config.protocol_version == (3, 0):
# Carry out Bolt subclass imports locally to avoid circular dependency issues.
from neo4j.io._bolt3 import Bolt3
connection = Bolt3(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent)
connection = Bolt3(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent, routing_context=routing_context)
elif pool_config.protocol_version == (4, 0):
# Carry out Bolt subclass imports locally to avoid circular dependency issues.
from neo4j.io._bolt4x0 import Bolt4x0
connection = Bolt4x0(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent)
connection = Bolt4x0(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent, routing_context=routing_context)
elif pool_config.protocol_version == (4, 1):
# Carry out Bolt subclass imports locally to avoid circular dependency issues.
from neo4j.io._bolt4x1 import Bolt4x1
connection = Bolt4x1(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent)
connection = Bolt4x1(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent, routing_context=routing_context)
else:
log.debug("[#%04X] S: <CLOSE>", s.getpeername()[1])
s.shutdown(SHUT_RDWR)
Expand Down Expand Up @@ -499,27 +500,35 @@ def close(self):
class BoltPool(IOPool):

@classmethod
def open(cls, address, *, auth, pool_config, workspace_config):
def open(cls, address, *, auth, pool_config, workspace_config, routing_context=None):
"""Create a new BoltPool

:param address:
:param auth:
:param pool_config:
:param workspace_config:
:param routing_context:
:return: BoltPool
"""

if routing_context is None:
routing_context = {}
elif "address" in routing_context:
raise ConfigurationError("The key 'address' is reserved for routing context.")
routing_context["address"] = str(address)

def opener(addr, timeout):
return Bolt.open(addr, auth=auth, timeout=timeout, **pool_config)
return Bolt.open(addr, auth=auth, timeout=timeout, routing_context=routing_context, **pool_config)

pool = cls(opener, pool_config, workspace_config, address)
pool = cls(opener, pool_config, workspace_config, routing_context, address)
seeds = [pool.acquire() for _ in range(pool_config.init_size)]
pool.release(*seeds)
return pool

def __init__(self, opener, pool_config, workspace_config, address):
def __init__(self, opener, pool_config, workspace_config, routing_context, address):
super(BoltPool, self).__init__(opener, pool_config, workspace_config)
self.address = address
self.routing_context = routing_context

def __repr__(self):
return "<{} address={!r}>".format(self.__class__.__name__, self.address)
Expand All @@ -545,10 +554,17 @@ def open(cls, *addresses, auth, pool_config, workspace_config, routing_context=N
:return: Neo4jPool
"""

address = addresses[0]
if routing_context is None:
routing_context = {}
elif "address" in routing_context:
raise ConfigurationError("The key 'address' is reserved for routing context.")
routing_context["address"] = str(address)

def opener(addr, timeout):
return Bolt.open(addr, auth=auth, timeout=timeout, **pool_config)
return Bolt.open(addr, auth=auth, timeout=timeout, routing_context=routing_context, **pool_config)

pool = cls(opener, pool_config, workspace_config, routing_context, addresses)
pool = cls(opener, pool_config, workspace_config, routing_context, address)

try:
pool.update_routing_table(database=workspace_config.database)
Expand All @@ -558,7 +574,7 @@ def opener(addr, timeout):
else:
return pool

def __init__(self, opener, pool_config, workspace_config, routing_context, addresses):
def __init__(self, opener, pool_config, workspace_config, routing_context, address):
"""

:param opener:
Expand All @@ -569,15 +585,10 @@ def __init__(self, opener, pool_config, workspace_config, routing_context, addre
"""
super(Neo4jPool, self).__init__(opener, pool_config, workspace_config)
# Each database have a routing table, the default database is a special case.
log.debug("[#0000] C: <NEO4J POOL> routing addresses %r", addresses)
self.init_address = addresses[0]
self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=addresses)}
log.debug("[#0000] C: <NEO4J POOL> routing address %r", address)
self.address = address
self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])}
self.routing_context = routing_context
if self.routing_context is None:
self.routing_context = {}
elif "address" in self.routing_context:
raise ConfigurationError("The key 'address' is reserved for routing context.")
self.routing_context["address"] = str(self.init_address)
self.refresh_lock = Lock()

def __repr__(self):
Expand Down Expand Up @@ -621,7 +632,7 @@ def fetch_routing_info(self, *, address, timeout, database):
:param address: router address
:param timeout: seconds
:param database: the data base name to get routing table for
:param init_address: the address by which the client initially contacted the server as a hint for inclusion in the returned routing table.
:param address: the address by which the client initially contacted the server as a hint for inclusion in the returned routing table.

:return: list of routing records or
None if no connection could be established
Expand Down
3 changes: 2 additions & 1 deletion neo4j/io/_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class Bolt3(Bolt):
#: The pool of which this connection is a member
pool = None

def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None):
def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None):
self.unresolved_address = unresolved_address
self.socket = sock
self.server_info = ServerInfo(Address(sock.getpeername()), Bolt3.PROTOCOL_VERSION)
Expand All @@ -90,6 +90,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No
self.supports_multiple_results = False
self.supports_multiple_databases = False
self._is_reset = True
self.routing_context = routing_context

# Determine the user agent
if user_agent:
Expand Down
3 changes: 2 additions & 1 deletion neo4j/io/_bolt4x0.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class Bolt4x0(Bolt):
#: The pool of which this connection is a member
pool = None

def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None):
def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None):
self.unresolved_address = unresolved_address
self.socket = sock
self.server_info = ServerInfo(Address(sock.getpeername()), Bolt4x0.PROTOCOL_VERSION)
Expand All @@ -89,6 +89,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No
self.supports_multiple_results = True
self.supports_multiple_databases = True
self._is_reset = True
self.routing_context = routing_context

# Determine the user agent
if user_agent:
Expand Down
4 changes: 3 additions & 1 deletion neo4j/io/_bolt4x1.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class Bolt4x1(Bolt):
#: The pool of which this connection is a member
pool = None

def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None):
def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None):
self.unresolved_address = unresolved_address
self.socket = sock
self.server_info = ServerInfo(Address(sock.getpeername()), Bolt4x1.PROTOCOL_VERSION)
Expand All @@ -89,6 +89,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No
self.supports_multiple_results = True
self.supports_multiple_databases = True
self._is_reset = True
self.routing_context = routing_context

# Determine the user agent
if user_agent:
Expand Down Expand Up @@ -135,6 +136,7 @@ def local_port(self):
def hello(self):
headers = {"user_agent": self.user_agent}
headers.update(self.auth_dict)
headers["routing"] = self.routing_context
logged_headers = dict(headers)
if "credentials" in logged_headers:
logged_headers["credentials"] = "*******"
Expand Down
24 changes: 14 additions & 10 deletions tests/stub/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import subprocess
import os
import time

from threading import Thread
from time import sleep
Expand All @@ -44,18 +45,20 @@ def __init__(self, port, script):
def run(self):
self._process = subprocess.Popen(["python", "-m", "boltkit", "stub", "-v", "-l", ":{}".format(str(self.port)), "-t", "10", self.script], stdout=subprocess.PIPE)
# Need verbose for this to work
line = self._process.stdout.readline()
line = self._process.stdout.readline().decode("utf-8")
log.debug("started stub server {}".format(self.port))
log.debug(line.strip("\n"))

def wait(self):
try:
returncode = self._process.wait(2)
if returncode != 0:
log.debug("stubserver return code {}".format(returncode))
log.debug("check for miss match in script")
return returncode == 0
except subprocess.TimeoutExpired:
log.debug("stubserver timeout!")
return False
while True:
return_code = self._process.poll()
if return_code is not None:
line = self._process.stdout.readline().decode("utf-8")
if line == "":
break
log.debug(line.strip("\n"))

return True

def kill(self):
# Kill process if not already dead
Expand Down Expand Up @@ -129,6 +132,7 @@ class DefaultBoltStubService(BoltStubService):
class StubCluster(StubCluster):

def __init__(self, *servers):
print("")
scripts = [os.path.join(os.path.dirname(__file__), "scripts", server) for server in servers]

bss = DefaultBoltStubService.load(*scripts)
Expand Down
1 change: 1 addition & 0 deletions tests/stub/scripts/v3/empty.script
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
!: AUTO HELLO
!: AUTO GOODBYE
!: AUTO RESET
!: PORT 9001
1 change: 1 addition & 0 deletions tests/stub/scripts/v3/empty_explicit_hello_goodbye.script
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
!: BOLT 3
!: PORT 9001

C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"}
S: SUCCESS {"server": "Neo4j/3.5.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"}
Expand Down
3 changes: 2 additions & 1 deletion tests/stub/scripts/v4x0/empty.script
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
!: BOLT 4
!: AUTO HELLO
!: AUTO GOODBYE
!: AUTO RESET
!: AUTO RESET
!: PORT 9001
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
!: BOLT 4
!: PORT 9001

C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"}
S: SUCCESS {"server": "Neo4j/4.0.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"}
Expand Down
3 changes: 2 additions & 1 deletion tests/stub/scripts/v4x1/empty_explicit_hello_goodbye.script
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
!: BOLT 4.1
!: PORT 9001

C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"}
C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test", "routing": {"address": "localhost:9001"}}
S: SUCCESS {"server": "Neo4j/4.1.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"}
C: GOODBYE
S: <EXIT>
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
!: BOLT 4.1
!: AUTO GOODBYE
!: AUTO RESET
!: PORT 9002

C: HELLO {"scheme": "basic", "principal": "test", "credentials": "test", "user_agent": "test", "routing": {"address": "localhost:9001", "policy": "my_policy", "region": "china"}}
S: SUCCESS {"server": "Neo4j/4.1.0", "connection_id": "bolt-123456789"}
C: RUN "RETURN 1 AS x" {} {"mode": "r"}
PULL {"n": -1}
S: SUCCESS {"fields": ["x"]}
RECORD [1]
SUCCESS {"bookmark": "neo4j:bookmark-test-2", "type": "r", "t_last": 5, "db": "system"}
2 changes: 1 addition & 1 deletion tests/stub/scripts/v4x1/return_1_noop_port_9001.script
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
!: AUTO RESET
!: PORT 9001

C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"}
C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test", "routing": {"address": "localhost:9001"}}
S: SUCCESS {"server": "Neo4j/4.1.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"}
C: RUN "RETURN 1 AS x" {} {"mode": "r"}
PULL {"n": 2}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
!: BOLT 4.1
!: AUTO GOODBYE
!: AUTO RESET
!: PORT 9001

C: HELLO {"scheme": "basic", "principal": "test", "credentials": "test", "user_agent": "test", "routing": {"address": "localhost:9001", "policy": "my_policy", "region": "china"}}
S: SUCCESS {"server": "Neo4j/4.1.0", "connection_id": "bolt-123456789"}
C: RUN "CALL dbms.routing.getRoutingTable($context)" {"context": {"address": "localhost:9001", "policy": "my_policy", "region": "china"}} {"mode": "r", "db": "system"}
PULL {"n": -1}
S: SUCCESS {"fields": ["ttl", "servers"]}
RECORD [4321, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "READ"}, {"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]]
SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "r", "t_last": 5, "db": "system"}
4 changes: 2 additions & 2 deletions tests/stub/test_directdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_bolt_uri_constructs_bolt_driver(driver_info, test_script):
def test_direct_driver_handshake_negotiation(driver_info, test_script, test_expected):
# python -m pytest tests/stub/test_directdriver.py -s -v -k test_direct_driver_handshake_negotiation
with StubCluster(test_script):
uri = "bolt://127.0.0.1:9001"
uri = "bolt://localhost:9001"
if test_expected:
with pytest.raises(test_expected) as error:
driver = GraphDatabase.driver(uri, auth=driver_info["auth_token"], **driver_config)
Expand Down Expand Up @@ -586,7 +586,7 @@ def test_bolt_driver_explicit_transaction_consume_result_case_b(driver_info, tes
def test_direct_can_handle_noop(driver_info, test_script):
# python -m pytest tests/stub/test_directdriver.py -s -v -k test_direct_can_handle_noop
with StubCluster(test_script):
uri = "bolt://127.0.0.1:9001"
uri = "bolt://localhost:9001"
with GraphDatabase.driver(uri, auth=driver_info["auth_token"], **driver_config) as driver:
assert isinstance(driver, BoltDriver)
with driver.session(fetch_size=2, default_access_mode=READ_ACCESS) as session:
Expand Down
20 changes: 20 additions & 0 deletions tests/stub/test_routingdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,3 +648,23 @@ def test_forgets_address_on_database_unavailable_error(driver_info, test_scripts
# reader 127.0.0.1:9004 should've been forgotten because of an raised
assert not table.readers
assert table.writers == {('127.0.0.1', 9006)}


@pytest.mark.parametrize(
"test_scripts",
[
("v4x1/router_get_routing_table_with_context.script", "v4x1/hello_with_routing_context_return_1_port_9002.script",),
]
)
def test_hello_routing(driver_info, test_scripts):
# python -m pytest tests/stub/test_routingdriver.py -s -v -k test_hello_routing
with StubCluster(*test_scripts):
uri = "neo4j://localhost:9001/?region=china&policy=my_policy"
with GraphDatabase.driver(uri, auth=driver_info["auth_token"], user_agent="test") as driver:
with driver.session(default_access_mode=READ_ACCESS, fetch_size=-1) as session:
result = session.run("RETURN 1 AS x")
for record in result:
assert record["x"] == 1
address = result.consume().server.address
assert address.host == "127.0.0.1"
assert address.port == 9002