-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: run warmup on Runtimes and Executor #5579
Changes from all commits
5817aac
330dd7a
10947b6
967111a
02faa77
23e1777
aa9189d
6824fbf
b70008c
c5155e8
280d481
2885b4e
0f2bf67
2be40b1
4651f3b
d1c7e21
111db68
8090ea8
fcb122b
367a115
34d7d4b
4c86668
a82c128
87ee55f
017dfea
99e0823
e11a41f
ab6119d
e9b6561
479c810
95a5767
9deda11
dd3516b
ada4e5e
a23caf1
51c7e60
855bd7c
7211337
463a998
6232b42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import asyncio | ||
import ipaddress | ||
import os | ||
import time | ||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union | ||
|
@@ -28,6 +29,8 @@ | |
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: # pragma: no cover | ||
import threading | ||
|
||
from grpc.aio._interceptor import ClientInterceptor | ||
from opentelemetry.instrumentation.grpc._client import ( | ||
OpenTelemetryClientInterceptor, | ||
|
@@ -127,6 +130,9 @@ def __init__( | |
self.aio_tracing_client_interceptors = aio_tracing_client_interceptors | ||
self.tracing_client_interceptors = tracing_client_interceptor | ||
self._deployment_name = deployment_name | ||
# a set containing all the ConnectionStubs that will be created using add_connection | ||
# this set is not updated in reset_connection and remove_connection | ||
self._warmup_stubs = set() | ||
|
||
async def reset_connection( | ||
self, address: str, deployment_name: str | ||
|
@@ -178,6 +184,10 @@ def add_connection(self, address: str, deployment_name: str): | |
stubs, channel = self._create_connection(address, deployment_name) | ||
self._address_to_channel[address] = channel | ||
self._connections.append(stubs) | ||
# create a new set of stubs and channels for warmup to avoid | ||
# loosing channel during remove_connection or reset_connection | ||
stubs, _ = self._create_connection(address, deployment_name) | ||
self._warmup_stubs.add(stubs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not see where they are used There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check the usage of the property method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah okey |
||
|
||
async def remove_connection(self, address: str) -> Union[grpc.aio.Channel, None]: | ||
""" | ||
|
@@ -311,6 +321,16 @@ async def close(self): | |
self._address_to_connection_idx.clear() | ||
self._connections.clear() | ||
self._rr_counter = 0 | ||
for stub in self._warmup_stubs: | ||
await stub.channel.close(0.5) | ||
self._warmup_stubs.clear() | ||
|
||
@property | ||
def warmup_stubs(self): | ||
"""Return set of warmup stubs | ||
:returns: Set of stubs. The set doesn't remove any items once added. | ||
""" | ||
return self._warmup_stubs | ||
|
||
|
||
class GrpcConnectionPool: | ||
|
@@ -373,6 +393,7 @@ async def _init_stubs(self): | |
self.single_data_stub = stubs['jina.JinaSingleDataRequestRPC'] | ||
self.stream_stub = stubs['jina.JinaRPC'] | ||
self.endpoints_discovery_stub = stubs['jina.JinaDiscoverEndpointsRPC'] | ||
self.info_rpc_stub = stubs['jina.JinaInfoRPC'] | ||
self._initialized = True | ||
|
||
async def send_discover_endpoint( | ||
|
@@ -506,6 +527,21 @@ async def send_requests( | |
else: | ||
raise ValueError(f'Unsupported request type {type(requests[0])}') | ||
|
||
async def send_info_rpc(self, timeout: Optional[float] = None): | ||
""" | ||
Use the JinaInfoRPC stub to send request to the _status endpoint exposed by the Runtime | ||
:param timeout: defines timeout for sending request | ||
:returns: JinaInfoProto | ||
""" | ||
if not self._initialized: | ||
await self._init_stubs() | ||
|
||
call_result = self.info_rpc_stub._status( | ||
jina_pb2.google_dot_protobuf_dot_empty__pb2.Empty(), | ||
timeout=timeout, | ||
) | ||
return await call_result | ||
|
||
class _ConnectionPoolMap: | ||
def __init__( | ||
self, | ||
|
@@ -613,9 +649,6 @@ def _get_connection_list( | |
return self._get_connection_list( | ||
deployment, type_, 0, increase_access_count | ||
) | ||
self._logger.debug( | ||
f'did not find a connection for deployment {deployment}, type {type_} and entity_id {entity_id}. There are {len(self._deployments[deployment][type_]) if deployment in self._deployments else 0} available connections for this deployment and type. ' | ||
) | ||
return None | ||
|
||
def _add_deployment(self, deployment: str): | ||
|
@@ -1114,6 +1147,68 @@ async def task_wrapper(): | |
|
||
return asyncio.create_task(task_wrapper()) | ||
|
||
async def warmup( | ||
self, | ||
deployment: str, | ||
stop_event: 'threading.Event', | ||
): | ||
'''Executes JinaInfoRPC against the provided deployment. A single task is created for each replica connection. | ||
:param deployment: deployment name and the replicas that needs to be warmed up. | ||
:param stop_event: signal to indicate if an early termination of the task is required for graceful teardown. | ||
''' | ||
self._logger.debug(f'starting warmup task for deployment {deployment}') | ||
|
||
async def task_wrapper(target_warmup_responses, stub): | ||
try: | ||
call_result = stub.send_info_rpc(timeout=0.5) | ||
await call_result | ||
target_warmup_responses[stub.address] = True | ||
except Exception: | ||
target_warmup_responses[stub.address] = False | ||
|
||
try: | ||
start_time = time.time() | ||
timeout = start_time + 60 * 5 # 5 minutes from now | ||
warmed_up_targets = set() | ||
replicas = self._get_all_replicas(deployment) | ||
|
||
while not stop_event.is_set(): | ||
replica_warmup_responses = {} | ||
tasks = [] | ||
|
||
for replica in replicas: | ||
for stub in replica.warmup_stubs: | ||
if stub.address not in warmed_up_targets: | ||
tasks.append( | ||
asyncio.create_task( | ||
task_wrapper(replica_warmup_responses, stub) | ||
) | ||
) | ||
|
||
await asyncio.gather(*tasks, return_exceptions=True) | ||
for target, response in replica_warmup_responses.items(): | ||
if response: | ||
warmed_up_targets.add(target) | ||
|
||
now = time.time() | ||
if now > timeout or all(list(replica_warmup_responses.values())): | ||
self._logger.debug(f'completed warmup task in {now - start_time}s.') | ||
return | ||
|
||
await asyncio.sleep(0.2) | ||
except Exception as ex: | ||
self._logger.error(f'error with warmup up task: {ex}') | ||
return | ||
|
||
def _get_all_replicas(self, deployment): | ||
replica_set = set() | ||
replica_set.update(self._connections.get_replicas_all_shards(deployment)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can extract out the warmup logic into a |
||
replica_set.add( | ||
self._connections.get_replicas(deployment=deployment, head=True) | ||
) | ||
|
||
return set(filter(None, replica_set)) | ||
|
||
@staticmethod | ||
def __aio_channel_with_tracing_interceptor( | ||
address, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import argparse | ||
import asyncio | ||
import signal | ||
import threading | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove unneeded import |
||
import time | ||
from abc import ABC, abstractmethod | ||
from typing import TYPE_CHECKING, Optional, Union | ||
|
@@ -17,7 +18,6 @@ | |
|
||
if TYPE_CHECKING: # pragma: no cover | ||
import multiprocessing | ||
import threading | ||
|
||
HANDLED_SIGNALS = ( | ||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. | ||
|
@@ -76,6 +76,8 @@ def _cancel(signum, frame): | |
self._start_time = time.time() | ||
self._loop.run_until_complete(self.async_setup()) | ||
self._send_telemetry_event() | ||
self.warmup_task = None | ||
self.warmup_stop_event = threading.Event() | ||
|
||
def _send_telemetry_event(self): | ||
send_telemetry_event(event='start', obj=self, entity_id=self._entity_id) | ||
|
@@ -161,6 +163,21 @@ async def async_run_forever(self): | |
"""The async method to run until it is stopped.""" | ||
... | ||
|
||
async def cancel_warmup_task(self): | ||
'''Cancel warmup task if exists and is not completed. Cancellation is required if the Flow is being terminated before the | ||
task is successful or hasn't reached the max timeout. | ||
''' | ||
if self.warmup_task: | ||
try: | ||
if not self.warmup_task.done(): | ||
self.logger.debug(f'Cancelling warmup task.') | ||
self.warmup_stop_event.set() | ||
await self.warmup_task | ||
self.warmup_task.exception() | ||
except Exception as ex: | ||
self.logger.debug(f'exception during warmup task cancellation: {ex}') | ||
pass | ||
|
||
# Static methods used by the Pod to communicate with the `Runtime` in the separate process | ||
|
||
@staticmethod | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove unneeded imports