-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
compound.py
302 lines (259 loc) · 10.2 KB
/
compound.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import copy
from argparse import Namespace
from itertools import cycle
from typing import Optional, Dict, List, Union, Set
from contextlib import ExitStack
from .. import BasePod
from .. import Pea
from .. import Pod
from ..networking import get_connect_host
from ... import helper
from ...enums import SocketType, SchedulerType, PollingType
from ...helper import random_identity
class CompoundPod(BasePod, ExitStack):
"""A CompoundPod is a immutable set of pods, which run in parallel.
A CompoundPod is an abstraction using a composable pattern to abstract the usage of parallel Pods that act as shards.
CompoundPod will make sure to add a `HeadPea` and a `TailPea` to serve as routing/merging pattern for the different Pod shards
:param args: pod arguments parsed from the CLI. These arguments will be used for each of the shards
:param needs: pod names of preceding pods, the output of these pods are going into the input of this pod
"""
head_args = None
tail_args = None
def __init__(
self, args: Union['Namespace', Dict], needs: Optional[Set[str]] = None
):
super().__init__()
args.upload_files = BasePod._set_upload_files(args)
self.args = args
self.needs = (
needs or set()
) #: used in the :class:`jina.flow.Flow` to build the graph
# we will see how to have `CompoundPods` in remote later when we add tests for it
self.is_head_router = True
self.is_tail_router = True
self.head_args = BasePod._copy_to_head_args(args, args.polling)
self.tail_args = BasePod._copy_to_tail_args(self.args, self.args.polling)
# uses before with shards apply to shards and not to replicas
self.shards = [] # type: List['Pod']
# BACKWARDS COMPATIBILITY:
self.args.parallel = self.args.shards
self.assign_shards()
def assign_shards(self):
"""Assign shards to the CompoundPod"""
self.shards.clear()
cargs = copy.copy(self.args)
cargs.uses_before = None
cargs.uses_after = None
self.shards_args = self._set_shard_args(cargs, self.head_args, self.tail_args)
for _args in self.shards_args:
if getattr(self.args, 'noblock_on_start', False):
_args.noblock_on_start = True
self.shards.append(Pod(_args))
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
super().__exit__(exc_type, exc_val, exc_tb)
self.join()
@property
def port_jinad(self) -> int:
"""Get the JinaD remote port
.. # noqa: DAR201
"""
return self.head_args.port_jinad
@property
def host(self) -> str:
"""Get the host name of this Pod
.. # noqa: DAR201
"""
return self.head_args.host
def _parse_pod_args(self, args: Namespace) -> List[Namespace]:
return self._set_shard_args(
args,
head_args=self.head_args,
tail_args=self.tail_args,
)
@property
def num_peas(self) -> int:
"""
Get the number of running :class:`Pod`
:return: total number of peas including head and tail
"""
return sum([shard.num_peas for shard in self.shards]) + 2
def __eq__(self, other: 'CompoundPod'):
return self.num_peas == other.num_peas and self.name == other.name
def _enter_pea(self, pea: 'Pea') -> None:
self.enter_context(pea)
def start(self) -> 'CompoundPod':
"""
Start to run all :class:`Pod` and :class:`Pea` in this CompoundPod.
:return: started CompoundPod
.. note::
If one of the :class:`Pod` fails to start, make sure that all of them
are properly closed.
"""
if getattr(self.args, 'noblock_on_start', False):
head_args = self.head_args
head_args.noblock_on_start = True
self.head_pea = Pea(head_args)
self._enter_pea(self.head_pea)
for shard in self.shards:
self._enter_shard(shard)
tail_args = self.tail_args
tail_args.noblock_on_start = True
self.tail_pea = Pea(tail_args)
self._enter_pea(self.tail_pea)
else:
try:
head_args = self.head_args
self.head_pea = Pea(head_args)
self._enter_pea(self.head_pea)
for shard in self.shards:
self._enter_shard(shard)
shard.activate()
tail_args = self.tail_args
self.tail_pea = Pea(tail_args)
self._enter_pea(self.tail_pea)
except:
self.close()
raise
return self
def wait_start_success(self) -> None:
"""
Block until all pods and peas start successfully.
If not successful, it will raise an error hoping the outer function to catch it
"""
if not self.args.noblock_on_start:
raise ValueError(
f'{self.wait_start_success!r} should only be called when `noblock_on_start` is set to True'
)
try:
self.head_pea.wait_start_success()
self.tail_pea.wait_start_success()
for p in self.shards:
p.wait_start_success()
except:
self.close()
raise
def _enter_shard(self, shard: 'Pod') -> None:
self.enter_context(shard)
def join(self):
"""Wait until all pods and peas exit."""
try:
if getattr(self, 'head_pea', None):
self.head_pea.join()
if getattr(self, 'tail_pea', None):
self.tail_pea.join()
for p in self.shards:
p.join()
except KeyboardInterrupt:
pass
@property
def is_ready(self) -> bool:
"""
Checks if Pod is read.
:return: true if the peas and pods are ready to serve requests
.. note::
A Pod is ready when all the Peas it contains are ready
"""
return all(
[p.is_ready.is_set() for p in [self.head_pea, self.tail_pea]]
+ [p.is_ready for p in self.shards]
)
@staticmethod
def _set_shard_args(
args: Namespace,
head_args: Namespace,
tail_args: Namespace,
) -> List[Namespace]:
"""
Sets the arguments of the shards in the compound pod.
:param args: arguments configured by the user for the shards
:param head_args: head args from the compound pod
:param tail_args: tail args from the compound pod
:return: list of arguments for the shards
"""
result = []
_host_list = (
args.peas_hosts
if args.peas_hosts
else [
args.host,
]
)
host_generator = cycle(_host_list)
for idx in range(args.shards):
_args = copy.deepcopy(args)
pod_host_list = [
host for _, host in zip(range(args.replicas), host_generator)
]
_args.peas_hosts = pod_host_list
_args.shard_id = idx
# BACKWARDS COMPATIBILITY:
_args.pea_id = _args.shard_id
_args.identity = random_identity()
if _args.name:
_args.name += f'/shard-{idx}'
else:
_args.name = f'{idx}'
_args.port_in = head_args.port_out
_args.port_out = tail_args.port_in
_args.port_ctrl = helper.random_port()
if args.polling.is_push:
if args.scheduling == SchedulerType.ROUND_ROBIN:
_args.socket_in = SocketType.PULL_CONNECT
elif args.scheduling == SchedulerType.LOAD_BALANCE:
_args.socket_in = SocketType.DEALER_CONNECT
else:
raise ValueError(
f'{args.scheduling} is not supported as a SchedulerType!'
)
else:
_args.socket_in = SocketType.SUB_CONNECT
_args.socket_out = SocketType.PUSH_CONNECT
_args.dynamic_routing = False
# ugly trick to avoid Head of shard to have wrong host in
tmp_args = copy.deepcopy(_args)
if _args.shards > 1:
tmp_args.runs_in_docker = False
tmp_args.uses = ''
_args.host_in = get_connect_host(
bind_host=head_args.host,
bind_expose_public=head_args.expose_public,
connect_args=tmp_args,
)
_args.host_out = get_connect_host(
bind_host=tail_args.host,
bind_expose_public=tail_args.expose_public,
connect_args=tmp_args,
)
result.append(_args)
return result
def rolling_update(
self, dump_path: Optional[str] = None, *, uses_with: Optional[Dict] = None
):
"""Reload all Pods of this Compound Pod.
:param dump_path: **backwards compatibility** This function was only accepting dump_path as the only potential arg to override
:param uses_with: a Dictionary of arguments to restart the executor with
"""
for shard in self.shards:
shard.rolling_update(dump_path=dump_path, uses_with=uses_with)
@property
def _mermaid_str(self) -> List[str]:
"""String that will be used to represent the Pod graphically when `Flow.plot()` is invoked
.. # noqa: DAR201
"""
mermaid_graph = [f'subgraph {self.name};\n', f'direction LR;\n']
head_name = self.head_args.name
tail_name = self.tail_args.name
pod_names = []
for shard in self.shards:
pod_names.append(shard.name)
shard_mermaid_graph = shard._mermaid_str
shard_mermaid_graph = [
node.replace(';', '\n') for node in shard_mermaid_graph
]
mermaid_graph.extend(shard_mermaid_graph)
mermaid_graph.append('\n')
for name in pod_names:
mermaid_graph.append(f'{head_name}:::HEADTAIL --> {name};')
mermaid_graph.append(f'{name} --> {tail_name}:::HEADTAIL;')
mermaid_graph.append('end;')
return mermaid_graph