Skip to content

Commit

Permalink
feat: remove ControlRequest
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed May 12, 2022
1 parent 08b58dd commit b04f9df
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 341 deletions.
31 changes: 1 addition & 30 deletions jina/clients/request/helper.py
@@ -1,10 +1,8 @@
"""Module for helper functions for clients."""
from typing import Tuple

from docarray import DocumentArray, Document

from docarray import Document, DocumentArray
from jina.enums import DataInputType
from jina.excepts import BadRequestType
from jina.types.request.data import DataRequest


Expand Down Expand Up @@ -72,30 +70,3 @@ def _add_docs(req, batch, data_type, _kwargs):
d, data_type = _new_doc_from_data(content, data_type, **_kwargs)
da.append(d)
req.data.docs = da


def _add_control_propagate(req, kwargs):
from jina.proto import jina_pb2

extra_kwargs = kwargs[
'extra_kwargs'
] #: control command and args are stored inside extra_kwargs
_available_commands = dict(
jina_pb2.RequestProto.ControlRequestProto.DESCRIPTOR.enum_values_by_name
)

if 'command' in extra_kwargs:
command = extra_kwargs['command']
else:
raise BadRequestType(
'sending ControlRequest from Client must contain the field `command`'
)

if command in _available_commands:
req.control.command = getattr(
jina_pb2.RequestProto.ControlRequestProto, command
)
else:
raise ValueError(
f'command "{command}" is not supported, must be one of {_available_commands}'
)
19 changes: 0 additions & 19 deletions jina/orchestrate/pods/helper.py
Expand Up @@ -3,13 +3,9 @@
from functools import partial
from typing import TYPE_CHECKING

from grpc import RpcError

from jina.enums import GatewayProtocolType, PodRoleType
from jina.hubble.helper import is_valid_huburi
from jina.hubble.hubio import HubIO
from jina.serve.networking import GrpcConnectionPool
from jina.types.request.control import ControlRequest

if TYPE_CHECKING:
from argparse import Namespace
Expand Down Expand Up @@ -94,18 +90,3 @@ def update_runtime_cls(args, copy=False) -> 'Namespace':
_args.runtime_cls = 'HeadRuntime'

return _args


def is_ready(address: str) -> bool:
"""
TODO: make this async
Check if status is ready.
:param address: the address where the control message needs to be sent
:return: True if status is ready else False.
"""

try:
GrpcConnectionPool.send_request_sync(ControlRequest('STATUS'), address)
except RpcError:
return False
return True
30 changes: 0 additions & 30 deletions jina/proto/serializer.py
Expand Up @@ -2,39 +2,9 @@
from typing import Iterable, List, Union

from jina.proto import jina_pb2
from jina.types.request.control import ControlRequest
from jina.types.request.data import DataRequest


class ControlRequestProto:
"""This class is a drop-in replacement for gRPC default serializer.
It replace default serializer to make sure we always work with `Request`
"""

@staticmethod
def SerializeToString(x: 'ControlRequest'):
"""
# noqa: DAR101
# noqa: DAR102
# noqa: DAR201
"""
return x.proto.SerializePartialToString()

@staticmethod
def FromString(x: bytes):
"""
# noqa: DAR101
# noqa: DAR102
# noqa: DAR201
"""
proto = jina_pb2.ControlRequestProto()
proto.ParseFromString(x)

return ControlRequest(request=proto)


class DataRequestProto:
"""This class is a drop-in replacement for gRPC default serializer.
Expand Down
24 changes: 8 additions & 16 deletions jina/resources/health_check/pod.py
Expand Up @@ -3,22 +3,14 @@ def check_health_pod(addr: str):
:param addr: the address on which the pod is serving ex : localhost:1234
"""
import grpc

from jina.serve.networking import GrpcConnectionPool
from jina.types.request.control import ControlRequest

try:
GrpcConnectionPool.send_request_sync(
request=ControlRequest('STATUS'),
target=addr,
)
except grpc.RpcError as e:
print('The pod is unhealthy')
print(e)
raise e

print('The pod is healthy')
from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime

is_ready = AsyncNewLoopRuntime.is_ready(addr)

if not is_ready:
raise Exception('Pod is unhealthy')

print('The Pod is healthy')


if __name__ == '__main__':
Expand Down
135 changes: 50 additions & 85 deletions jina/serve/networking.py
Expand Up @@ -8,6 +8,7 @@

import grpc
from grpc.aio import AioRpcError
from grpc_health.v1 import health_pb2, health_pb2_grpc
from grpc_reflection.v1alpha.reflection_pb2 import ServerReflectionRequest
from grpc_reflection.v1alpha.reflection_pb2_grpc import ServerReflectionStub

Expand All @@ -16,7 +17,6 @@
from jina.logging.logger import JinaLogger
from jina.proto import jina_pb2, jina_pb2_grpc
from jina.types.request import Request
from jina.types.request.control import ControlRequest
from jina.types.request.data import DataRequest

TLS_PROTOCOL_SCHEMES = ['grpcs', 'https', 'wss']
Expand All @@ -26,6 +26,8 @@
if TYPE_CHECKING:
from prometheus_client import CollectorRegistry

from jina.types.request.control import ControlRequest


class ReplicaList:
"""
Expand Down Expand Up @@ -153,7 +155,6 @@ class ConnectionStubs:
"""

STUB_MAPPING = {
'jina.JinaControlRequestRPC': jina_pb2_grpc.JinaControlRequestRPCStub,
'jina.JinaDataRequestRPC': jina_pb2_grpc.JinaDataRequestRPCStub,
'jina.JinaSingleDataRequestRPC': jina_pb2_grpc.JinaSingleDataRequestRPCStub,
'jina.JinaDiscoverEndpointsRPC': jina_pb2_grpc.JinaDiscoverEndpointsRPCStub,
Expand All @@ -175,7 +176,6 @@ async def _init_stubs(self):
stubs = defaultdict(lambda: None)
for service in available_services:
stubs[service] = self.STUB_MAPPING[service](self.channel)
self.control_stub = stubs['jina.JinaControlRequestRPC']
self.data_list_stub = stubs['jina.JinaDataRequestRPC']
self.single_data_stub = stubs['jina.JinaSingleDataRequestRPC']
self.stream_stub = stubs['jina.JinaRPC']
Expand Down Expand Up @@ -265,20 +265,6 @@ async def send_requests(
raise ValueError(
'Can not send list of DataRequests. gRPC endpoint not available.'
)
elif request_type == ControlRequest:
if self.control_stub:
call_result = self.control_stub.process_control(
requests[0], timeout=timeout
)
metadata, response = (
await call_result.trailing_metadata(),
await call_result,
)
return response, metadata
else:
raise ValueError(
'Can not send ControlRequest. gRPC endpoint not available.'
)
else:
raise ValueError(f'Unsupported request type {type(requests[0])}')

Expand Down Expand Up @@ -514,7 +500,7 @@ def send_requests(
) -> List[asyncio.Task]:
"""Send a request to target via one or all of the pooled connections, depending on polling_type
:param requests: request (DataRequest/ControlRequest) to send
:param requests: request (DataRequest) to send
:param deployment: name of the Jina deployment to send the request to
:param head: If True it is send to the head, otherwise to the worker pods
:param shard_id: Send to a specific shard of the deployment, ignored for polling ALL
Expand Down Expand Up @@ -792,7 +778,7 @@ def activate_worker_sync(
worker_port: int,
target_head: str,
shard_id: Optional[int] = None,
) -> ControlRequest:
):
"""
Register a given worker to a head by sending an activate request
Expand All @@ -802,6 +788,8 @@ def activate_worker_sync(
:param shard_id: id of the shard the worker belongs to
:returns: the response request
"""
from jina.types.request.control import ControlRequest

activate_request = ControlRequest(command='ACTIVATE')
activate_request.add_related_entity(
'worker', worker_host, worker_port, shard_id
Expand All @@ -813,54 +801,6 @@ def activate_worker_sync(

return GrpcConnectionPool.send_request_sync(activate_request, target_head)

@staticmethod
async def activate_worker(
worker_host: str,
worker_port: int,
target_head: str,
shard_id: Optional[int] = None,
) -> ControlRequest:
"""
Register a given worker to a head by sending an activate request
:param worker_host: the host address of the worker
:param worker_port: the port of the worker
:param target_head: address of the head to send the activate request to
:param shard_id: id of the shard the worker belongs to
:returns: the response request
"""
activate_request = ControlRequest(command='ACTIVATE')
activate_request.add_related_entity(
'worker', worker_host, worker_port, shard_id
)
return await GrpcConnectionPool.send_request_async(
activate_request, target_head
)

@staticmethod
async def deactivate_worker(
worker_host: str,
worker_port: int,
target_head: str,
shard_id: Optional[int] = None,
) -> ControlRequest:
"""
Remove a given worker to a head by sending a deactivate request
:param worker_host: the host address of the worker
:param worker_port: the port of the worker
:param target_head: address of the head to send the deactivate request to
:param shard_id: id of the shard the worker belongs to
:returns: the response request
"""
activate_request = ControlRequest(command='DEACTIVATE')
activate_request.add_related_entity(
'worker', worker_host, worker_port, shard_id
)
return await GrpcConnectionPool.send_request_async(
activate_request, target_head
)

@staticmethod
def send_request_sync(
request: Request,
Expand Down Expand Up @@ -890,22 +830,51 @@ def send_request_sync(
tls=tls,
root_certificates=root_certificates,
) as channel:
if type(request) == DataRequest:
metadata = (('endpoint', endpoint),) if endpoint else None
stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel)
response, call = stub.process_single_data.with_call(
request,
timeout=timeout,
metadata=metadata,
)
elif type(request) == ControlRequest:
stub = jina_pb2_grpc.JinaControlRequestRPCStub(channel)
response = stub.process_control(request, timeout=timeout)
metadata = (('endpoint', endpoint),) if endpoint else None
stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel)
response, call = stub.process_single_data.with_call(
request,
timeout=timeout,
metadata=metadata,
)
return response
except grpc.RpcError as e:
if e.code() != grpc.StatusCode.UNAVAILABLE or i == 2:
raise

@staticmethod
def send_health_check_sync(
target: str,
timeout=100.0,
tls=False,
root_certificates: Optional[str] = None,
) -> health_pb2.HealthCheckResponse:
"""
Sends a request synchronously to the target via grpc
:param target: where to send the request to, like 127.0.0.1:8080
:param timeout: timeout for the send
:param tls: if True, use tls encryption for the grpc channel
:param root_certificates: the path to the root certificates for tls, only used if tls is True
:returns: the response health check
"""

for i in range(3):
try:
with GrpcConnectionPool.get_grpc_channel(
target,
tls=tls,
root_certificates=root_certificates,
) as channel:
health_check_req = health_pb2.HealthCheckRequest()
health_check_req.service = ''
stub = health_pb2_grpc.HealthStub(channel)
return stub.Check(health_check_req, timeout=timeout)
except grpc.RpcError as e:
if e.code() != grpc.StatusCode.UNAVAILABLE or i == 2:
raise

@staticmethod
def send_requests_sync(
requests: List[Request],
Expand Down Expand Up @@ -985,12 +954,8 @@ async def send_request_async(
tls=tls,
root_certificates=root_certificates,
) as channel:
if type(request) == DataRequest:
stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel)
return await stub.process_single_data(request, timeout=timeout)
elif type(request) == ControlRequest:
stub = jina_pb2_grpc.JinaControlRequestRPCStub(channel)
return await stub.process_control(request, timeout=timeout)
stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel)
return await stub.process_single_data(request, timeout=timeout)

@staticmethod
def create_async_channel_stub(
Expand All @@ -1004,7 +969,7 @@ def create_async_channel_stub(
:param root_certificates: the path to the root certificates for tls, only u
:param summary: Optional Prometheus summary object
:returns: DataRequest/ControlRequest stubs and an async grpc channel
:returns: DataRequest stubs and an async grpc channel
"""
channel = GrpcConnectionPool.get_grpc_channel(
address,
Expand Down

0 comments on commit b04f9df

Please sign in to comment.