Skip to content

Commit

Permalink
Fix node_pool_class override (#2581)
Browse files Browse the repository at this point in the history
Co-authored-by: Quentin Pradet <quentin.pradet@elastic.co>
  • Loading branch information
tallakh and pquentin committed May 31, 2024
1 parent 9a650e6 commit 7e9ea0d
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 4 deletions.
2 changes: 1 addition & 1 deletion elasticsearch/_async/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def __init__(
if node_class is not DEFAULT:
transport_kwargs["node_class"] = node_class
if node_pool_class is not DEFAULT:
transport_kwargs["node_pool_class"] = node_class
transport_kwargs["node_pool_class"] = node_pool_class
if randomize_nodes_in_pool is not DEFAULT:
transport_kwargs["randomize_nodes_in_pool"] = randomize_nodes_in_pool
if node_selector_class is not DEFAULT:
Expand Down
2 changes: 1 addition & 1 deletion elasticsearch/_sync/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def __init__(
if node_class is not DEFAULT:
transport_kwargs["node_class"] = node_class
if node_pool_class is not DEFAULT:
transport_kwargs["node_pool_class"] = node_class
transport_kwargs["node_pool_class"] = node_pool_class
if randomize_nodes_in_pool is not DEFAULT:
transport_kwargs["randomize_nodes_in_pool"] = randomize_nodes_in_pool
if node_selector_class is not DEFAULT:
Expand Down
37 changes: 36 additions & 1 deletion test_elasticsearch/test_async/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
from typing import Any, Dict, Optional, Union

import pytest
from elastic_transport import ApiResponseMeta, BaseAsyncNode, HttpHeaders, NodeConfig
from elastic_transport import (
ApiResponseMeta,
BaseAsyncNode,
HttpHeaders,
NodeConfig,
NodePool,
)
from elastic_transport._node import NodeApiResponse
from elastic_transport.client_utils import DEFAULT

Expand Down Expand Up @@ -73,6 +79,14 @@ async def close(self):
self.closed = True


class NoTimeoutConnectionPool(NodePool):
def mark_dead(self, connection):
pass

def mark_live(self, connection):
pass


CLUSTER_NODES = """{
"_nodes" : {
"total" : 1,
Expand Down Expand Up @@ -345,6 +359,27 @@ async def test_resurrected_connection_will_be_marked_as_live_on_success(self):
assert len(client.transport.node_pool._alive_nodes) == 1
assert len(client.transport.node_pool._dead_consecutive_failures) == 1

async def test_override_mark_dead_mark_live(self):
client = AsyncElasticsearch(
[
NodeConfig("http", "localhost", 9200),
NodeConfig("http", "localhost", 9201),
],
node_class=DummyNode,
node_pool_class=NoTimeoutConnectionPool,
)
node1 = client.transport.node_pool.get()
node2 = client.transport.node_pool.get()
assert node1 is not node2
client.transport.node_pool.mark_dead(node1)
client.transport.node_pool.mark_dead(node2)
assert len(client.transport.node_pool._alive_nodes) == 2

await client.info()

assert len(client.transport.node_pool._alive_nodes) == 2
assert len(client.transport.node_pool._dead_consecutive_failures) == 0

@pytest.mark.parametrize(
["nodes_info_response", "node_host"],
[(CLUSTER_NODES, "1.1.1.1"), (CLUSTER_NODES_7x_PUBLISH_HOST, "somehost.tld")],
Expand Down
37 changes: 36 additions & 1 deletion test_elasticsearch/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from typing import Any, Dict, Optional, Union

import pytest
from elastic_transport import ApiResponseMeta, BaseNode, HttpHeaders, NodeConfig
from elastic_transport import (
ApiResponseMeta,
BaseNode,
HttpHeaders,
NodeConfig,
NodePool,
)
from elastic_transport._node import NodeApiResponse
from elastic_transport.client_utils import DEFAULT

Expand Down Expand Up @@ -64,6 +70,14 @@ def perform_request(self, *args, **kwargs):
)


class NoTimeoutConnectionPool(NodePool):
def mark_dead(self, connection):
pass

def mark_live(self, connection):
pass


CLUSTER_NODES = """{
"_nodes" : {
"total" : 1,
Expand Down Expand Up @@ -376,6 +390,27 @@ def test_resurrected_connection_will_be_marked_as_live_on_success(self):
assert len(client.transport.node_pool._alive_nodes) == 1
assert len(client.transport.node_pool._dead_consecutive_failures) == 1

def test_override_mark_dead_mark_live(self):
client = Elasticsearch(
[
NodeConfig("http", "localhost", 9200),
NodeConfig("http", "localhost", 9201),
],
node_class=DummyNode,
node_pool_class=NoTimeoutConnectionPool,
)
node1 = client.transport.node_pool.get()
node2 = client.transport.node_pool.get()
assert node1 is not node2
client.transport.node_pool.mark_dead(node1)
client.transport.node_pool.mark_dead(node2)
assert len(client.transport.node_pool._alive_nodes) == 2

client.info()

assert len(client.transport.node_pool._alive_nodes) == 2
assert len(client.transport.node_pool._dead_consecutive_failures) == 0

@pytest.mark.parametrize(
["nodes_info_response", "node_host"],
[(CLUSTER_NODES, "1.1.1.1"), (CLUSTER_NODES_7x_PUBLISH_HOST, "somehost.tld")],
Expand Down

0 comments on commit 7e9ea0d

Please sign in to comment.