diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 1967deadb..38410dba6 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -35,7 +35,10 @@ ] import abc -from collections import deque +from collections import ( + defaultdict, + deque, +) from logging import getLogger from random import choice from select import select @@ -618,7 +621,7 @@ def __init__(self, opener, pool_config, workspace_config): self.opener = opener self.pool_config = pool_config self.workspace_config = workspace_config - self.connections = {} + self.connections = defaultdict(deque) self.lock = RLock() self.cond = Condition(self.lock) @@ -640,18 +643,13 @@ def _acquire(self, address, timeout): timeout = self.workspace_config.connection_acquisition_timeout with self.lock: - try: - connections = self.connections[address] - except KeyError: - connections = self.connections[address] = deque() - def time_remaining(): t = timeout - (perf_counter() - t0) return t if t > 0 else 0 while True: # try to find a free connection in pool - for connection in list(connections): + for connection in list(self.connections.get(address, [])): if (connection.closed() or connection.defunct() or connection.stale()): # `close` is a noop on already closed connections. @@ -659,16 +657,30 @@ def time_remaining(): # closed, e.g. if it's just marked as `stale` but still # alive. connection.close() - connections.remove(connection) + try: + self.connections.get(address, []).remove(connection) + except ValueError: + # If closure fails (e.g. because the server went + # down), all connections to the same address will + # be removed. Therefore, we silently ignore if the + # connection isn't in the pool anymore. + pass continue if not connection.in_use: connection.in_use = True return connection # all connections in pool are in-use - infinite_pool_size = (self.pool_config.max_connection_pool_size < 0 or self.pool_config.max_connection_pool_size == float("inf")) - can_create_new_connection = infinite_pool_size or len(connections) < self.pool_config.max_connection_pool_size + connections = self.connections[address] + max_pool_size = self.pool_config.max_connection_pool_size + infinite_pool_size = (max_pool_size < 0 + or max_pool_size == float("inf")) + can_create_new_connection = ( + infinite_pool_size + or len(connections) < max_pool_size + ) if can_create_new_connection: - timeout = min(self.pool_config.connection_timeout, time_remaining()) + timeout = min(self.pool_config.connection_timeout, + time_remaining()) try: connection = self.opener(address, timeout) except ServiceUnavailable: diff --git a/tests/unit/io/test_neo4j_pool.py b/tests/unit/io/test_neo4j_pool.py index 0ce2af492..458118057 100644 --- a/tests/unit/io/test_neo4j_pool.py +++ b/tests/unit/io/test_neo4j_pool.py @@ -25,7 +25,10 @@ from ..work import FakeConnection -from neo4j import READ_ACCESS +from neo4j import ( + READ_ACCESS, + WRITE_ACCESS, +) from neo4j.addressing import ResolvedAddress from neo4j.conf import ( PoolConfig, @@ -35,23 +38,24 @@ from neo4j.io import Neo4jPool +ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") +READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") +WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host") + + @pytest.fixture() def opener(): - def open_(*_, **__): + def open_(addr, timeout): connection = FakeConnection() + connection.addr = addr + connection.timeout = timeout route_mock = Mock() route_mock.return_value = [{ "ttl": 1000, "servers": [ - {"addresses": ["1.2.3.1:9001"], "role": "ROUTE"}, - { - "addresses": ["1.2.3.10:9010", "1.2.3.11:9011"], - "role": "READ" - }, - { - "addresses": ["1.2.3.20:9020", "1.2.3.21:9021"], - "role": "WRITE" - }, + {"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"}, + {"addresses": [str(READER_ADDRESS)], "role": "READ"}, + {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, ], }] connection.attach_mock(route_mock, "route") @@ -65,8 +69,7 @@ def open_(*_, **__): def test_acquires_new_routing_table_if_deleted(opener): - address = ResolvedAddress(("1.2.3.1", 9001), host_name="host") - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), address) + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) cx = pool.acquire(READ_ACCESS, 30, "test_db", None) pool.release(cx) assert pool.routing_tables.get("test_db") @@ -79,8 +82,7 @@ def test_acquires_new_routing_table_if_deleted(opener): def test_acquires_new_routing_table_if_stale(opener): - address = ResolvedAddress(("1.2.3.1", 9001), host_name="host") - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), address) + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) cx = pool.acquire(READ_ACCESS, 30, "test_db", None) pool.release(cx) assert pool.routing_tables.get("test_db") @@ -94,8 +96,7 @@ def test_acquires_new_routing_table_if_stale(opener): def test_removes_old_routing_table(opener): - address = ResolvedAddress(("1.2.3.1", 9001), host_name="host") - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), address) + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) cx = pool.acquire(READ_ACCESS, 30, "test_db1", None) pool.release(cx) assert pool.routing_tables.get("test_db1") @@ -113,3 +114,50 @@ def test_removes_old_routing_table(opener): assert pool.routing_tables["test_db1"].last_updated_time > old_value assert "test_db2" not in pool.routing_tables + +@pytest.mark.parametrize("type_", ("r", "w")) +def test_chooses_right_connection_type(opener, type_): + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS, + 30, "test_db", None) + pool.release(cx1) + if type_ == "r": + assert cx1.addr == READER_ADDRESS + else: + assert cx1.addr == WRITER_ADDRESS + + +def test_reuses_connection(opener): + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx1) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + assert cx1 is cx2 + + +@pytest.mark.parametrize("break_on_close", (True, False)) +def test_closes_stale_connections(opener, break_on_close): + def break_connection(): + pool.deactivate(cx1.addr) + + if cx_close_mock_side_effect: + cx_close_mock_side_effect() + + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx1) + assert cx1 in pool.connections[cx1.addr] + # simulate connection going stale (e.g. exceeding) and than breaking when + # the pool tries to close the connection + cx1.stale.return_value = True + cx_close_mock = cx1.close + if break_on_close: + cx_close_mock_side_effect = cx_close_mock.side_effect + cx_close_mock.side_effect = break_connection + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx2) + assert cx1.close.called_once() + assert cx2 is not cx1 + assert cx2.addr == cx1.addr + assert cx1 not in pool.connections[cx1.addr] + assert cx2 in pool.connections[cx2.addr] diff --git a/tests/unit/work/_fake_connection.py b/tests/unit/work/_fake_connection.py index 4d4550627..25b272fea 100644 --- a/tests/unit/work/_fake_connection.py +++ b/tests/unit/work/_fake_connection.py @@ -87,7 +87,6 @@ def callback(): return parent.__getattr__(name) - @pytest.fixture def fake_connection(): return FakeConnection()