From 7e9ea0d6d8e283ecb8ef08fba626a19bd21813bc Mon Sep 17 00:00:00 2001 From: Tallak Hellebust Date: Fri, 31 May 2024 15:51:50 +0200 Subject: [PATCH] Fix node_pool_class override (#2581) Co-authored-by: Quentin Pradet --- elasticsearch/_async/client/__init__.py | 2 +- elasticsearch/_sync/client/__init__.py | 2 +- .../test_async/test_transport.py | 37 ++++++++++++++++++- test_elasticsearch/test_transport.py | 37 ++++++++++++++++++- 4 files changed, 74 insertions(+), 4 deletions(-) diff --git a/elasticsearch/_async/client/__init__.py b/elasticsearch/_async/client/__init__.py index 28eb58477..3661d769d 100644 --- a/elasticsearch/_async/client/__init__.py +++ b/elasticsearch/_async/client/__init__.py @@ -352,7 +352,7 @@ def __init__( if node_class is not DEFAULT: transport_kwargs["node_class"] = node_class if node_pool_class is not DEFAULT: - transport_kwargs["node_pool_class"] = node_class + transport_kwargs["node_pool_class"] = node_pool_class if randomize_nodes_in_pool is not DEFAULT: transport_kwargs["randomize_nodes_in_pool"] = randomize_nodes_in_pool if node_selector_class is not DEFAULT: diff --git a/elasticsearch/_sync/client/__init__.py b/elasticsearch/_sync/client/__init__.py index 80b01e580..97ca512cb 100644 --- a/elasticsearch/_sync/client/__init__.py +++ b/elasticsearch/_sync/client/__init__.py @@ -352,7 +352,7 @@ def __init__( if node_class is not DEFAULT: transport_kwargs["node_class"] = node_class if node_pool_class is not DEFAULT: - transport_kwargs["node_pool_class"] = node_class + transport_kwargs["node_pool_class"] = node_pool_class if randomize_nodes_in_pool is not DEFAULT: transport_kwargs["randomize_nodes_in_pool"] = randomize_nodes_in_pool if node_selector_class is not DEFAULT: diff --git a/test_elasticsearch/test_async/test_transport.py b/test_elasticsearch/test_async/test_transport.py index 4d34e23bc..918e19e57 100644 --- a/test_elasticsearch/test_async/test_transport.py +++ b/test_elasticsearch/test_async/test_transport.py @@ -24,7 +24,13 @@ from typing import Any, Dict, Optional, Union import pytest -from elastic_transport import ApiResponseMeta, BaseAsyncNode, HttpHeaders, NodeConfig +from elastic_transport import ( + ApiResponseMeta, + BaseAsyncNode, + HttpHeaders, + NodeConfig, + NodePool, +) from elastic_transport._node import NodeApiResponse from elastic_transport.client_utils import DEFAULT @@ -73,6 +79,14 @@ async def close(self): self.closed = True +class NoTimeoutConnectionPool(NodePool): + def mark_dead(self, connection): + pass + + def mark_live(self, connection): + pass + + CLUSTER_NODES = """{ "_nodes" : { "total" : 1, @@ -345,6 +359,27 @@ async def test_resurrected_connection_will_be_marked_as_live_on_success(self): assert len(client.transport.node_pool._alive_nodes) == 1 assert len(client.transport.node_pool._dead_consecutive_failures) == 1 + async def test_override_mark_dead_mark_live(self): + client = AsyncElasticsearch( + [ + NodeConfig("http", "localhost", 9200), + NodeConfig("http", "localhost", 9201), + ], + node_class=DummyNode, + node_pool_class=NoTimeoutConnectionPool, + ) + node1 = client.transport.node_pool.get() + node2 = client.transport.node_pool.get() + assert node1 is not node2 + client.transport.node_pool.mark_dead(node1) + client.transport.node_pool.mark_dead(node2) + assert len(client.transport.node_pool._alive_nodes) == 2 + + await client.info() + + assert len(client.transport.node_pool._alive_nodes) == 2 + assert len(client.transport.node_pool._dead_consecutive_failures) == 0 + @pytest.mark.parametrize( ["nodes_info_response", "node_host"], [(CLUSTER_NODES, "1.1.1.1"), (CLUSTER_NODES_7x_PUBLISH_HOST, "somehost.tld")], diff --git a/test_elasticsearch/test_transport.py b/test_elasticsearch/test_transport.py index b3b6de097..ce8b7f901 100644 --- a/test_elasticsearch/test_transport.py +++ b/test_elasticsearch/test_transport.py @@ -22,7 +22,13 @@ from typing import Any, Dict, Optional, Union import pytest -from elastic_transport import ApiResponseMeta, BaseNode, HttpHeaders, NodeConfig +from elastic_transport import ( + ApiResponseMeta, + BaseNode, + HttpHeaders, + NodeConfig, + NodePool, +) from elastic_transport._node import NodeApiResponse from elastic_transport.client_utils import DEFAULT @@ -64,6 +70,14 @@ def perform_request(self, *args, **kwargs): ) +class NoTimeoutConnectionPool(NodePool): + def mark_dead(self, connection): + pass + + def mark_live(self, connection): + pass + + CLUSTER_NODES = """{ "_nodes" : { "total" : 1, @@ -376,6 +390,27 @@ def test_resurrected_connection_will_be_marked_as_live_on_success(self): assert len(client.transport.node_pool._alive_nodes) == 1 assert len(client.transport.node_pool._dead_consecutive_failures) == 1 + def test_override_mark_dead_mark_live(self): + client = Elasticsearch( + [ + NodeConfig("http", "localhost", 9200), + NodeConfig("http", "localhost", 9201), + ], + node_class=DummyNode, + node_pool_class=NoTimeoutConnectionPool, + ) + node1 = client.transport.node_pool.get() + node2 = client.transport.node_pool.get() + assert node1 is not node2 + client.transport.node_pool.mark_dead(node1) + client.transport.node_pool.mark_dead(node2) + assert len(client.transport.node_pool._alive_nodes) == 2 + + client.info() + + assert len(client.transport.node_pool._alive_nodes) == 2 + assert len(client.transport.node_pool._dead_consecutive_failures) == 0 + @pytest.mark.parametrize( ["nodes_info_response", "node_host"], [(CLUSTER_NODES, "1.1.1.1"), (CLUSTER_NODES_7x_PUBLISH_HOST, "somehost.tld")],