This repository has been archived by the owner on Mar 20, 2023. It is now read-only.
/
transport.py
219 lines (184 loc) · 8.57 KB
/
transport.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import asyncio
import time
import logging
from itertools import chain
from elasticsearch import Transport, TransportError, ConnectionTimeout, ConnectionError, SerializationError
from elasticsearch.connection_pool import DummyConnectionPool
from .connection import AIOHttpConnection
from .connection_pool import AsyncConnectionPool, AsyncDummyConnectionPool
from .helpers import ensure_future
logger = logging.getLogger('elasticsearch')
class AsyncTransport(Transport):
def __init__(self, hosts, connection_class=AIOHttpConnection, loop=None,
connection_pool_class=AsyncConnectionPool,
sniff_on_start=False, raise_on_sniff_error=True, **kwargs):
self.raise_on_sniff_error = raise_on_sniff_error
self.loop = asyncio.get_event_loop() if loop is None else loop
kwargs['loop'] = self.loop
super().__init__(hosts, connection_class=connection_class, sniff_on_start=False,
connection_pool_class=connection_pool_class, **kwargs)
self.sniffing_task = None
if sniff_on_start:
# schedule sniff on start
self.initiate_sniff(True)
def initiate_sniff(self, initial=False):
"""
Initiate a sniffing task. Make sure we only have one sniff request
running at any given time. If a finished sniffing request is around,
collect its result (which can raise its exception).
"""
if self.sniffing_task and self.sniffing_task.done():
try:
if self.sniffing_task is not None:
self.sniffing_task.result()
except:
if self.raise_on_sniff_error:
raise
finally:
self.sniffing_task = None
if self.sniffing_task is None:
self.sniffing_task = ensure_future(self.sniff_hosts(initial), loop=self.loop)
@asyncio.coroutine
def close(self):
if self.sniffing_task:
self.sniffing_task.cancel()
yield from self.connection_pool.close()
def set_connections(self, hosts):
super().set_connections(hosts)
if isinstance(self.connection_pool, DummyConnectionPool):
self.connection_pool = AsyncDummyConnectionPool(self.connection_pool.connection_opts)
def get_connection(self):
if self.sniffer_timeout:
if time.time() >= self.last_sniff + self.sniffer_timeout:
self.initiate_sniff()
return self.connection_pool.get_connection()
def mark_dead(self, connection):
self.connection_pool.mark_dead(connection)
if self.sniff_on_connection_fail:
self.initiate_sniff()
@asyncio.coroutine
def _get_sniff_data(self, initial=False):
previous_sniff = self.last_sniff
# reset last_sniff timestamp
self.last_sniff = time.time()
# use small timeout for the sniffing request, should be a fast api call
timeout = self.sniff_timeout if not initial else None
tasks = [
c.perform_request('GET', '/_nodes/_all/http', timeout=timeout)
# go through all current connections as well as the
# seed_connections for good measure
for c in chain(self.connection_pool.connections, (c for c in self.seed_connections if c not in self.connection_pool.connections))
]
done = ()
try:
while tasks:
# execute sniff requests in parallel, wait for first to return
done, tasks = yield from asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED, loop=self.loop)
# go through all the finished tasks
for t in done:
try:
_, headers, node_info = t.result()
node_info = self.deserializer.loads(node_info, headers.get('content-type'))
except (ConnectionError, SerializationError) as e:
logger.warn('Sniffing request failed with %r', e)
continue
node_info = list(node_info['nodes'].values())
return node_info
else:
# no task has finished completely
raise TransportError("N/A", "Unable to sniff hosts.")
except:
# keep the previous value on error
self.last_sniff = previous_sniff
raise
finally:
# clean up pending futures
for t in chain(done, tasks):
t.cancel()
@asyncio.coroutine
def sniff_hosts(self, initial=False):
"""
Obtain a list of nodes from the cluster and create a new connection
pool using the information retrieved.
To extract the node connection parameters use the ``nodes_to_host_callback``.
:arg initial: flag indicating if this is during startup
(``sniff_on_start``), ignore the ``sniff_timeout`` if ``True``
"""
node_info = yield from self._get_sniff_data(initial)
hosts = list(filter(None, (self._get_host_info(n) for n in node_info)))
# we weren't able to get any nodes, maybe using an incompatible
# transport_schema or host_info_callback blocked all - raise error.
if not hosts:
raise TransportError("N/A", "Unable to sniff hosts - no viable hosts found.")
# remember current live connections
orig_connections = self.connection_pool.connections[:]
self.set_connections(hosts)
# close those connections that are not in use any more
for c in orig_connections:
if c not in self.connection_pool.connections:
yield from c.close()
@asyncio.coroutine
def main_loop(self, method, url, params, body, headers=None, ignore=(), timeout=None):
for attempt in range(self.max_retries + 1):
connection = self.get_connection()
try:
status, headers, data = yield from connection.perform_request(
method, url, params, body, headers=headers, ignore=ignore, timeout=timeout)
except TransportError as e:
if method == 'HEAD' and e.status_code == 404:
return False
retry = False
if isinstance(e, ConnectionTimeout):
retry = self.retry_on_timeout
elif isinstance(e, ConnectionError):
retry = True
elif e.status_code in self.retry_on_status:
retry = True
if retry:
# only mark as dead if we are retrying
self.mark_dead(connection)
# raise exception on last retry
if attempt == self.max_retries:
raise
else:
raise
else:
if method == 'HEAD':
return 200 <= status < 300
# connection didn't fail, confirm it's live status
self.connection_pool.mark_live(connection)
if data:
data = self.deserializer.loads(data, headers.get('content-type'))
return data
def perform_request(self, method, url, headers=None, params=None, body=None):
if body is not None:
body = self.serializer.dumps(body)
# some clients or environments don't support sending GET with body
if method in ('HEAD', 'GET') and self.send_get_body_as != 'GET':
# send it as post instead
if self.send_get_body_as == 'POST':
method = 'POST'
# or as source parameter
elif self.send_get_body_as == 'source':
if params is None:
params = {}
params['source'] = body
body = None
if body is not None:
try:
body = body.encode('utf-8')
except (UnicodeDecodeError, AttributeError):
# bytes/str - no need to re-encode
pass
ignore = ()
timeout = None
if params:
timeout = params.pop('request_timeout', None)
ignore = params.pop('ignore', ())
if isinstance(ignore, int):
ignore = (ignore, )
return ensure_future(self.main_loop(method, url, params, body,
headers=headers,
ignore=ignore,
timeout=timeout),
loop=self.loop)