From dae113f45a685803118a4b346d919e7965a0be9f Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Fri, 1 Oct 2021 14:10:55 +0200 Subject: [PATCH] Fix removing connection twice from pool. When trying to close a stale connection the driver count realize that the connection is dead on trying to send GOODBYE. This would cause the connection to make sure that all connections to the same address would get removed from the pool as well. Since this removal only happens as a side effect of `connection.close()` and does not always happen, the driver would still try to remove the (now already removed) connection form the pool after closure. Fixes: `ValueError: deque.remove(x): x not in deque` --- neo4j/io/__init__.py | 36 ++++++++----- tests/unit/io/test_neo4j_pool.py | 82 +++++++++++++++++++++++------ tests/unit/work/_fake_connection.py | 1 - 3 files changed, 89 insertions(+), 30 deletions(-) diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 1967deadb..38410dba6 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -35,7 +35,10 @@ ] import abc -from collections import deque +from collections import ( + defaultdict, + deque, +) from logging import getLogger from random import choice from select import select @@ -618,7 +621,7 @@ def __init__(self, opener, pool_config, workspace_config): self.opener = opener self.pool_config = pool_config self.workspace_config = workspace_config - self.connections = {} + self.connections = defaultdict(deque) self.lock = RLock() self.cond = Condition(self.lock) @@ -640,18 +643,13 @@ def _acquire(self, address, timeout): timeout = self.workspace_config.connection_acquisition_timeout with self.lock: - try: - connections = self.connections[address] - except KeyError: - connections = self.connections[address] = deque() - def time_remaining(): t = timeout - (perf_counter() - t0) return t if t > 0 else 0 while True: # try to find a free connection in pool - for connection in list(connections): + for connection in list(self.connections.get(address, [])): if (connection.closed() or connection.defunct() or connection.stale()): # `close` is a noop on already closed connections. @@ -659,16 +657,30 @@ def time_remaining(): # closed, e.g. if it's just marked as `stale` but still # alive. connection.close() - connections.remove(connection) + try: + self.connections.get(address, []).remove(connection) + except ValueError: + # If closure fails (e.g. because the server went + # down), all connections to the same address will + # be removed. Therefore, we silently ignore if the + # connection isn't in the pool anymore. + pass continue if not connection.in_use: connection.in_use = True return connection # all connections in pool are in-use - infinite_pool_size = (self.pool_config.max_connection_pool_size < 0 or self.pool_config.max_connection_pool_size == float("inf")) - can_create_new_connection = infinite_pool_size or len(connections) < self.pool_config.max_connection_pool_size + connections = self.connections[address] + max_pool_size = self.pool_config.max_connection_pool_size + infinite_pool_size = (max_pool_size < 0 + or max_pool_size == float("inf")) + can_create_new_connection = ( + infinite_pool_size + or len(connections) < max_pool_size + ) if can_create_new_connection: - timeout = min(self.pool_config.connection_timeout, time_remaining()) + timeout = min(self.pool_config.connection_timeout, + time_remaining()) try: connection = self.opener(address, timeout) except ServiceUnavailable: diff --git a/tests/unit/io/test_neo4j_pool.py b/tests/unit/io/test_neo4j_pool.py index 0ce2af492..458118057 100644 --- a/tests/unit/io/test_neo4j_pool.py +++ b/tests/unit/io/test_neo4j_pool.py @@ -25,7 +25,10 @@ from ..work import FakeConnection -from neo4j import READ_ACCESS +from neo4j import ( + READ_ACCESS, + WRITE_ACCESS, +) from neo4j.addressing import ResolvedAddress from neo4j.conf import ( PoolConfig, @@ -35,23 +38,24 @@ from neo4j.io import Neo4jPool +ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") +READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") +WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host") + + @pytest.fixture() def opener(): - def open_(*_, **__): + def open_(addr, timeout): connection = FakeConnection() + connection.addr = addr + connection.timeout = timeout route_mock = Mock() route_mock.return_value = [{ "ttl": 1000, "servers": [ - {"addresses": ["1.2.3.1:9001"], "role": "ROUTE"}, - { - "addresses": ["1.2.3.10:9010", "1.2.3.11:9011"], - "role": "READ" - }, - { - "addresses": ["1.2.3.20:9020", "1.2.3.21:9021"], - "role": "WRITE" - }, + {"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"}, + {"addresses": [str(READER_ADDRESS)], "role": "READ"}, + {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, ], }] connection.attach_mock(route_mock, "route") @@ -65,8 +69,7 @@ def open_(*_, **__): def test_acquires_new_routing_table_if_deleted(opener): - address = ResolvedAddress(("1.2.3.1", 9001), host_name="host") - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), address) + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) cx = pool.acquire(READ_ACCESS, 30, "test_db", None) pool.release(cx) assert pool.routing_tables.get("test_db") @@ -79,8 +82,7 @@ def test_acquires_new_routing_table_if_deleted(opener): def test_acquires_new_routing_table_if_stale(opener): - address = ResolvedAddress(("1.2.3.1", 9001), host_name="host") - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), address) + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) cx = pool.acquire(READ_ACCESS, 30, "test_db", None) pool.release(cx) assert pool.routing_tables.get("test_db") @@ -94,8 +96,7 @@ def test_acquires_new_routing_table_if_stale(opener): def test_removes_old_routing_table(opener): - address = ResolvedAddress(("1.2.3.1", 9001), host_name="host") - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), address) + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) cx = pool.acquire(READ_ACCESS, 30, "test_db1", None) pool.release(cx) assert pool.routing_tables.get("test_db1") @@ -113,3 +114,50 @@ def test_removes_old_routing_table(opener): assert pool.routing_tables["test_db1"].last_updated_time > old_value assert "test_db2" not in pool.routing_tables + +@pytest.mark.parametrize("type_", ("r", "w")) +def test_chooses_right_connection_type(opener, type_): + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS, + 30, "test_db", None) + pool.release(cx1) + if type_ == "r": + assert cx1.addr == READER_ADDRESS + else: + assert cx1.addr == WRITER_ADDRESS + + +def test_reuses_connection(opener): + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx1) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + assert cx1 is cx2 + + +@pytest.mark.parametrize("break_on_close", (True, False)) +def test_closes_stale_connections(opener, break_on_close): + def break_connection(): + pool.deactivate(cx1.addr) + + if cx_close_mock_side_effect: + cx_close_mock_side_effect() + + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx1) + assert cx1 in pool.connections[cx1.addr] + # simulate connection going stale (e.g. exceeding) and than breaking when + # the pool tries to close the connection + cx1.stale.return_value = True + cx_close_mock = cx1.close + if break_on_close: + cx_close_mock_side_effect = cx_close_mock.side_effect + cx_close_mock.side_effect = break_connection + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx2) + assert cx1.close.called_once() + assert cx2 is not cx1 + assert cx2.addr == cx1.addr + assert cx1 not in pool.connections[cx1.addr] + assert cx2 in pool.connections[cx2.addr] diff --git a/tests/unit/work/_fake_connection.py b/tests/unit/work/_fake_connection.py index 4d4550627..25b272fea 100644 --- a/tests/unit/work/_fake_connection.py +++ b/tests/unit/work/_fake_connection.py @@ -87,7 +87,6 @@ def callback(): return parent.__getattr__(name) - @pytest.fixture def fake_connection(): return FakeConnection()