Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: run warmup on Runtimes and Executor #5579

Merged
merged 40 commits into from
Jan 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
5817aac
feat: add warmup coroutine to Gateway and Worker runtimes
Jan 6, 2023
330dd7a
style: fix overload and cli autocomplete
jina-bot Jan 6, 2023
10947b6
test: reduce traced operations to account for warmup
Jan 6, 2023
967111a
refactor: user threding.Event to signal graceful warmup task cancella…
Jan 9, 2023
02faa77
style: fix overload and cli autocomplete
jina-bot Jan 9, 2023
23e1777
chore: reduce debug logging
Jan 9, 2023
aa9189d
feat: implement warmup using discovery requests
Jan 9, 2023
6824fbf
refactor: use asyncio.sleep
Jan 9, 2023
b70008c
feat: create warmup task per deployment
Jan 10, 2023
c5155e8
feat: implement warmup for HeadRuntime
Jan 10, 2023
280d481
feat: remove executor warmup task
Jan 10, 2023
2885b4e
feat: don't warmup deprecated head uses before and after
Jan 11, 2023
0f2bf67
Merge remote-tracking branch 'origin/master' into feat-serve-5467-run…
Jan 11, 2023
2be40b1
style: fix overload and cli autocomplete
jina-bot Jan 11, 2023
4651f3b
fix: revert changes to unrelated tests
Jan 11, 2023
d1c7e21
Merge remote-tracking branch 'origin/feat-serve-5467-runtime-warmup' …
Jan 11, 2023
111db68
test: start worker deployment before gateway
Jan 11, 2023
8090ea8
fix: don't use asyncio.gather or single task
Jan 11, 2023
fcb122b
Merge branch 'master' into feat-serve-5467-runtime-warmup
girishc13 Jan 11, 2023
367a115
fix: await gather_endpoints coroutine
Jan 12, 2023
34d7d4b
feat: wait until grpc channel is ready before warmup
Jan 12, 2023
4c86668
Revert "feat: wait until grpc channel is ready before warmup"
Jan 12, 2023
a82c128
feat: create JinaInfoRPC stub in the ConnectionStubs for reuse
Jan 12, 2023
87ee55f
feat: enable grpc SO_REUSEPORT for multi process/threading
Jan 12, 2023
017dfea
Revert "feat: enable grpc SO_REUSEPORT for multi process/threading"
Jan 12, 2023
99e0823
Revert "fix: await gather_endpoints coroutine"
Jan 12, 2023
e11a41f
ci: debug warmup task
Jan 12, 2023
ab6119d
fix: pop removed connection channel from dict
Jan 12, 2023
e9b6561
Merge remote-tracking branch 'origin/master' into feat-serve-5467-run…
Jan 12, 2023
479c810
fix: pop removed connection channel from dict
Jan 12, 2023
95a5767
Revert "fix: pop removed connection channel from dict"
Jan 13, 2023
9deda11
Revert "fix: pop removed connection channel from dict"
Jan 13, 2023
dd3516b
Revert "feat: create JinaInfoRPC stub in the ConnectionStubs for reuse"
Jan 13, 2023
ada4e5e
feat: create duplicate stubs for warmup requests
Jan 13, 2023
a23caf1
Merge remote-tracking branch 'origin/master' into feat-serve-5467-run…
Jan 13, 2023
51c7e60
style: fix overload and cli autocomplete
jina-bot Jan 13, 2023
855bd7c
Merge remote-tracking branch 'origin/master' into feat-serve-5467-run…
Jan 13, 2023
7211337
chore: remove debug logging
Jan 13, 2023
463a998
chore: clean up imports
Jan 13, 2023
6232b42
feat: close channels created for warmup stubs
Jan 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
101 changes: 98 additions & 3 deletions jina/serve/networking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import ipaddress
import os
import time
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove unneeded imports

from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
Expand Down Expand Up @@ -28,6 +29,8 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover
import threading

from grpc.aio._interceptor import ClientInterceptor
from opentelemetry.instrumentation.grpc._client import (
OpenTelemetryClientInterceptor,
Expand Down Expand Up @@ -127,6 +130,9 @@ def __init__(
self.aio_tracing_client_interceptors = aio_tracing_client_interceptors
self.tracing_client_interceptors = tracing_client_interceptor
self._deployment_name = deployment_name
# a set containing all the ConnectionStubs that will be created using add_connection
# this set is not updated in reset_connection and remove_connection
self._warmup_stubs = set()

async def reset_connection(
self, address: str, deployment_name: str
Expand Down Expand Up @@ -178,6 +184,10 @@ def add_connection(self, address: str, deployment_name: str):
stubs, channel = self._create_connection(address, deployment_name)
self._address_to_channel[address] = channel
self._connections.append(stubs)
# create a new set of stubs and channels for warmup to avoid
# loosing channel during remove_connection or reset_connection
stubs, _ = self._create_connection(address, deployment_name)
self._warmup_stubs.add(stubs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see where they are used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the usage of the property method self.warmup_stubs. What's the preference when exposing read only properties?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah okey


async def remove_connection(self, address: str) -> Union[grpc.aio.Channel, None]:
"""
Expand Down Expand Up @@ -311,6 +321,16 @@ async def close(self):
self._address_to_connection_idx.clear()
self._connections.clear()
self._rr_counter = 0
for stub in self._warmup_stubs:
await stub.channel.close(0.5)
self._warmup_stubs.clear()

@property
def warmup_stubs(self):
"""Return set of warmup stubs
:returns: Set of stubs. The set doesn't remove any items once added.
"""
return self._warmup_stubs


class GrpcConnectionPool:
Expand Down Expand Up @@ -373,6 +393,7 @@ async def _init_stubs(self):
self.single_data_stub = stubs['jina.JinaSingleDataRequestRPC']
self.stream_stub = stubs['jina.JinaRPC']
self.endpoints_discovery_stub = stubs['jina.JinaDiscoverEndpointsRPC']
self.info_rpc_stub = stubs['jina.JinaInfoRPC']
self._initialized = True

async def send_discover_endpoint(
Expand Down Expand Up @@ -506,6 +527,21 @@ async def send_requests(
else:
raise ValueError(f'Unsupported request type {type(requests[0])}')

async def send_info_rpc(self, timeout: Optional[float] = None):
"""
Use the JinaInfoRPC stub to send request to the _status endpoint exposed by the Runtime
:param timeout: defines timeout for sending request
:returns: JinaInfoProto
"""
if not self._initialized:
await self._init_stubs()

call_result = self.info_rpc_stub._status(
jina_pb2.google_dot_protobuf_dot_empty__pb2.Empty(),
timeout=timeout,
)
return await call_result

class _ConnectionPoolMap:
def __init__(
self,
Expand Down Expand Up @@ -613,9 +649,6 @@ def _get_connection_list(
return self._get_connection_list(
deployment, type_, 0, increase_access_count
)
self._logger.debug(
f'did not find a connection for deployment {deployment}, type {type_} and entity_id {entity_id}. There are {len(self._deployments[deployment][type_]) if deployment in self._deployments else 0} available connections for this deployment and type. '
)
return None

def _add_deployment(self, deployment: str):
Expand Down Expand Up @@ -1114,6 +1147,68 @@ async def task_wrapper():

return asyncio.create_task(task_wrapper())

async def warmup(
self,
deployment: str,
stop_event: 'threading.Event',
):
'''Executes JinaInfoRPC against the provided deployment. A single task is created for each replica connection.
:param deployment: deployment name and the replicas that needs to be warmed up.
:param stop_event: signal to indicate if an early termination of the task is required for graceful teardown.
'''
self._logger.debug(f'starting warmup task for deployment {deployment}')

async def task_wrapper(target_warmup_responses, stub):
try:
call_result = stub.send_info_rpc(timeout=0.5)
await call_result
target_warmup_responses[stub.address] = True
except Exception:
target_warmup_responses[stub.address] = False

try:
start_time = time.time()
timeout = start_time + 60 * 5 # 5 minutes from now
warmed_up_targets = set()
replicas = self._get_all_replicas(deployment)

while not stop_event.is_set():
replica_warmup_responses = {}
tasks = []

for replica in replicas:
for stub in replica.warmup_stubs:
if stub.address not in warmed_up_targets:
tasks.append(
asyncio.create_task(
task_wrapper(replica_warmup_responses, stub)
)
)

await asyncio.gather(*tasks, return_exceptions=True)
for target, response in replica_warmup_responses.items():
if response:
warmed_up_targets.add(target)

now = time.time()
if now > timeout or all(list(replica_warmup_responses.values())):
self._logger.debug(f'completed warmup task in {now - start_time}s.')
return

await asyncio.sleep(0.2)
except Exception as ex:
self._logger.error(f'error with warmup up task: {ex}')
return

def _get_all_replicas(self, deployment):
replica_set = set()
replica_set.update(self._connections.get_replicas_all_shards(deployment))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can extract out the warmup logic into a Warmer class but the self._connections is the inner class _ConnectionPoolMap of the GrpcConnectionPool. Does the inner class still make sense?

replica_set.add(
self._connections.get_replicas(deployment=deployment, head=True)
)

return set(filter(None, replica_set))

@staticmethod
def __aio_channel_with_tracing_interceptor(
address,
Expand Down
19 changes: 18 additions & 1 deletion jina/serve/runtimes/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import asyncio
import signal
import threading
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unneeded import

import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, Union
Expand All @@ -17,7 +18,6 @@

if TYPE_CHECKING: # pragma: no cover
import multiprocessing
import threading

HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
Expand Down Expand Up @@ -76,6 +76,8 @@ def _cancel(signum, frame):
self._start_time = time.time()
self._loop.run_until_complete(self.async_setup())
self._send_telemetry_event()
self.warmup_task = None
self.warmup_stop_event = threading.Event()

def _send_telemetry_event(self):
send_telemetry_event(event='start', obj=self, entity_id=self._entity_id)
Expand Down Expand Up @@ -161,6 +163,21 @@ async def async_run_forever(self):
"""The async method to run until it is stopped."""
...

async def cancel_warmup_task(self):
'''Cancel warmup task if exists and is not completed. Cancellation is required if the Flow is being terminated before the
task is successful or hasn't reached the max timeout.
'''
if self.warmup_task:
try:
if not self.warmup_task.done():
self.logger.debug(f'Cancelling warmup task.')
self.warmup_stop_event.set()
await self.warmup_task
self.warmup_task.exception()
except Exception as ex:
self.logger.debug(f'exception during warmup task cancellation: {ex}')
pass

# Static methods used by the Pod to communicate with the `Runtime` in the separate process

@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions jina/serve/runtimes/gateway/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,22 @@ async def _wait_for_cancel(self):

async def async_teardown(self):
"""Shutdown the server."""
await self.cancel_warmup_task()
await self.gateway.streamer.close()
await self.gateway.shutdown()
await self.async_cancel()

async def async_cancel(self):
"""Stop the server."""
await self.cancel_warmup_task()
await self.gateway.streamer.close()
await self.gateway.shutdown()

async def async_run_forever(self):
"""Running method of the server."""
self.warmup_task = asyncio.create_task(
self.gateway.streamer.warmup(self.warmup_stop_event)
)
await self.gateway.run_server()
self.is_cancel.set()

Expand Down
16 changes: 13 additions & 3 deletions jina/serve/runtimes/gateway/composite/gateway.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import copy
from typing import Any, List, Optional

Expand Down Expand Up @@ -36,18 +37,27 @@ async def setup_server(self):
"""
setup GRPC server
"""
tasks = []
for gateway in self.gateways:
await gateway.setup_server()
tasks.append(asyncio.create_task(gateway.setup_server()))

await asyncio.gather(*tasks)

async def shutdown(self):
"""Free other resources allocated with the server, e.g, gateway object, ..."""
shutdown_tasks = []
for gateway in self.gateways:
await gateway.shutdown()
shutdown_tasks.append(asyncio.create_task(gateway.shutdown()))

await asyncio.gather(*shutdown_tasks)

async def run_server(self):
"""Run GRPC server forever"""
run_server_tasks = []
for gateway in self.gateways:
await gateway.run_server()
run_server_tasks.append(asyncio.create_task(gateway.run_server()))

await asyncio.gather(*run_server_tasks)

@staticmethod
def _deepcopy_with_ignore_attrs(obj: Any, ignore_attrs: List[str]) -> Any:
Expand Down
17 changes: 15 additions & 2 deletions jina/serve/runtimes/head/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import asyncio
import json
import os
from abc import ABC
Expand Down Expand Up @@ -158,24 +159,35 @@ async def async_setup(self):
service, health_pb2.HealthCheckResponse.SERVING
)
reflection.enable_server_reflection(service_names, self._grpc_server)

bind_addr = f'{self.args.host}:{self.args.port}'
self._grpc_server.add_insecure_port(bind_addr)
self.logger.debug(f'start listening on {bind_addr}')
await self._grpc_server.start()

def _warmup(self):
self.warmup_task = asyncio.create_task(
self.request_handler.warmup(
connection_pool=self.connection_pool,
stop_event=self.warmup_stop_event,
deployment=self._deployment_name,
)
)

async def async_run_forever(self):
"""Block until the GRPC server is terminated"""
self._warmup()
await self._grpc_server.wait_for_termination()

async def async_cancel(self):
"""Stop the GRPC server"""
self.logger.debug('cancel HeadRuntime')

await self.cancel_warmup_task()
await self._grpc_server.stop(0)

async def async_teardown(self):
"""Close the connection pool"""
await self.cancel_warmup_task()
await self._health_servicer.enter_graceful_shutdown()
await self.async_cancel()
await self.connection_pool.close()
Expand Down Expand Up @@ -294,6 +306,7 @@ async def _status(self, empty, context) -> jina_pb2.JinaInfoProto:
:param context: grpc context
:returns: the response request
"""
self.logger.debug('recv _status request')
infoProto = jina_pb2.JinaInfoProto()
version, env_info = get_full_version()
for k, v in version.items():
Expand Down
26 changes: 25 additions & 1 deletion jina/serve/runtimes/head/request_handling.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import asyncio
from typing import TYPE_CHECKING, Dict, Optional, Tuple

from jina.serve.networking import GrpcConnectionPool
from jina.serve.runtimes.monitoring import MonitoringRequestMixin
from jina.serve.runtimes.worker.request_handling import WorkerRequestHandler

if TYPE_CHECKING: # pragma: no cover
import threading

from opentelemetry.metrics import Meter
from prometheus_client import CollectorRegistry

Expand Down Expand Up @@ -164,7 +167,9 @@ async def _handle_data_request(
elif len(worker_results) > 1 and not reduce:
# worker returned multiple responses, but the head is configured to skip reduction
# just concatenate the docs in this case
response_request.data.docs = WorkerRequestHandler.get_docs_from_request(requests)
response_request.data.docs = WorkerRequestHandler.get_docs_from_request(
requests
)

merged_metadata = self._merge_metadata(
metadata,
Expand All @@ -177,3 +182,22 @@ async def _handle_data_request(
self._update_end_request_metrics(response_request)

return response_request, merged_metadata

async def warmup(
self,
connection_pool: GrpcConnectionPool,
stop_event: 'threading.Event',
deployment: str,
):
'''Executes warmup task against the deployments from the connection pool.
:param connection_pool: GrpcConnectionPool that implements the warmup to the connected deployments.
:param stop_event: signal to indicate if an early termination of the task is required for graceful teardown.
:param deployment: deployment name that need to be warmed up.
'''
self.logger.debug(f'Running HeadRuntime warmup')

try:
await connection_pool.warmup(deployment=deployment, stop_event=stop_event)
except Exception as ex:
self.logger.error(f'error with HeadRuntime warmup up task: {ex}')
return
3 changes: 2 additions & 1 deletion jina/serve/runtimes/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async def _async_setup_grpc_server(self):
self._health_servicer, self._grpc_server
)

reflection.enable_server_reflection(service_names, self._grpc_server)
reflection.enable_server_reflection(service_names, self._grpc_server)
bind_addr = f'{self.args.host}:{self.args.port}'
self.logger.debug(f'start listening on {bind_addr}')
self._grpc_server.add_insecure_port(bind_addr)
Expand Down Expand Up @@ -306,6 +306,7 @@ async def _status(self, empty, context) -> jina_pb2.JinaInfoProto:
:param context: grpc context
:returns: the response request
"""
self.logger.debug('recv _status request')
info_proto = jina_pb2.JinaInfoProto()
version, env_info = get_full_version()
for k, v in version.items():
Expand Down
4 changes: 1 addition & 3 deletions jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,7 @@ def get_docs_from_request(
"""
if len(requests) > 1:
result = DocumentArray(
d
for r in reversed(requests)
for d in getattr(r, 'docs')
d for r in reversed(requests) for d in getattr(r, 'docs')
)
else:
result = getattr(requests[0], 'docs')
Expand Down