diff --git a/jina/clients/request/helper.py b/jina/clients/request/helper.py index 6dc653e709cd1..e00b2ec06bf57 100644 --- a/jina/clients/request/helper.py +++ b/jina/clients/request/helper.py @@ -1,10 +1,8 @@ """Module for helper functions for clients.""" from typing import Tuple -from docarray import DocumentArray, Document - +from docarray import Document, DocumentArray from jina.enums import DataInputType -from jina.excepts import BadRequestType from jina.types.request.data import DataRequest @@ -72,30 +70,3 @@ def _add_docs(req, batch, data_type, _kwargs): d, data_type = _new_doc_from_data(content, data_type, **_kwargs) da.append(d) req.data.docs = da - - -def _add_control_propagate(req, kwargs): - from jina.proto import jina_pb2 - - extra_kwargs = kwargs[ - 'extra_kwargs' - ] #: control command and args are stored inside extra_kwargs - _available_commands = dict( - jina_pb2.RequestProto.ControlRequestProto.DESCRIPTOR.enum_values_by_name - ) - - if 'command' in extra_kwargs: - command = extra_kwargs['command'] - else: - raise BadRequestType( - 'sending ControlRequest from Client must contain the field `command`' - ) - - if command in _available_commands: - req.control.command = getattr( - jina_pb2.RequestProto.ControlRequestProto, command - ) - else: - raise ValueError( - f'command "{command}" is not supported, must be one of {_available_commands}' - ) diff --git a/jina/orchestrate/pods/helper.py b/jina/orchestrate/pods/helper.py index efc99fb60a8da..50397c2e60d38 100644 --- a/jina/orchestrate/pods/helper.py +++ b/jina/orchestrate/pods/helper.py @@ -3,13 +3,9 @@ from functools import partial from typing import TYPE_CHECKING -from grpc import RpcError - from jina.enums import GatewayProtocolType, PodRoleType from jina.hubble.helper import is_valid_huburi from jina.hubble.hubio import HubIO -from jina.serve.networking import GrpcConnectionPool -from jina.types.request.control import ControlRequest if TYPE_CHECKING: from argparse import Namespace @@ -94,18 +90,3 @@ def update_runtime_cls(args, copy=False) -> 'Namespace': _args.runtime_cls = 'HeadRuntime' return _args - - -def is_ready(address: str) -> bool: - """ - TODO: make this async - Check if status is ready. - :param address: the address where the control message needs to be sent - :return: True if status is ready else False. - """ - - try: - GrpcConnectionPool.send_request_sync(ControlRequest('STATUS'), address) - except RpcError: - return False - return True diff --git a/jina/proto/serializer.py b/jina/proto/serializer.py index 5b60a535690d0..594dec42da121 100644 --- a/jina/proto/serializer.py +++ b/jina/proto/serializer.py @@ -2,39 +2,9 @@ from typing import Iterable, List, Union from jina.proto import jina_pb2 -from jina.types.request.control import ControlRequest from jina.types.request.data import DataRequest -class ControlRequestProto: - """This class is a drop-in replacement for gRPC default serializer. - - It replace default serializer to make sure we always work with `Request` - - """ - - @staticmethod - def SerializeToString(x: 'ControlRequest'): - """ - # noqa: DAR101 - # noqa: DAR102 - # noqa: DAR201 - """ - return x.proto.SerializePartialToString() - - @staticmethod - def FromString(x: bytes): - """ - # noqa: DAR101 - # noqa: DAR102 - # noqa: DAR201 - """ - proto = jina_pb2.ControlRequestProto() - proto.ParseFromString(x) - - return ControlRequest(request=proto) - - class DataRequestProto: """This class is a drop-in replacement for gRPC default serializer. diff --git a/jina/resources/health_check/pod.py b/jina/resources/health_check/pod.py index faec7192f6b3b..ffd6705ca90b1 100644 --- a/jina/resources/health_check/pod.py +++ b/jina/resources/health_check/pod.py @@ -3,22 +3,14 @@ def check_health_pod(addr: str): :param addr: the address on which the pod is serving ex : localhost:1234 """ - import grpc - - from jina.serve.networking import GrpcConnectionPool - from jina.types.request.control import ControlRequest - - try: - GrpcConnectionPool.send_request_sync( - request=ControlRequest('STATUS'), - target=addr, - ) - except grpc.RpcError as e: - print('The pod is unhealthy') - print(e) - raise e - - print('The pod is healthy') + from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime + + is_ready = AsyncNewLoopRuntime.is_ready(addr) + + if not is_ready: + raise Exception('Pod is unhealthy') + + print('The Pod is healthy') if __name__ == '__main__': diff --git a/jina/serve/networking.py b/jina/serve/networking.py index 37f72430f16a2..3c1c4b07d6aae 100644 --- a/jina/serve/networking.py +++ b/jina/serve/networking.py @@ -8,6 +8,7 @@ import grpc from grpc.aio import AioRpcError +from grpc_health.v1 import health_pb2, health_pb2_grpc from grpc_reflection.v1alpha.reflection_pb2 import ServerReflectionRequest from grpc_reflection.v1alpha.reflection_pb2_grpc import ServerReflectionStub @@ -16,7 +17,6 @@ from jina.logging.logger import JinaLogger from jina.proto import jina_pb2, jina_pb2_grpc from jina.types.request import Request -from jina.types.request.control import ControlRequest from jina.types.request.data import DataRequest TLS_PROTOCOL_SCHEMES = ['grpcs', 'https', 'wss'] @@ -26,6 +26,8 @@ if TYPE_CHECKING: from prometheus_client import CollectorRegistry + from jina.types.request.control import ControlRequest + class ReplicaList: """ @@ -153,7 +155,6 @@ class ConnectionStubs: """ STUB_MAPPING = { - 'jina.JinaControlRequestRPC': jina_pb2_grpc.JinaControlRequestRPCStub, 'jina.JinaDataRequestRPC': jina_pb2_grpc.JinaDataRequestRPCStub, 'jina.JinaSingleDataRequestRPC': jina_pb2_grpc.JinaSingleDataRequestRPCStub, 'jina.JinaDiscoverEndpointsRPC': jina_pb2_grpc.JinaDiscoverEndpointsRPCStub, @@ -175,7 +176,6 @@ async def _init_stubs(self): stubs = defaultdict(lambda: None) for service in available_services: stubs[service] = self.STUB_MAPPING[service](self.channel) - self.control_stub = stubs['jina.JinaControlRequestRPC'] self.data_list_stub = stubs['jina.JinaDataRequestRPC'] self.single_data_stub = stubs['jina.JinaSingleDataRequestRPC'] self.stream_stub = stubs['jina.JinaRPC'] @@ -265,20 +265,6 @@ async def send_requests( raise ValueError( 'Can not send list of DataRequests. gRPC endpoint not available.' ) - elif request_type == ControlRequest: - if self.control_stub: - call_result = self.control_stub.process_control( - requests[0], timeout=timeout - ) - metadata, response = ( - await call_result.trailing_metadata(), - await call_result, - ) - return response, metadata - else: - raise ValueError( - 'Can not send ControlRequest. gRPC endpoint not available.' - ) else: raise ValueError(f'Unsupported request type {type(requests[0])}') @@ -514,7 +500,7 @@ def send_requests( ) -> List[asyncio.Task]: """Send a request to target via one or all of the pooled connections, depending on polling_type - :param requests: request (DataRequest/ControlRequest) to send + :param requests: request (DataRequest) to send :param deployment: name of the Jina deployment to send the request to :param head: If True it is send to the head, otherwise to the worker pods :param shard_id: Send to a specific shard of the deployment, ignored for polling ALL @@ -792,7 +778,7 @@ def activate_worker_sync( worker_port: int, target_head: str, shard_id: Optional[int] = None, - ) -> ControlRequest: + ): """ Register a given worker to a head by sending an activate request @@ -802,6 +788,8 @@ def activate_worker_sync( :param shard_id: id of the shard the worker belongs to :returns: the response request """ + from jina.types.request.control import ControlRequest + activate_request = ControlRequest(command='ACTIVATE') activate_request.add_related_entity( 'worker', worker_host, worker_port, shard_id @@ -813,54 +801,6 @@ def activate_worker_sync( return GrpcConnectionPool.send_request_sync(activate_request, target_head) - @staticmethod - async def activate_worker( - worker_host: str, - worker_port: int, - target_head: str, - shard_id: Optional[int] = None, - ) -> ControlRequest: - """ - Register a given worker to a head by sending an activate request - - :param worker_host: the host address of the worker - :param worker_port: the port of the worker - :param target_head: address of the head to send the activate request to - :param shard_id: id of the shard the worker belongs to - :returns: the response request - """ - activate_request = ControlRequest(command='ACTIVATE') - activate_request.add_related_entity( - 'worker', worker_host, worker_port, shard_id - ) - return await GrpcConnectionPool.send_request_async( - activate_request, target_head - ) - - @staticmethod - async def deactivate_worker( - worker_host: str, - worker_port: int, - target_head: str, - shard_id: Optional[int] = None, - ) -> ControlRequest: - """ - Remove a given worker to a head by sending a deactivate request - - :param worker_host: the host address of the worker - :param worker_port: the port of the worker - :param target_head: address of the head to send the deactivate request to - :param shard_id: id of the shard the worker belongs to - :returns: the response request - """ - activate_request = ControlRequest(command='DEACTIVATE') - activate_request.add_related_entity( - 'worker', worker_host, worker_port, shard_id - ) - return await GrpcConnectionPool.send_request_async( - activate_request, target_head - ) - @staticmethod def send_request_sync( request: Request, @@ -890,22 +830,51 @@ def send_request_sync( tls=tls, root_certificates=root_certificates, ) as channel: - if type(request) == DataRequest: - metadata = (('endpoint', endpoint),) if endpoint else None - stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) - response, call = stub.process_single_data.with_call( - request, - timeout=timeout, - metadata=metadata, - ) - elif type(request) == ControlRequest: - stub = jina_pb2_grpc.JinaControlRequestRPCStub(channel) - response = stub.process_control(request, timeout=timeout) + metadata = (('endpoint', endpoint),) if endpoint else None + stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) + response, call = stub.process_single_data.with_call( + request, + timeout=timeout, + metadata=metadata, + ) return response except grpc.RpcError as e: if e.code() != grpc.StatusCode.UNAVAILABLE or i == 2: raise + @staticmethod + def send_health_check_sync( + target: str, + timeout=100.0, + tls=False, + root_certificates: Optional[str] = None, + ) -> health_pb2.HealthCheckResponse: + """ + Sends a request synchronously to the target via grpc + + :param target: where to send the request to, like 127.0.0.1:8080 + :param timeout: timeout for the send + :param tls: if True, use tls encryption for the grpc channel + :param root_certificates: the path to the root certificates for tls, only used if tls is True + + :returns: the response health check + """ + + for i in range(3): + try: + with GrpcConnectionPool.get_grpc_channel( + target, + tls=tls, + root_certificates=root_certificates, + ) as channel: + health_check_req = health_pb2.HealthCheckRequest() + health_check_req.service = '' + stub = health_pb2_grpc.HealthStub(channel) + return stub.Check(health_check_req, timeout=timeout) + except grpc.RpcError as e: + if e.code() != grpc.StatusCode.UNAVAILABLE or i == 2: + raise + @staticmethod def send_requests_sync( requests: List[Request], @@ -985,12 +954,8 @@ async def send_request_async( tls=tls, root_certificates=root_certificates, ) as channel: - if type(request) == DataRequest: - stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) - return await stub.process_single_data(request, timeout=timeout) - elif type(request) == ControlRequest: - stub = jina_pb2_grpc.JinaControlRequestRPCStub(channel) - return await stub.process_control(request, timeout=timeout) + stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) + return await stub.process_single_data(request, timeout=timeout) @staticmethod def create_async_channel_stub( @@ -1004,7 +969,7 @@ def create_async_channel_stub( :param root_certificates: the path to the root certificates for tls, only u :param summary: Optional Prometheus summary object - :returns: DataRequest/ControlRequest stubs and an async grpc channel + :returns: DataRequest stubs and an async grpc channel """ channel = GrpcConnectionPool.get_grpc_channel( address, diff --git a/jina/serve/runtimes/asyncio.py b/jina/serve/runtimes/asyncio.py index 22fd5fc326716..f717e9644723c 100644 --- a/jina/serve/runtimes/asyncio.py +++ b/jina/serve/runtimes/asyncio.py @@ -12,7 +12,6 @@ from jina.serve.networking import GrpcConnectionPool from jina.serve.runtimes.base import BaseRuntime from jina.serve.runtimes.monitoring import MonitoringMixin -from jina.types.request.control import ControlRequest from jina.types.request.data import DataRequest if TYPE_CHECKING: @@ -148,12 +147,15 @@ def is_ready(ctrl_address: str, **kwargs) -> bool: """ try: - GrpcConnectionPool.send_request_sync( - ControlRequest('STATUS'), ctrl_address, timeout=1.0 + from grpc_health.v1 import health_pb2, health_pb2_grpc + + response = GrpcConnectionPool.send_health_check_sync( + ctrl_address, timeout=1.0 ) - except RpcError as e: + # TODO: Get the proper value of the ServingStatus SERVING KEY + return response.status == 'SERVING' + except RpcError: return False - return True @staticmethod def wait_for_ready_or_shutdown( @@ -181,16 +183,8 @@ def wait_for_ready_or_shutdown( time.sleep(0.1) return False - def _log_info_msg(self, request: Union[ControlRequest, DataRequest]): - if type(request) == DataRequest: - self._log_data_request(request) - elif type(request) == ControlRequest: - self._log_control_request(request) - - def _log_control_request(self, request: ControlRequest): - self.logger.debug( - f'recv ControlRequest {request.command} with id: {request.header.request_id}' - ) + def _log_info_msg(self, request: DataRequest): + self._log_data_request(request) def _log_data_request(self, request: DataRequest): self.logger.debug( diff --git a/jina/serve/runtimes/gateway/grpc/__init__.py b/jina/serve/runtimes/gateway/grpc/__init__.py index d25a33a06050b..9d6453a72b838 100644 --- a/jina/serve/runtimes/gateway/grpc/__init__.py +++ b/jina/serve/runtimes/gateway/grpc/__init__.py @@ -1,6 +1,8 @@ +import argparse import os import grpc +from grpc_health.v1 import health, health_pb2, health_pb2_grpc from grpc_reflection.v1alpha import reflection from jina import __default_host__ @@ -13,12 +15,22 @@ __all__ = ['GRPCGatewayRuntime'] -from jina.types.request.control import ControlRequest - class GRPCGatewayRuntime(GatewayRuntime): """Gateway Runtime for gRPC.""" + def __init__( + self, + args: argparse.Namespace, + **kwargs, + ): + """Initialize the runtime + :param args: args from CLI + :param kwargs: keyword args + """ + self._health_servicer = health.HealthServicer(experimental_non_blocking=True) + super().__init__(args, **kwargs) + async def async_setup(self): """ The async method to setup. @@ -59,13 +71,16 @@ async def _async_setup_server(self): self.streamer.Call = self.streamer.stream jina_pb2_grpc.add_JinaRPCServicer_to_server(self.streamer, self.server) - jina_pb2_grpc.add_JinaControlRequestRPCServicer_to_server(self, self.server) service_names = ( jina_pb2.DESCRIPTOR.services_by_name['JinaRPC'].full_name, - jina_pb2.DESCRIPTOR.services_by_name['JinaControlRequestRPC'].full_name, reflection.SERVICE_NAME, ) + # Mark all services as healthy. + health_pb2_grpc.add_HealthServicer_to_server(self._health_servicer, self.server) + + for service in service_names: + self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING) reflection.enable_server_reflection(service_names, self.server) bind_addr = f'{__default_host__}:{self.args.port}' @@ -100,6 +115,7 @@ async def async_teardown(self): """Close the connection pool""" # usually async_cancel should already have been called, but then its a noop # if the runtime is stopped without a sigterm (e.g. as a context manager, this can happen) + self._health_servicer.enter_graceful_shutdown() await self.async_cancel() await self._connection_pool.close() @@ -111,19 +127,3 @@ async def async_run_forever(self): """The async running of server.""" self._connection_pool.start() await self.server.wait_for_termination() - - async def process_control(self, request: ControlRequest, *args) -> ControlRequest: - """ - Should be used to check readiness by sending STATUS ControlRequests. - Throws for any other command than STATUS. - - :param request: the ControlRequest, should have command 'STATUS' - :param args: additional arguments in the grpc call, ignored - :returns: will be the original request - """ - - if self.logger.debug_enabled: - self._log_control_request(request) - if request.command != 'STATUS': - raise ValueError('gateway only support STATUS ControlRequests') - return request diff --git a/jina/serve/runtimes/head/__init__.py b/jina/serve/runtimes/head/__init__.py index f7d6234e17b84..84a76cacfde75 100644 --- a/jina/serve/runtimes/head/__init__.py +++ b/jina/serve/runtimes/head/__init__.py @@ -8,6 +8,7 @@ from typing import Dict, List, Optional, Tuple import grpc +from grpc_health.v1 import health, health_pb2, health_pb2_grpc from grpc_reflection.v1alpha import reflection from jina.enums import PollingType @@ -16,7 +17,6 @@ from jina.serve.networking import GrpcConnectionPool from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime from jina.serve.runtimes.request_handlers.data_request_handler import DataRequestHandler -from jina.types.request.control import ControlRequest from jina.types.request.data import DataRequest @@ -36,6 +36,8 @@ def __init__( :param args: args from CLI :param kwargs: keyword args """ + self._health_servicer = health.HealthServicer(experimental_non_blocking=True) + super().__init__(args, **kwargs) if args.name is None: args.name = '' @@ -147,19 +149,22 @@ async def async_setup(self): self, self._grpc_server ) jina_pb2_grpc.add_JinaDataRequestRPCServicer_to_server(self, self._grpc_server) - jina_pb2_grpc.add_JinaControlRequestRPCServicer_to_server( - self, self._grpc_server - ) jina_pb2_grpc.add_JinaDiscoverEndpointsRPCServicer_to_server( self, self._grpc_server ) service_names = ( jina_pb2.DESCRIPTOR.services_by_name['JinaSingleDataRequestRPC'].full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaDataRequestRPC'].full_name, - jina_pb2.DESCRIPTOR.services_by_name['JinaControlRequestRPC'].full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaDiscoverEndpointsRPC'].full_name, reflection.SERVICE_NAME, ) + # Mark all services as healthy. + health_pb2_grpc.add_HealthServicer_to_server( + self._health_servicer, self._grpc_server + ) + + for service in service_names: + self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING) reflection.enable_server_reflection(service_names, self._grpc_server) bind_addr = f'0.0.0.0:{self.args.port}' @@ -180,6 +185,7 @@ async def async_cancel(self): async def async_teardown(self): """Close the connection pool""" + self._health_servicer.enter_graceful_shutdown() await self.async_cancel() await self.connection_pool.close() @@ -218,48 +224,6 @@ async def process_data(self, requests: List[DataRequest], context) -> DataReques context.set_trailing_metadata((('is-error', 'true'),)) return requests[0] - async def process_control(self, request: ControlRequest, *args) -> ControlRequest: - """ - Process the received control request and return the input request - - :param request: the data request to process - :param args: additional arguments in the grpc call, ignored - :returns: the input request - """ - try: - if self.logger.debug_enabled: - self._log_control_request(request) - - if request.command == 'ACTIVATE': - - for relatedEntity in request.relatedEntities: - connection_string = f'{relatedEntity.address}:{relatedEntity.port}' - - self.connection_pool.add_connection( - deployment=self._deployment_name, - address=connection_string, - shard_id=relatedEntity.shard_id - if relatedEntity.HasField('shard_id') - else None, - ) - elif request.command == 'DEACTIVATE': - for relatedEntity in request.relatedEntities: - connection_string = f'{relatedEntity.address}:{relatedEntity.port}' - await self.connection_pool.remove_connection( - deployment=self._deployment_name, - address=connection_string, - shard_id=relatedEntity.shard_id, - ) - return request - except (RuntimeError, Exception) as ex: - self.logger.error( - f'{ex!r}' + f'\n add "--quiet-error" to suppress the exception details' - if not self.args.quiet_error - else '', - exc_info=not self.args.quiet_error, - ) - raise - async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: """ USes the connection pool to send a discover endpoint call to the workers diff --git a/jina/serve/runtimes/worker/__init__.py b/jina/serve/runtimes/worker/__init__.py index 86b4cba31360a..05d3c6b83691f 100644 --- a/jina/serve/runtimes/worker/__init__.py +++ b/jina/serve/runtimes/worker/__init__.py @@ -26,8 +26,8 @@ def __init__( :param args: args from CLI :param kwargs: keyword args """ - super().__init__(args, **kwargs) self._health_servicer = health.HealthServicer(experimental_non_blocking=True) + super().__init__(args, **kwargs) async def async_setup(self): """ diff --git a/jina/types/request/control.py b/jina/types/request/control.py deleted file mode 100644 index 025e579e291d6..0000000000000 --- a/jina/types/request/control.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import Optional - -from jina.helper import random_identity, typename -from jina.proto import jina_pb2 -from jina.types.request import Request - -_available_commands = dict(jina_pb2.ControlRequestProto.DESCRIPTOR.enum_values_by_name) - - -class ControlRequest(Request): - """ - :class:`ControlRequest` is one of the **primitive data type** in Jina. - - It offers a Pythonic interface to allow users access and manipulate - :class:`jina.jina_pb2.ControlRequestProto` object without working with Protobuf itself. - - A container for serialized :class:`jina_pb2.ControlRequestProto` that only triggers deserialization - and decompression when receives the first read access to its member. - - It overrides :meth:`__getattr__` to provide the same get/set interface as an - :class:`jina_pb2.ControlRequestProtoProto` object. - - :param command: the command for this request, can be STATUS, ACTIVATE or DEACTIVATE - :param request: The request. - """ - - def __init__( - self, - command: Optional[str] = None, - request: Optional['jina_pb2.jina_pb2.ControlRequestProto'] = None, - ): - - if isinstance(request, jina_pb2.ControlRequestProto): - self._pb_body = request - elif request is not None: - # note ``None`` is not considered as a bad type - raise ValueError(f'{typename(request)} is not recognizable') - if command: - proto = jina_pb2.ControlRequestProto() - proto.header.request_id = random_identity() - if command in _available_commands: - proto.command = getattr(jina_pb2.ControlRequestProto, command) - else: - raise ValueError( - f'command "{command}" is not supported, must be one of {_available_commands}' - ) - self._pb_body = proto - - def add_related_entity( - self, id: str, address: str, port: int, shard_id: Optional[int] = None - ): - """ - Add a related entity to this ControlMessage - - :param id: jina id of the entity - :param address: address of the entity - :param port: Port of the entity - :param shard_id: Optional id of the shard this entity belongs to - """ - self.proto.relatedEntities.append( - jina_pb2.RelatedEntity(id=id, address=address, port=port, shard_id=shard_id) - ) - - @property - def proto(self) -> 'jina_pb2.ControlRequestProto': - """ - Cast ``self`` to a :class:`jina_pb2.ControlRequestProto`. Laziness will be broken and serialization will be recomputed when calling - :meth:`SerializeToString`. - :return: protobuf instance - """ - return self._pb_body - - @property - def command(self) -> str: - """Get the command. - - .. #noqa: DAR201""" - return jina_pb2.ControlRequestProto.Command.Name(self.proto.command)