-
Notifications
You must be signed in to change notification settings - Fork 590
/
sync_rpc_client.py
241 lines (219 loc) · 8.93 KB
/
sync_rpc_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
Copyright 2020 The Magma Authors.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
import queue
import random
import threading
import time
from typing import Dict, Iterator
import grpc
import magma.magmad.events as magmad_events
from google.protobuf.json_format import MessageToJson
from magma.common.rpc_utils import is_grpc_error_retryable
from magma.common.sentry import EXCLUDE_FROM_ERROR_MONITORING
from magma.common.service_registry import ServiceRegistry
from magma.magmad.proxy_client import ControlProxyHttpClient
from orc8r.protos.sync_rpc_service_pb2 import SyncRPCRequest, SyncRPCResponse
from orc8r.protos.sync_rpc_service_pb2_grpc import SyncRPCServiceStub
class SyncRPCClient(threading.Thread):
"""
SyncRPCClient initiates a SyncRPCClient, and opens a bidirectional stream
with the cloud.
"""
RETRY_MAX_DELAY_SECS = 10 # seconds
def __init__(
self, loop, response_timeout: int,
print_grpc_payload: bool = False,
):
threading.Thread.__init__(self)
# a synchronized queue
self._response_queue: queue.Queue = queue.Queue()
self._loop = loop
asyncio.set_event_loop(self._loop)
# seconds to wait for an actual SyncRPCResponse to become available
# before sending out a heartBeat
self._response_timeout = response_timeout
self._proxy_client = ControlProxyHttpClient()
self.daemon = True
self._current_delay = 0
self._last_conn_time = 0.0
self._conn_closed_table: Dict[int, bool] = {} # mapping of req id -> conn closed
self._print_grpc_payload = print_grpc_payload
def run(self):
"""
This is executed when the thread is started. It gets a connection to
the cloud dispatcher, and calls its bidirectional streaming rpc
EstablishSyncRPCStream(). process_streams should never return, and
if it did, exception will be logged, and new connection to dispatcher
will be attempted after RETRY_DELAY_SECS seconds.
"""
while True:
start_time = time.time()
try:
chan = ServiceRegistry.get_rpc_channel(
'dispatcher',
ServiceRegistry.CLOUD,
)
client = SyncRPCServiceStub(chan)
self._set_connect_time()
self.process_streams(client)
except grpc.RpcError as err:
if is_grpc_error_retryable(err):
logging.warning(
"[SyncRPC] Transient gRPC error, retrying: %s",
err.details(),
)
self._retry_connect_sleep()
continue
else:
logging.error(
"[SyncRPC] gRPC error: %s, reconnecting to cloud.",
err.details(),
extra=EXCLUDE_FROM_ERROR_MONITORING,
)
self._cleanup_and_reconnect()
except Exception as exp: # pylint: disable=broad-except
conn_time = time.time() - start_time
logging.error(
"[SyncRPC] Error after %ds: %s",
conn_time,
exp,
extra=EXCLUDE_FROM_ERROR_MONITORING,
)
self._cleanup_and_reconnect()
def process_streams(self, client: SyncRPCServiceStub) -> None:
"""
Calls rpc function EstablishSyncRPCStream on the client to establish
a stream with dispatcher in the cloud, processes all requests from
the stream, and writes all responses to the stream.
Args:
client: a grpc client to dispatcher in the cloud.
Returns:
Should only return when an exception is encountered.
"""
# call to bidirectional streaming grpc takes in an iterator,
# and returns an iterator
sync_rpc_requests = client.EstablishSyncRPCStream(
self.send_sync_rpc_response(),
)
magmad_events.established_sync_rpc_stream()
# forward incoming requests from cloud to control_proxy
self.forward_requests(sync_rpc_requests)
def send_sync_rpc_response(self):
"""
Retrieve SyncRPCResponse from queue. If no response is available yet,
block for at most response_timeout seconds, and send a heartBeat if
timeout.
Returns: A generator of SyncRPCResponse
"""
while True:
try:
resp = self._response_queue.get(
block=True,
timeout=self._response_timeout,
)
yield resp
except queue.Empty:
# response_queue is empty, send heartbeat
# as the function itself has no knowledge on when it's
# the first time it's called
# this heartbeat response could be periodically called
logging.debug("[SyncRPC] Sending heartbeat")
yield SyncRPCResponse(heartBeat=True)
def forward_requests(
self,
sync_rpc_requests: Iterator[SyncRPCRequest],
) -> None:
"""
Send requests in the sync_rpc_requests iterator.
Args:
sync_rpc_requests: an iterator of SyncRPCRequest from cloud
Returns: Should only return when server shuts the stream (reaches
end of the iterator sync_rpc_requests, or encounters an error)
"""
logging.info("[SyncRPC] Waiting for requests")
while True:
try:
req = next(sync_rpc_requests)
self.forward_request(req)
except grpc.RpcError as err:
logging.error(
"[SyncRPC] Failing to forward request, err: %s",
err.details(),
extra=EXCLUDE_FROM_ERROR_MONITORING,
)
raise err
def forward_request(self, request: SyncRPCRequest) -> None:
self._print_grpc(request)
if request.heartBeat:
logging.info("[SyncRPC] Got heartBeat from cloud")
return
if request.connClosed:
logging.debug("[SyncRPC] Got connClosed from cloud")
self._conn_closed_table[request.reqId] = True
return
logging.debug("[SyncRPC] Got a request")
asyncio.run_coroutine_threadsafe(
self._proxy_client.send(
request.reqBody,
request.reqId,
self._response_queue,
self._conn_closed_table,
),
self._loop,
)
def _retry_connect_sleep(self) -> None:
"""
Sleep for a current delay amount of time, with random backoff
If current delay is less than RETRY_MAX_DELAY_SECS, exponentially
increase current delay. If it exceeds RETRY_MAX_DELAY_SECS, sleep for
RETRY_MAX_DELAY_SECS
"""
sleep_time = self._current_delay + (random.randint(0, 1000) / 1000)
self._current_delay = min(
2 * self._current_delay,
self.RETRY_MAX_DELAY_SECS,
)
self._current_delay = max(self._current_delay, 1)
time.sleep(sleep_time)
def _set_connect_time(self) -> None:
logging.info("[SyncRPC] Opening stream to cloud")
self._current_delay = 0
self._last_conn_time = time.time()
def _cleanup_and_reconnect(self):
"""
If the connection is terminated, wait for a period of time
before connecting back to the cloud. Also clear the conn
closed table since cloud may reuse req IDs, and clear
current proxy client connections
"""
self._conn_closed_table.clear()
self._proxy_client.close_all_connections()
self._retry_connect_sleep()
magmad_events.disconnected_sync_rpc_stream()
def _print_grpc(self, message):
if self._print_grpc_payload:
try:
log_msg = "{} {}".format(
message.DESCRIPTOR.full_name,
MessageToJson(message),
)
# add indentation
padding = 2 * ' '
log_msg = ''.join(
"{}{}".format(padding, line)
for line in log_msg.splitlines(True)
)
log_msg = "GRPC message:\n{}".format(log_msg)
logging.info(log_msg)
except Exception as e: # pylint: disable=broad-except
logging.debug("Exception while trying to log GRPC: %s", e)