diff --git a/elasticsearch/_async/transport.py b/elasticsearch/_async/transport.py index 47f2df20d..f7a2b2f51 100644 --- a/elasticsearch/_async/transport.py +++ b/elasticsearch/_async/transport.py @@ -316,20 +316,26 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non retry = True if retry: - # only mark as dead if we are retrying - self.mark_dead(connection) + try: + # only mark as dead if we are retrying + self.mark_dead(connection) + except TransportError: + # If sniffing on failure, it could fail too. Catch the + # exception not to interrupt the retries. + pass # raise exception on last retry if attempt == self.max_retries: - raise + raise e else: - raise + raise e else: + # connection didn't fail, confirm it's live status + self.connection_pool.mark_live(connection) + if method == "HEAD": return 200 <= status < 300 - # connection didn't fail, confirm it's live status - self.connection_pool.mark_live(connection) if data: data = self.deserializer.loads(data, headers.get("content-type")) return data diff --git a/elasticsearch/transport.py b/elasticsearch/transport.py index af81e8da2..79471b7e8 100644 --- a/elasticsearch/transport.py +++ b/elasticsearch/transport.py @@ -378,13 +378,18 @@ def perform_request(self, method, url, headers=None, params=None, body=None): retry = True if retry: - # only mark as dead if we are retrying - self.mark_dead(connection) + try: + # only mark as dead if we are retrying + self.mark_dead(connection) + except TransportError: + # If sniffing on failure, it could fail too. Catch the + # exception not to interrupt the retries. + pass # raise exception on last retry if attempt == self.max_retries: - raise + raise e else: - raise + raise e else: # connection didn't fail, confirm it's live status diff --git a/test_elasticsearch/test_async/test_transport.py b/test_elasticsearch/test_async/test_transport.py index 0bd654c7b..293363ffd 100644 --- a/test_elasticsearch/test_async/test_transport.py +++ b/test_elasticsearch/test_async/test_transport.py @@ -18,13 +18,14 @@ from __future__ import unicode_literals import asyncio +import json from mock import patch import pytest from elasticsearch import AsyncTransport from elasticsearch.connection import Connection from elasticsearch.connection_pool import DummyConnectionPool -from elasticsearch.exceptions import ConnectionError +from elasticsearch.exceptions import ConnectionError, TransportError pytestmark = pytest.mark.asyncio @@ -273,16 +274,17 @@ async def test_failed_connection_will_be_marked_as_dead(self): assert 0 == len(t.connection_pool.connections) async def test_resurrected_connection_will_be_marked_as_live_on_success(self): - t = AsyncTransport([{}, {}], connection_class=DummyConnection) - await t._async_call() - con1 = t.connection_pool.get_connection() - con2 = t.connection_pool.get_connection() - t.connection_pool.mark_dead(con1) - t.connection_pool.mark_dead(con2) - - await t.perform_request("GET", "/") - assert 1 == len(t.connection_pool.connections) - assert 1 == len(t.connection_pool.dead_count) + for method in ("GET", "HEAD"): + t = AsyncTransport([{}, {}], connection_class=DummyConnection) + await t._async_call() + con1 = t.connection_pool.get_connection() + con2 = t.connection_pool.get_connection() + t.connection_pool.mark_dead(con1) + t.connection_pool.mark_dead(con2) + + await t.perform_request(method, "/") + assert 1 == len(t.connection_pool.connections) + assert 1 == len(t.connection_pool.dead_count) async def test_sniff_will_use_seed_connections(self): t = AsyncTransport([{"data": CLUSTER_NODES}], connection_class=DummyConnection) @@ -368,6 +370,25 @@ async def test_sniff_on_fail_triggers_sniffing_on_fail(self): assert 1 == len(t.connection_pool.connections) assert "http://1.1.1.1:123" == t.get_connection().host + @patch("elasticsearch._async.transport.AsyncTransport.sniff_hosts") + async def test_sniff_on_fail_failing_does_not_prevent_retires(self, sniff_hosts): + sniff_hosts.side_effect = [TransportError("sniff failed")] + t = AsyncTransport( + [{"exception": ConnectionError("abandon ship")}, {"data": CLUSTER_NODES}], + connection_class=DummyConnection, + sniff_on_connection_fail=True, + max_retries=3, + randomize_hosts=False, + ) + await t._async_init() + + conn_err, conn_data = t.connection_pool.connections + response = await t.perform_request("GET", "/") + assert json.loads(CLUSTER_NODES) == response + assert 1 == sniff_hosts.call_count + assert 1 == len(conn_err.calls) + assert 1 == len(conn_data.calls) + async def test_sniff_after_n_seconds(self, event_loop): t = AsyncTransport( [{"data": CLUSTER_NODES}], diff --git a/test_elasticsearch/test_transport.py b/test_elasticsearch/test_transport.py index cbc1f716f..50e68c97f 100644 --- a/test_elasticsearch/test_transport.py +++ b/test_elasticsearch/test_transport.py @@ -17,13 +17,14 @@ # under the License. from __future__ import unicode_literals +import json import time from mock import patch from elasticsearch.transport import Transport, get_host_info from elasticsearch.connection import Connection from elasticsearch.connection_pool import DummyConnectionPool -from elasticsearch.exceptions import ConnectionError +from elasticsearch.exceptions import ConnectionError, TransportError from .test_cases import TestCase @@ -254,15 +255,16 @@ def test_failed_connection_will_be_marked_as_dead(self): self.assertEqual(0, len(t.connection_pool.connections)) def test_resurrected_connection_will_be_marked_as_live_on_success(self): - t = Transport([{}, {}], connection_class=DummyConnection) - con1 = t.connection_pool.get_connection() - con2 = t.connection_pool.get_connection() - t.connection_pool.mark_dead(con1) - t.connection_pool.mark_dead(con2) + for method in ("GET", "HEAD"): + t = Transport([{}, {}], connection_class=DummyConnection) + con1 = t.connection_pool.get_connection() + con2 = t.connection_pool.get_connection() + t.connection_pool.mark_dead(con1) + t.connection_pool.mark_dead(con2) - t.perform_request("GET", "/") - self.assertEqual(1, len(t.connection_pool.connections)) - self.assertEqual(1, len(t.connection_pool.dead_count)) + t.perform_request(method, "/") + self.assertEqual(1, len(t.connection_pool.connections)) + self.assertEqual(1, len(t.connection_pool.dead_count)) def test_sniff_will_use_seed_connections(self): t = Transport([{"data": CLUSTER_NODES}], connection_class=DummyConnection) @@ -330,6 +332,24 @@ def test_sniff_on_fail_triggers_sniffing_on_fail(self): self.assertEqual(1, len(t.connection_pool.connections)) self.assertEqual("http://1.1.1.1:123", t.get_connection().host) + @patch("elasticsearch.transport.Transport.sniff_hosts") + def test_sniff_on_fail_failing_does_not_prevent_retires(self, sniff_hosts): + sniff_hosts.side_effect = [TransportError("sniff failed")] + t = Transport( + [{"exception": ConnectionError("abandon ship")}, {"data": CLUSTER_NODES}], + connection_class=DummyConnection, + sniff_on_connection_fail=True, + max_retries=3, + randomize_hosts=False, + ) + + conn_err, conn_data = t.connection_pool.connections + response = t.perform_request("GET", "/") + self.assertEqual(json.loads(CLUSTER_NODES), response) + self.assertEqual(1, sniff_hosts.call_count) + self.assertEqual(1, len(conn_err.calls)) + self.assertEqual(1, len(conn_data.calls)) + def test_sniff_after_n_seconds(self): t = Transport( [{"data": CLUSTER_NODES}],