Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
]

import abc
from collections import deque
from collections import (
defaultdict,
deque,
)
from logging import getLogger
from random import choice
from select import select
Expand Down Expand Up @@ -618,7 +621,7 @@ def __init__(self, opener, pool_config, workspace_config):
self.opener = opener
self.pool_config = pool_config
self.workspace_config = workspace_config
self.connections = {}
self.connections = defaultdict(deque)
self.lock = RLock()
self.cond = Condition(self.lock)

Expand All @@ -640,35 +643,44 @@ def _acquire(self, address, timeout):
timeout = self.workspace_config.connection_acquisition_timeout

with self.lock:
try:
connections = self.connections[address]
except KeyError:
connections = self.connections[address] = deque()

def time_remaining():
t = timeout - (perf_counter() - t0)
return t if t > 0 else 0

while True:
# try to find a free connection in pool
for connection in list(connections):
for connection in list(self.connections.get(address, [])):
if (connection.closed() or connection.defunct()
or connection.stale()):
# `close` is a noop on already closed connections.
# This is to make sure that the connection is gracefully
# closed, e.g. if it's just marked as `stale` but still
# alive.
connection.close()
connections.remove(connection)
try:
self.connections.get(address, []).remove(connection)
except ValueError:
# If closure fails (e.g. because the server went
# down), all connections to the same address will
# be removed. Therefore, we silently ignore if the
# connection isn't in the pool anymore.
pass
continue
if not connection.in_use:
connection.in_use = True
return connection
# all connections in pool are in-use
infinite_pool_size = (self.pool_config.max_connection_pool_size < 0 or self.pool_config.max_connection_pool_size == float("inf"))
can_create_new_connection = infinite_pool_size or len(connections) < self.pool_config.max_connection_pool_size
connections = self.connections[address]
max_pool_size = self.pool_config.max_connection_pool_size
infinite_pool_size = (max_pool_size < 0
or max_pool_size == float("inf"))
can_create_new_connection = (
infinite_pool_size
or len(connections) < max_pool_size
)
if can_create_new_connection:
timeout = min(self.pool_config.connection_timeout, time_remaining())
timeout = min(self.pool_config.connection_timeout,
time_remaining())
try:
connection = self.opener(address, timeout)
except ServiceUnavailable:
Expand Down
82 changes: 65 additions & 17 deletions tests/unit/io/test_neo4j_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@

from ..work import FakeConnection

from neo4j import READ_ACCESS
from neo4j import (
READ_ACCESS,
WRITE_ACCESS,
)
from neo4j.addressing import ResolvedAddress
from neo4j.conf import (
PoolConfig,
Expand All @@ -35,23 +38,24 @@
from neo4j.io import Neo4jPool


ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host")
READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host")
WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host")


@pytest.fixture()
def opener():
def open_(*_, **__):
def open_(addr, timeout):
connection = FakeConnection()
connection.addr = addr
connection.timeout = timeout
route_mock = Mock()
route_mock.return_value = [{
"ttl": 1000,
"servers": [
{"addresses": ["1.2.3.1:9001"], "role": "ROUTE"},
{
"addresses": ["1.2.3.10:9010", "1.2.3.11:9011"],
"role": "READ"
},
{
"addresses": ["1.2.3.20:9020", "1.2.3.21:9021"],
"role": "WRITE"
},
{"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"},
{"addresses": [str(READER_ADDRESS)], "role": "READ"},
{"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"},
],
}]
connection.attach_mock(route_mock, "route")
Expand All @@ -65,8 +69,7 @@ def open_(*_, **__):


def test_acquires_new_routing_table_if_deleted(opener):
address = ResolvedAddress(("1.2.3.1", 9001), host_name="host")
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), address)
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
cx = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx)
assert pool.routing_tables.get("test_db")
Expand All @@ -79,8 +82,7 @@ def test_acquires_new_routing_table_if_deleted(opener):


def test_acquires_new_routing_table_if_stale(opener):
address = ResolvedAddress(("1.2.3.1", 9001), host_name="host")
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), address)
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
cx = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx)
assert pool.routing_tables.get("test_db")
Expand All @@ -94,8 +96,7 @@ def test_acquires_new_routing_table_if_stale(opener):


def test_removes_old_routing_table(opener):
address = ResolvedAddress(("1.2.3.1", 9001), host_name="host")
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), address)
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
cx = pool.acquire(READ_ACCESS, 30, "test_db1", None)
pool.release(cx)
assert pool.routing_tables.get("test_db1")
Expand All @@ -113,3 +114,50 @@ def test_removes_old_routing_table(opener):
assert pool.routing_tables["test_db1"].last_updated_time > old_value
assert "test_db2" not in pool.routing_tables


@pytest.mark.parametrize("type_", ("r", "w"))
def test_chooses_right_connection_type(opener, type_):
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS,
30, "test_db", None)
pool.release(cx1)
if type_ == "r":
assert cx1.addr == READER_ADDRESS
else:
assert cx1.addr == WRITER_ADDRESS


def test_reuses_connection(opener):
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx1)
cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None)
assert cx1 is cx2


@pytest.mark.parametrize("break_on_close", (True, False))
def test_closes_stale_connections(opener, break_on_close):
def break_connection():
pool.deactivate(cx1.addr)

if cx_close_mock_side_effect:
cx_close_mock_side_effect()

pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx1)
assert cx1 in pool.connections[cx1.addr]
# simulate connection going stale (e.g. exceeding) and than breaking when
# the pool tries to close the connection
cx1.stale.return_value = True
cx_close_mock = cx1.close
if break_on_close:
cx_close_mock_side_effect = cx_close_mock.side_effect
cx_close_mock.side_effect = break_connection
cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx2)
assert cx1.close.called_once()
assert cx2 is not cx1
assert cx2.addr == cx1.addr
assert cx1 not in pool.connections[cx1.addr]
assert cx2 in pool.connections[cx2.addr]
1 change: 0 additions & 1 deletion tests/unit/work/_fake_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def callback():
return parent.__getattr__(name)



@pytest.fixture
def fake_connection():
return FakeConnection()