-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
helper.py
92 lines (74 loc) · 2.84 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import multiprocessing
from copy import deepcopy
from functools import partial
from typing import TYPE_CHECKING
from jina.enums import GatewayProtocolType, PodRoleType
from jina.hubble.helper import is_valid_huburi
from jina.hubble.hubio import HubIO
if TYPE_CHECKING:
from argparse import Namespace
def _get_event(obj) -> multiprocessing.Event:
if isinstance(obj, multiprocessing.Process) or isinstance(
obj, multiprocessing.context.ForkProcess
):
return multiprocessing.Event()
elif isinstance(obj, multiprocessing.context.SpawnProcess):
return multiprocessing.get_context('spawn').Event()
else:
raise TypeError(f'{obj} is not an instance of "multiprocessing.Process"')
class ConditionalEvent:
"""
:class:`ConditionalEvent` provides a common interface to an event (multiprocessing or threading event)
that gets triggered when any of the events provided in input is triggered (OR logic)
:param events_list: The list of events that compose this composable event
"""
def __init__(self, events_list):
super().__init__()
self.event = None
self.event = multiprocessing.synchronize.Event(
ctx=multiprocessing.get_context()
)
self.event_list = events_list
for e in events_list:
self._setup(e, self._state_changed)
self._state_changed()
def _state_changed(self):
bools = [e.is_set() for e in self.event_list]
if any(bools):
self.event.set()
else:
self.event.clear()
def _custom_set(self, e):
e._set()
e._state_changed()
def _custom_clear(self, e):
e._clear()
e._state_changed()
def _setup(self, e, changed_callback):
e._set = e.set
e._clear = e.clear
e._state_changed = changed_callback
e.set = partial(self._custom_set, e)
e.clear = partial(self._custom_clear, e)
def update_runtime_cls(args, copy=False) -> 'Namespace':
"""Get runtime_cls as a string from args
:param args: pod/deployment namespace args
:param copy: True if args shouldn't be modified in-place
:return: runtime class as a string
"""
_args = deepcopy(args) if copy else args
gateway_runtime_dict = {
GatewayProtocolType.GRPC: 'GRPCGatewayRuntime',
GatewayProtocolType.WEBSOCKET: 'WebSocketGatewayRuntime',
GatewayProtocolType.HTTP: 'HTTPGatewayRuntime',
}
if _args.runtime_cls == 'WorkerRuntime' and is_valid_huburi(_args.uses):
_hub_args = deepcopy(_args)
_hub_args.uri = _args.uses
_hub_args.no_usage = True
_args.uses = HubIO(_hub_args).pull()
if hasattr(_args, 'protocol'):
_args.runtime_cls = gateway_runtime_dict[_args.protocol]
if _args.pod_role == PodRoleType.HEAD:
_args.runtime_cls = 'HeadRuntime'
return _args