From ef662b529b2a2eecea7bb99759a9f7b9d86d3062 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Wed, 8 Jun 2022 11:15:45 +0200 Subject: [PATCH] feat: add grpc health checking (#4779) --- docs/fundamentals/flow/health-check.md | 275 ++++++++++++++++++ docs/fundamentals/flow/index.md | 2 + extra-requirements.txt | 1 + jina/__init__.py | 3 +- jina/clients/base/__init__.py | 8 + jina/clients/base/grpc.py | 30 +- jina/clients/base/helper.py | 46 +++ jina/clients/base/http.py | 65 +++-- jina/clients/base/websocket.py | 50 ++++ jina/clients/grpc.py | 6 +- jina/clients/helper.py | 2 - jina/clients/http.py | 14 +- jina/clients/mixin.py | 12 + jina/clients/request/helper.py | 31 +- jina/clients/websocket.py | 6 +- jina/orchestrate/deployments/__init__.py | 64 ++-- jina/orchestrate/flow/base.py | 4 +- jina/orchestrate/pods/helper.py | 19 -- jina/proto/jina.proto | 48 +-- jina/proto/jina_pb2.py | 57 ++-- jina/proto/jina_pb2_grpc.py | 166 +++++------ jina/proto/serializer.py | 57 ++-- jina/resources/extra-requirements.txt | 1 + jina/resources/health_check/pod.py | 24 +- jina/serve/executors/__init__.py | 16 +- jina/serve/networking.py | 164 ++++------- jina/serve/runtimes/asyncio.py | 24 +- .../runtimes/gateway/graph/topology_graph.py | 28 +- jina/serve/runtimes/gateway/grpc/__init__.py | 66 ++++- jina/serve/runtimes/gateway/http/app.py | 50 +++- .../runtimes/gateway/request_handling.py | 19 +- jina/serve/runtimes/gateway/websocket/app.py | 95 +++++- jina/serve/runtimes/head/__init__.py | 58 +--- jina/serve/runtimes/worker/__init__.py | 45 +-- jina/types/mixin.py | 12 +- jina/types/request/control.py | 78 ----- jina/types/request/status.py | 60 ++++ tests/integration/flow-dry-run/__init__.py | 0 .../flow-dry-run/test_flow_dry_run.py | 14 + tests/integration/pods/container/test_pod.py | 21 +- tests/integration/pods/test_pod.py | 115 ++++---- .../runtimes/test_gateway_dry_run.py | 99 +++++++ .../runtimes/test_network_failures.py | 15 +- tests/integration/runtimes/test_runtimes.py | 133 ++++----- tests/integration/v2_api/test_func_routing.py | 3 +- tests/unit/serve/executors/test_executor.py | 4 + .../gateway/grpc/test_grpc_gateway_runtime.py | 5 +- .../runtimes/gateway/grpc/test_grpc_tls.py | 7 +- .../runtimes/gateway/http/test_models.py | 2 +- .../serve/runtimes/head/test_head_runtime.py | 82 +----- .../runtimes/worker/test_worker_runtime.py | 3 +- tests/unit/test_helper.py | 6 - tests/unit/types/request/test_request.py | 15 - .../unit/types/request/test_status_message.py | 41 +++ tests/unit/yaml/datauriindex.yml | 23 -- 55 files changed, 1361 insertions(+), 933 deletions(-) create mode 100644 docs/fundamentals/flow/health-check.md delete mode 100644 jina/types/request/control.py create mode 100644 jina/types/request/status.py create mode 100644 tests/integration/flow-dry-run/__init__.py create mode 100644 tests/integration/flow-dry-run/test_flow_dry_run.py create mode 100644 tests/integration/runtimes/test_gateway_dry_run.py create mode 100644 tests/unit/types/request/test_status_message.py delete mode 100644 tests/unit/yaml/datauriindex.yml diff --git a/docs/fundamentals/flow/health-check.md b/docs/fundamentals/flow/health-check.md new file mode 100644 index 0000000000000..31391f844f26a --- /dev/null +++ b/docs/fundamentals/flow/health-check.md @@ -0,0 +1,275 @@ +# Health and readiness check +Every Jina Flow consists of a {ref}`number of microservices `, +each of which have to be healthy before the Flow is ready to receive requests. + +Each Flow microservice provides a health check in the form of a [standardized gRPC endpoint](https://github.com/grpc/grpc/blob/master/doc/health-checking.md) that exposes this information to the outside world. +This means that health checks can automatically be performed by Jina itself as well as external tools like Docker Compose, Kubernetes service meshes, or load balancers. + +In most cases, it is most useful to check if an entire Flow is ready to accept requests. +To enable this readiness check, the Jina Gateway can aggregate health check information from all services and provides +a readiness check endpoint for the complete Flow. + +## Readiness of complete Flow + +A lot of times, it is useful to know if a Flow, as a complete set of microservices, is ready to receive requests. This is why the Gateway +exposes an endpoint for each of the supported protocols to know the health and readiness of the entire Flow. + +Jina `Flow` and `Client` offer a convenient API to query these readiness endpoints. You can call `flow.dry_run()` or `client.dry_run()`, which will return `True` if the Flow is healthy and ready, and `False` otherwise: + +````{tab} via Flow +```python +from jina import Flow + +with Flow().add() as f: + print(f.dry_run()) + +print(f.dry_run()) +``` +```text +True +False +``` +```` +````{tab} via Client +```python +from jina import Flow + +with Flow(port=12345).add() as f: + f.block() +``` +```python +from jina import Client + +client = Client(port=12345) +print(client.dry_run()) +``` +```text +True +``` +```` + +### Flow status using third-party clients + +You can check the status of a Flow using any gRPC/HTTP/Websocket client, not just Jina's Client implementation. + +To see how this works, first instantiate the Flow with its corresponding protocol and block it for serving: + +```python +from jina import Flow +import os + +PROTOCOL = 'grpc' # it could also be http or websocket + +os.setenv[ + 'JINA_LOG_LEVEL' +] = 'DEBUG' # this way we can check what is the PID of the Executor + +with Flow(protocol=PROTOCOL, port=12345).add() as f: + f.block() +``` + +```text +⠋ Waiting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0/0 -:--:--DEBUG gateway/rep-0@19075 adding connection for deployment executor0/heads/0 to grpc://0.0.0.0:12346 [05/31/22 18:10:16] +DEBUG executor0/rep-0@19074 start listening on 0.0.0.0:12346 [05/31/22 18:10:16] +DEBUG gateway/rep-0@19075 start server bound to 0.0.0.0:12345 [05/31/22 18:10:17] +DEBUG executor0/rep-0@19059 ready and listening [05/31/22 18:10:17] +DEBUG gateway/rep-0@19059 ready and listening [05/31/22 18:10:17] +╭────── 🎉 Flow is ready to serve! ──────╮ +│ 🔗 Protocol GRPC │ +│ 🏠 Local 0.0.0.0:12345 │ +│ 🔒 Private 192.168.1.13:12345 │ +╰────────────────────────────────────────╯ +DEBUG Flow@19059 2 Deployments (i.e. 2 Pods) are running in this Flow +``` + +#### Using gRPC + +When using grpc, you can use [grpcurl](https://github.com/fullstorydev/grpcurl) to hit the Gateway's gRPC service that is responsible for reporting the Flow status. + +```shell +docker pull fullstorydev/grpcurl:latest +docker run --network='host' fullstorydev/grpcurl -plaintext 127.0.0.1:12345 jina.JinaGatewayDryRunRPC/dry_run +``` +The error-free output below signifies a correctly running Flow: +```text +{ + +} +``` + +You can simulate an Executor going offline by killing its process. + +```shell script +kill -9 $EXECUTOR_PID # in this case we can see in the logs that it is 19059 +``` + +Then by doing the same check, you will see that it returns an error: + +```shell +docker run --network='host' fullstorydev/grpcurl -plaintext 127.0.0.1:12345 jina.JinaGatewayDryRunRPC/dry_run +``` + +````{dropdown} Error output +```text +{ + "code": "ERROR", + "description": "failed to connect to all addresses |Gateway: Communication error with deployment at address(es) 0.0.0.0:12346. Head or worker(s) may be down.", + "exception": { + "name": "InternalNetworkError", + "args": [ + "failed to connect to all addresses |Gateway: Communication error with deployment at address(es) 0.0.0.0:12346. Head or worker(s) may be down." + ], + "stacks": [ + "Traceback (most recent call last):\n", + " File \"/home/joan/jina/jina/jina/serve/networking.py\", line 750, in task_wrapper\n timeout=timeout,\n", + " File \"/home/joan/jina/jina/jina/serve/networking.py\", line 197, in send_discover_endpoint\n await self._init_stubs()\n", + " File \"/home/joan/jina/jina/jina/serve/networking.py\", line 174, in _init_stubs\n self.channel\n", + " File \"/home/joan/jina/jina/jina/serve/networking.py\", line 1001, in get_available_services\n async for res in response:\n", + " File \"/home/joan/.local/lib/python3.7/site-packages/grpc/aio/_call.py\", line 326, in _fetch_stream_responses\n await self._raise_for_status()\n", + " File \"/home/joan/.local/lib/python3.7/site-packages/grpc/aio/_call.py\", line 237, in _raise_for_status\n self._cython_call.status())\n", + "grpc.aio._call.AioRpcError: \u003cAioRpcError of RPC that terminated with:\n\tstatus = StatusCode.UNAVAILABLE\n\tdetails = \"failed to connect to all addresses\"\n\tdebug_error_string = \"{\"created\":\"@1654012804.794351252\",\"description\":\"Failed to pick subchannel\",\"file\":\"src/core/ext/filters/client_channel/client_channel.cc\",\"file_line\":3134,\"referenced_errors\":[{\"created\":\"@1654012804.794350006\",\"description\":\"failed to connect to all addresses\",\"file\":\"src/core/lib/transport/error_utils.cc\",\"file_line\":163,\"grpc_status\":14}]}\"\n\u003e\n", + "\nDuring handling of the above exception, another exception occurred:\n\n", + "Traceback (most recent call last):\n", + " File \"/home/joan/jina/jina/jina/serve/runtimes/gateway/grpc/__init__.py\", line 155, in dry_run\n async for _ in self.streamer.stream(request_iterator=req_iterator):\n", + " File \"/home/joan/jina/jina/jina/serve/stream/__init__.py\", line 78, in stream\n async for response in async_iter:\n", + " File \"/home/joan/jina/jina/jina/serve/stream/__init__.py\", line 154, in _stream_requests\n response = self._result_handler(future.result())\n", + " File \"/home/joan/jina/jina/jina/serve/runtimes/gateway/request_handling.py\", line 146, in _process_results_at_end_gateway\n await asyncio.gather(gather_endpoints(request_graph))\n", + " File \"/home/joan/jina/jina/jina/serve/runtimes/gateway/request_handling.py\", line 88, in gather_endpoints\n raise err\n", + " File \"/home/joan/jina/jina/jina/serve/runtimes/gateway/request_handling.py\", line 80, in gather_endpoints\n endpoints = await asyncio.gather(*tasks_to_get_endpoints)\n", + " File \"/home/joan/jina/jina/jina/serve/networking.py\", line 754, in task_wrapper\n e=e, retry_i=i, dest_addr=connection.address\n", + " File \"/home/joan/jina/jina/jina/serve/networking.py\", line 697, in _handle_aiorpcerror\n details=e.details(),\n", + "jina.excepts.InternalNetworkError: failed to connect to all addresses |Gateway: Communication error with deployment at address(es) 0.0.0.0:12346. Head or worker(s) may be down.\n" + ] + } +} +``` +```` + + +#### Using HTTP or Websocket + +When using HTTP or Websocket as the Gateway protocol, you can use curl to target the `/dry_run` endpoint and get the status of the Flow. + + +```shell +curl http://localhost:12345/dry_run +``` +The error-free output below signifies a correctly running Flow: +```text +{"code":0,"description":"","exception":null}% +``` + +You can simulate an Executor going offline by killing its process. + +```shell script +kill -9 $EXECUTOR_PID # in this case we can see in the logs that it is 19059 +``` + +Then by doing the same check, you will see that the call returns an error: + +```text +{"code":1,"description":"failed to connect to all addresses |Gateway: Communication error with deployment executor0 at address(es) {'0.0.0.0:12346'}. Head or worker(s) may be down.","exception":{"name":"InternalNetworkError","args":["failed to connect to all addresses |Gateway: Communication error with deployment executor0 at address(es) {'0.0.0.0:12346'}. Head or worker(s) may be down."],"stacks":["Traceback (most recent call last):\n"," File \"/home/joan/jina/jina/jina/serve/networking.py\", line 726, in task_wrapper\n timeout=timeout,\n"," File \"/home/joan/jina/jina/jina/serve/networking.py\", line 241, in send_requests\n await call_result,\n"," File \"/home/joan/.local/lib/python3.7/site-packages/grpc/aio/_call.py\", line 291, in __await__\n self._cython_call._status)\n","grpc.aio._call.AioRpcError: \n","\nDuring handling of the above exception, another exception occurred:\n\n","Traceback (most recent call last):\n"," File \"/home/joan/jina/jina/jina/serve/runtimes/gateway/http/app.py\", line 142, in _flow_health\n data_type=DataInputType.DOCUMENT,\n"," File \"/home/joan/jina/jina/jina/serve/runtimes/gateway/http/app.py\", line 399, in _get_singleton_result\n async for k in streamer.stream(request_iterator=request_iterator):\n"," File \"/home/joan/jina/jina/jina/serve/stream/__init__.py\", line 78, in stream\n async for response in async_iter:\n"," File \"/home/joan/jina/jina/jina/serve/stream/__init__.py\", line 154, in _stream_requests\n response = self._result_handler(future.result())\n"," File \"/home/joan/jina/jina/jina/serve/runtimes/gateway/request_handling.py\", line 148, in _process_results_at_end_gateway\n partial_responses = await asyncio.gather(*tasks)\n"," File \"/home/joan/jina/jina/jina/serve/runtimes/gateway/graph/topology_graph.py\", line 128, in _wait_previous_and_send\n self._handle_internalnetworkerror(err)\n"," File \"/home/joan/jina/jina/jina/serve/runtimes/gateway/graph/topology_graph.py\", line 70, in _handle_internalnetworkerror\n raise err\n"," File \"/home/joan/jina/jina/jina/serve/runtimes/gateway/graph/topology_graph.py\", line 125, in _wait_previous_and_send\n timeout=self._timeout_send,\n"," File \"/home/joan/jina/jina/jina/serve/networking.py\", line 734, in task_wrapper\n num_retries=num_retries,\n"," File \"/home/joan/jina/jina/jina/serve/networking.py\", line 697, in _handle_aiorpcerror\n details=e.details(),\n","jina.excepts.InternalNetworkError: failed to connect to all addresses |Gateway: Communication error with deployment executor0 at address(es) {'0.0.0.0:12346'}. Head or worker(s) may be down.\n"],"executor":""}}% +``` + +(health-check-microservices)= +## Health check of individual microservices + +In addition to a performing a readiness check for the entire Flow, it is also possible to check every individual microservice in said Flow, +by utilizing a [standardized gRPC health check endpoint](https://github.com/grpc/grpc/blob/master/doc/health-checking.md). +In most cases this is not necessary, since such checks are performed by Jina, a Kubernetes service mesh or a load balancer under the hood. +Nevertheless, it is possible to perform these checks as a user. + +When performing these checks, you can expect on of the following `ServingStatus` responses: +- **`UNKNOWN` (0)**: The health of the microservice could not be determined +- **`SERVING` (1)**: The microservice is healthy and ready to receive requests +- **`NOT_SERVING` (2)**: The microservice is *not* healthy and *not* ready to receive requests +- **`SERVICE_UNKNOWN` (3)**: The health of the microservice could not be determined while performing streaming + +````{admonition} See Also +:class: seealso + +To learn more about these status codes, and how health checks are performed with gRPC, see [here](https://github.com/grpc/grpc/blob/master/doc/health-checking.md). +```` + +(health-check-executor)= +### Health check of an Executor + +Executors run as microservices exposing gRPC endpoints, and they expose one endpoint for a health and readiness check. + +To see how to use it, you can start a Flow inside a terminal and block it to accept requests: + +```python +from jina import Flow + +f = Flow(protocol='grpc', port=12345).add(port=12346) +with f: + f.block() +``` + +On another terminal, you can use [grpcurl](https://github.com/fullstorydev/grpcurl) to send RPC requests to your services. + +```bash +docker pull fullstorydev/grpcurl:latest +docker run --network='host' fullstorydev/grpcurl -plaintext 127.0.0.1:12346 grpc.health.v1.Health/Check +``` + +```text +{ + "status": "SERVING" +} +``` + +(health-check-gateway)= +### Health check of the Gateway + +Just like each individual Executor, the Gateway also acts as a microservice, and as such it exposes a health check endpoint. + +In contrast to Executors however, a Gateway can use gRPC, HTTP, or Websocket, and the health check endpoint changes accordingly. + + +#### Gateway health check with gRPC + +When using gRPC as the protocol to communicate with the Gateway, the Gateway uses the exact same mechanism as Executors to expose its health status: It exposes the [ standard gRPC health check](https://github.com/grpc/grpc/blob/master/doc/health-checking.md) to the outside world. + +With the same Flow as described before, you can use the same way to check the Gateway status: + +```bash +docker run --network='host' fullstorydev/grpcurl -plaintext 127.0.0.1:12345 grpc.health.v1.Health/Check +``` + +```text +{ + "status": "SERVING" +} +``` + + +#### Gateway health check with HTTP or Websocket + +````{admonition} Caution +:class: caution +For Gateways running with HTTP or Websocket, the gRPC health check response codes outlined {ref}`above ` do not apply. + +Instead, an error free response signifies healthiness. +```` + +When using HTTP or Websocket as the protocol for the Gateway, it exposes the endpoint `'/'` that one can query to check the status. + +First, crate a Flow with HTTP or Websocket protocol: + +```python +from jina import Flow + +f = Flow(protocol='http', port=12345).add() +with f: + f.block() +``` +Then, you can query the "empty" endpoint: +```bash +curl http://localhost:12345 +``` + +And you will get a valid empty response indicating the Gateway's ability to serve. +```text +{}% +``` \ No newline at end of file diff --git a/docs/fundamentals/flow/index.md b/docs/fundamentals/flow/index.md index 8697187cbb1f6..3f47ab879c1df 100644 --- a/docs/fundamentals/flow/index.md +++ b/docs/fundamentals/flow/index.md @@ -22,6 +22,7 @@ The most important methods of the `Flow` object are the following: | `.block()` | Blocks execution until the program is terminated. This is useful to keep the Flow alive so it can be used from other places (clients, etc). | | `.to_docker_compose_yaml()` | Generates a Docker-Compose file listing all its Executors as Services. | | `.to_kubernetes_yaml()` | Generates the Kubernetes configuration files in ``. Based on your local Jina version, Jina Hub may rebuild the Docker image during the YAML generation process. If you do not wish to rebuild the image, set the environment variable `JINA_HUB_NO_IMAGE_REBUILD`. | +| `.dry_run()` | Calls the dry run endpoint of the Flow to check if the Flow is ready to process requests. Returns a boolean indicating the readiness | ## Why should you use a Flow? @@ -146,6 +147,7 @@ add-executors topologies flow-api monitoring-flow +health-check when-things-go-wrong yaml-spec ``` diff --git a/extra-requirements.txt b/extra-requirements.txt index 5c97f3e87727a..d4cc2f94bf6dc 100644 --- a/extra-requirements.txt +++ b/extra-requirements.txt @@ -29,6 +29,7 @@ numpy: core protobuf>=3.19.1,<=3.20.1: core grpcio>=1.46.0: core grpcio-reflection>=1.46.0: core +grpcio-health-checking>=1.46.0: core pyyaml>=5.3.1: core packaging>=20.0: core docarray>=0.13.14: core diff --git a/jina/__init__.py b/jina/__init__.py index 017b969f7c23a..c8c365e4b7e80 100644 --- a/jina/__init__.py +++ b/jina/__init__.py @@ -68,7 +68,8 @@ def _warning_on_one_line(message, category, filename, lineno, *args, **kwargs): # do not change this line manually # this is managed by proto/build-proto.sh and updated on every execution -__proto_version__ = '0.1.10' +__proto_version__ = '0.1.11' + try: __docarray_version__ = _docarray.__version__ except AttributeError as e: diff --git a/jina/clients/base/__init__.py b/jina/clients/base/__init__.py index b4005b54dab9a..7bb8e2372498b 100644 --- a/jina/clients/base/__init__.py +++ b/jina/clients/base/__init__.py @@ -152,6 +152,14 @@ async def _get_results( ): ... + @abc.abstractmethod + def _dry_run(self, **kwargs) -> bool: + """Sends a dry run to the Flow to validate if the Flow is ready to receive requests + + :param kwargs: potential kwargs received passed from the public interface + """ + ... + @property def client(self: T) -> T: """Return the client object itself diff --git a/jina/clients/base/grpc.py b/jina/clients/base/grpc.py index 280745bda97dd..2f1db7bcf09c2 100644 --- a/jina/clients/base/grpc.py +++ b/jina/clients/base/grpc.py @@ -12,7 +12,7 @@ InternalNetworkError, ) from jina.logging.profile import ProgressBar -from jina.proto import jina_pb2_grpc +from jina.proto import jina_pb2, jina_pb2_grpc from jina.serve.networking import GrpcConnectionPool if TYPE_CHECKING: @@ -25,6 +25,34 @@ class GRPCBaseClient(BaseClient): It manages the asyncio event loop internally, so all interfaces are synchronous from the outside. """ + async def _dry_run(self, **kwargs) -> bool: + """Sends a dry run to the Flow to validate if the Flow is ready to receive requests + + :param kwargs: potential kwargs received passed from the public interface + :return: boolean indicating the health/readiness of the Flow + """ + try: + async with GrpcConnectionPool.get_grpc_channel( + f'{self.args.host}:{self.args.port}', + asyncio=True, + tls=self.args.tls, + ) as channel: + stub = jina_pb2_grpc.JinaGatewayDryRunRPCStub(channel) + self.logger.debug(f'connected to {self.args.host}:{self.args.port}') + call_result = stub.dry_run( + jina_pb2.google_dot_protobuf_dot_empty__pb2.Empty(), **kwargs + ) + metadata, response = ( + await call_result.trailing_metadata(), + await call_result, + ) + if response.code == jina_pb2.StatusProto.SUCCESS: + return True + except Exception as e: + self.logger.error(f'Error while getting response from grpc server {e!r}') + + return False + async def _get_results( self, inputs: 'InputType', diff --git a/jina/clients/base/helper.py b/jina/clients/base/helper.py index a55bb635270bc..6ae060052a8c1 100644 --- a/jina/clients/base/helper.py +++ b/jina/clients/base/helper.py @@ -8,6 +8,7 @@ from jina.importer import ImportExtensions from jina.types.request import Request from jina.types.request.data import DataRequest +from jina.types.request.status import StatusMessage if TYPE_CHECKING: from jina.logging.logger import JinaLogger @@ -33,11 +34,20 @@ async def send_message(self): """Send message to Gateway""" ... + @abstractmethod + async def send_dry_run(self): + """Query the dry_run endpoint from Gateway""" + ... + @abstractmethod async def recv_message(self): """Receive message from Gateway""" ... + async def recv_dry_run(self): + """Receive dry run response from Gateway""" + pass + async def __aenter__(self): """enter async context @@ -83,6 +93,12 @@ async def send_message(self, request: 'Request'): req_dict['target_executor'] = req_dict['header']['target_executor'] return await self.session.post(url=self.url, json=req_dict).__aenter__() + async def send_dry_run(self): + """Query the dry_run endpoint from Gateway + :return: send get message + """ + return await self.session.get(url=self.url).__aenter__() + async def recv_message(self): """Receive message for HTTP (sleep) @@ -90,6 +106,13 @@ async def recv_message(self): """ return await asyncio.sleep(1e10) + async def recv_dry_run(self): + """Receive dry run response for HTTP (sleep) + + :return: await sleep + """ + return await asyncio.sleep(1e10) + class WsResponseIter: """ @@ -131,6 +154,17 @@ async def send_message(self, request: 'Request'): except ConnectionResetError: self.logger.critical(f'server connection closed already!') + async def send_dry_run(self): + """Query the dry_run endpoint from Gateway + + :return: send dry_run as bytes awaitable + """ + + try: + return await self.websocket.send_bytes(b'') + except ConnectionResetError: + self.logger.critical(f'server connection closed already!') + async def send_eoi(self): """To confirm end of iteration, we send `bytes(True)` to the server. @@ -155,6 +189,18 @@ async def recv_message(self) -> 'DataRequest': async for response in self.response_iter: yield DataRequest(response.data) + async def recv_dry_run(self): + """Receive dry run response in bytes from server + + ..note:: + aiohttp allows only one task which can `receive` concurrently. + we need to make sure we don't create multiple tasks with `recv_message` + + :yield: response objects received from server + """ + async for response in self.response_iter: + yield StatusMessage(response.data) + async def __aenter__(self): await super().__aenter__() self.websocket = await self.session.ws_connect( diff --git a/jina/clients/base/http.py b/jina/clients/base/http.py index 56adc19d0018f..57534bc8bb683 100644 --- a/jina/clients/base/http.py +++ b/jina/clients/base/http.py @@ -21,6 +21,52 @@ class HTTPBaseClient(BaseClient): """A MixIn for HTTP Client.""" + def _handle_response_status(self, r_status, r_str, url): + if r_status == status.HTTP_404_NOT_FOUND: + raise BadClient(f'no such endpoint {url}') + elif r_status == status.HTTP_503_SERVICE_UNAVAILABLE: + if ( + 'header' in r_str + and 'status' in r_str['header'] + and 'description' in r_str['header']['status'] + ): + raise ConnectionError(r_str['header']['status']['description']) + else: + raise ValueError(r_str) + elif ( + r_status < status.HTTP_200_OK or r_status > status.HTTP_300_MULTIPLE_CHOICES + ): # failure codes + raise ValueError(r_str) + + async def _dry_run(self, **kwargs) -> bool: + """Sends a dry run to the Flow to validate if the Flow is ready to receive requests + + :param kwargs: potential kwargs received passed from the public interface + :return: boolean indicating the health/readiness of the Flow + """ + from jina.proto import jina_pb2 + + async with AsyncExitStack() as stack: + try: + proto = 'https' if self.args.tls else 'http' + url = f'{proto}://{self.args.host}:{self.args.port}/dry_run' + iolet = await stack.enter_async_context( + HTTPClientlet(url=url, logger=self.logger) + ) + + response = await iolet.send_dry_run() + r_status = response.status + + r_str = await response.json() + self._handle_response_status(r_status, r_str, url) + if r_str['code'] == jina_pb2.StatusProto.SUCCESS: + return True + except Exception as e: + self.logger.error( + f'Error while fetching response from HTTP server {e!r}' + ) + return False + async def _get_results( self, inputs: 'InputType', @@ -77,24 +123,7 @@ def _result_handler(result): r_status = response.status r_str = await response.json() - if r_status == status.HTTP_404_NOT_FOUND: - raise BadClient(f'no such endpoint {url}') - elif r_status == status.HTTP_503_SERVICE_UNAVAILABLE: - if ( - 'header' in r_str - and 'status' in r_str['header'] - and 'description' in r_str['header']['status'] - ): - raise ConnectionError( - r_str['header']['status']['description'] - ) - else: - raise ValueError(r_str) - elif ( - r_status < status.HTTP_200_OK - or r_status > status.HTTP_300_MULTIPLE_CHOICES - ): # failure codes - raise ValueError(r_str) + self._handle_response_status(r_status, r_str, url) da = None if 'data' in r_str and r_str['data'] is not None: diff --git a/jina/clients/base/websocket.py b/jina/clients/base/websocket.py index 8f5cad8407596..ad8e019b41ee5 100644 --- a/jina/clients/base/websocket.py +++ b/jina/clients/base/websocket.py @@ -11,6 +11,7 @@ from jina.helper import get_or_reuse_loop from jina.importer import ImportExtensions from jina.logging.profile import ProgressBar +from jina.proto import jina_pb2 from jina.serve.stream import RequestStreamer if TYPE_CHECKING: @@ -21,6 +22,55 @@ class WebSocketBaseClient(BaseClient): """A Websocket Client.""" + async def _dry_run(self, **kwargs) -> bool: + """Sends a dry run to the Flow to validate if the Flow is ready to receive requests + + :param kwargs: potential kwargs received passed from the public interface + :return: boolean indicating the readiness of the Flow + """ + async with AsyncExitStack() as stack: + try: + proto = 'wss' if self.args.tls else 'ws' + url = f'{proto}://{self.args.host}:{self.args.port}/dry_run' + iolet = await stack.enter_async_context( + WebsocketClientlet(url=url, logger=self.logger) + ) + + async def _receive(): + try: + async for response in iolet.recv_dry_run(): + return response + except Exception as exc: + self.logger.error( + f'Error while fetching response from Websocket server {exc!r}' + ) + raise + + async def _send(): + return await iolet.send_dry_run() + + receive_task = asyncio.create_task(_receive()) + + if receive_task.done(): + raise RuntimeError( + 'receive task not running, can not send messages' + ) + try: + send_task = asyncio.create_task(_send()) + _, response_result = await asyncio.gather(send_task, receive_task) + if response_result.proto.code == jina_pb2.StatusProto.SUCCESS: + return True + finally: + if iolet.close_code == status.WS_1011_INTERNAL_ERROR: + raise ConnectionError(iolet.close_message) + await receive_task + + except Exception as e: + self.logger.error( + f'Error while getting response from websocket server {e!r}' + ) + return False + async def _get_results( self, inputs: 'InputType', diff --git a/jina/clients/grpc.py b/jina/clients/grpc.py index 9b18ef4dc5fdf..2bca0d5c5ccb9 100644 --- a/jina/clients/grpc.py +++ b/jina/clients/grpc.py @@ -1,12 +1,12 @@ from jina.clients.base.grpc import GRPCBaseClient -from jina.clients.mixin import AsyncPostMixin, PostMixin +from jina.clients.mixin import AsyncPostMixin, HealthCheckMixin, PostMixin -class GRPCClient(GRPCBaseClient, PostMixin): +class GRPCClient(GRPCBaseClient, PostMixin, HealthCheckMixin): """A client communicates the server with GRPC protocol.""" -class AsyncGRPCClient(GRPCBaseClient, AsyncPostMixin): +class AsyncGRPCClient(GRPCBaseClient, AsyncPostMixin, HealthCheckMixin): """ A client communicates the server with GRPC protocol. diff --git a/jina/clients/helper.py b/jina/clients/helper.py index 365e83e02943d..028b86d92ec58 100644 --- a/jina/clients/helper.py +++ b/jina/clients/helper.py @@ -29,8 +29,6 @@ def pprint_routes(resp: 'Response', stack_limit: int = 3): status_icon = '🟢' if route.status.code == jina_pb2.StatusProto.ERROR: status_icon = '🔴' - elif route.status.code == jina_pb2.StatusProto.ERROR_CHAINED: - status_icon = '⚪' table.add_row( f'{status_icon} {route.executor}', diff --git a/jina/clients/http.py b/jina/clients/http.py index dba662ac07f3a..ae1023ea899d7 100644 --- a/jina/clients/http.py +++ b/jina/clients/http.py @@ -1,14 +1,22 @@ from jina.clients.base.http import HTTPBaseClient -from jina.clients.mixin import AsyncPostMixin, PostMixin, MutateMixin, AsyncMutateMixin +from jina.clients.mixin import ( + AsyncMutateMixin, + AsyncPostMixin, + HealthCheckMixin, + MutateMixin, + PostMixin, +) -class HTTPClient(HTTPBaseClient, PostMixin, MutateMixin): +class HTTPClient(HTTPBaseClient, PostMixin, MutateMixin, HealthCheckMixin): """ A client communicates the server with HTTP protocol. """ -class AsyncHTTPClient(HTTPBaseClient, AsyncPostMixin, AsyncMutateMixin): +class AsyncHTTPClient( + HTTPBaseClient, AsyncPostMixin, AsyncMutateMixin, HealthCheckMixin +): """ A client communicates the server with HTTP protocol. diff --git a/jina/clients/mixin.py b/jina/clients/mixin.py index 21d3e4f75689f..4fa3b6d28040c 100644 --- a/jina/clients/mixin.py +++ b/jina/clients/mixin.py @@ -88,6 +88,18 @@ async def mutate( ) +class HealthCheckMixin: + """The Health check Mixin for Client and Flow to expose `dry_run` API""" + + def dry_run(self, **kwargs) -> bool: + """Sends a dry run to the Flow to validate if the Flow is ready to receive requests + + :param kwargs: potential kwargs received passed from the public interface + :return: boolean indicating the health/readiness of the Flow + """ + return run_async(self.client._dry_run, **kwargs) + + class PostMixin: """The Post Mixin class for Client and Flow""" diff --git a/jina/clients/request/helper.py b/jina/clients/request/helper.py index 6dc653e709cd1..e00b2ec06bf57 100644 --- a/jina/clients/request/helper.py +++ b/jina/clients/request/helper.py @@ -1,10 +1,8 @@ """Module for helper functions for clients.""" from typing import Tuple -from docarray import DocumentArray, Document - +from docarray import Document, DocumentArray from jina.enums import DataInputType -from jina.excepts import BadRequestType from jina.types.request.data import DataRequest @@ -72,30 +70,3 @@ def _add_docs(req, batch, data_type, _kwargs): d, data_type = _new_doc_from_data(content, data_type, **_kwargs) da.append(d) req.data.docs = da - - -def _add_control_propagate(req, kwargs): - from jina.proto import jina_pb2 - - extra_kwargs = kwargs[ - 'extra_kwargs' - ] #: control command and args are stored inside extra_kwargs - _available_commands = dict( - jina_pb2.RequestProto.ControlRequestProto.DESCRIPTOR.enum_values_by_name - ) - - if 'command' in extra_kwargs: - command = extra_kwargs['command'] - else: - raise BadRequestType( - 'sending ControlRequest from Client must contain the field `command`' - ) - - if command in _available_commands: - req.control.command = getattr( - jina_pb2.RequestProto.ControlRequestProto, command - ) - else: - raise ValueError( - f'command "{command}" is not supported, must be one of {_available_commands}' - ) diff --git a/jina/clients/websocket.py b/jina/clients/websocket.py index c488c4aff42d8..036298e9c0bbf 100644 --- a/jina/clients/websocket.py +++ b/jina/clients/websocket.py @@ -1,14 +1,14 @@ from jina.clients.base.websocket import WebSocketBaseClient -from jina.clients.mixin import AsyncPostMixin, PostMixin +from jina.clients.mixin import AsyncPostMixin, HealthCheckMixin, PostMixin -class WebSocketClient(WebSocketBaseClient, PostMixin): +class WebSocketClient(WebSocketBaseClient, PostMixin, HealthCheckMixin): """ A client communicates the server with WebSocket protocol. """ -class AsyncWebSocketClient(WebSocketBaseClient, AsyncPostMixin): +class AsyncWebSocketClient(WebSocketBaseClient, AsyncPostMixin, HealthCheckMixin): """ A client communicates the server with WebSocket protocol. diff --git a/jina/orchestrate/deployments/__init__.py b/jina/orchestrate/deployments/__init__.py index e74e09f37ff2e..2a017a0e91436 100644 --- a/jina/orchestrate/deployments/__init__.py +++ b/jina/orchestrate/deployments/__init__.py @@ -1,9 +1,11 @@ import copy +import json import os import re import subprocess from abc import abstractmethod from argparse import Namespace +from collections import defaultdict from contextlib import ExitStack from itertools import cycle from typing import Dict, List, Optional, Set, Union @@ -14,9 +16,8 @@ from jina.hubble.helper import replace_secret_of_hub_uri from jina.hubble.hubio import HubIO from jina.jaml.helper import complete_path -from jina.orchestrate.pods.container import ContainerPod from jina.orchestrate.pods.factory import PodFactory -from jina.serve.networking import GrpcConnectionPool, host_is_local, in_docker +from jina.serve.networking import host_is_local, in_docker WRAPPED_SLICE_BASE = r'\[[-\d:]+\]' @@ -288,6 +289,17 @@ def is_sandbox(self) -> bool: is_sandbox = uses.startswith('jinahub+sandbox://') return is_sandbox + @property + def _is_docker(self) -> bool: + """ + Check if this deployment is to be run in docker. + + :return: True if this deployment is to be run in docker + """ + uses = getattr(self.args, 'uses', '') + is_docker = uses.startswith('jinahub+docker://') or uses.startswith('docker://') + return is_docker + @property def tls_enabled(self): """ @@ -464,52 +476,27 @@ def num_pods(self) -> int: def __eq__(self, other: 'BaseDeployment'): return self.num_pods == other.num_pods and self.name == other.name - def activate(self): - """ - Activate all worker pods in this deployment by registering them with the head - """ - if self.head_pod is not None: - for shard_id in self.pod_args['pods']: - for pod_idx, pod_args in enumerate(self.pod_args['pods'][shard_id]): - worker_host = self.get_worker_host( - pod_args, self.shards[shard_id]._pods[pod_idx], self.head_pod - ) - GrpcConnectionPool.activate_worker_sync( - worker_host, - int(pod_args.port), - self.head_pod.runtime_ctrl_address, - shard_id, - ) - @staticmethod - def get_worker_host(pod_args, pod, head_pod): + def get_worker_host(pod_args, pod_is_container, head_is_container): """ Check if the current pod and head are both containerized on the same host If so __docker_host__ needs to be advertised as the worker's address to the head :param pod_args: arguments of the worker pod - :param pod: the worker pod - :param head_pod: head pod communicating with the worker pod - :return: host to use in activate messages + :param pod_is_container: boolean specifying if pod is to be run in container + :param head_is_container: boolean specifying if head pod is to be run in container + :return: host to pass in connection list of the head """ # Check if the current pod and head are both containerized on the same host # If so __docker_host__ needs to be advertised as the worker's address to the head worker_host = ( __docker_host__ - if Deployment._is_container_to_container(pod, head_pod) + if (pod_is_container and (head_is_container or in_docker())) and host_is_local(pod_args.host) else pod_args.host ) return worker_host - @staticmethod - def _is_container_to_container(pod, head_pod): - # Check if both shard_id/pod_idx and the head are containerized - # if the head is not containerized, it still could mean that the deployment itself is containerized - return type(pod) == ContainerPod and ( - type(head_pod) == ContainerPod or in_docker() - ) - def start(self) -> 'Deployment': """ Start to run all :class:`Pod` in this BaseDeployment. @@ -549,8 +536,6 @@ def start(self) -> 'Deployment': ) self.enter_context(self.shards[shard_id]) - if not getattr(self.args, 'noblock_on_start', False): - self.activate() return self def wait_start_success(self) -> None: @@ -571,7 +556,6 @@ def wait_start_success(self) -> None: self.head_pod.wait_start_success() for shard_id in self.shards: self.shards[shard_id].wait_start_success() - self.activate() except: self.close() raise @@ -782,8 +766,18 @@ def _parse_base_deployment_args(self, args): ) parsed_args['head'] = BaseDeployment._copy_to_head_args(args) + parsed_args['pods'] = self._set_pod_args(args) + if parsed_args['head'] is not None: + connection_list = defaultdict(list) + + for shard_id in parsed_args['pods']: + for pod_idx, pod_args in enumerate(parsed_args['pods'][shard_id]): + worker_host = self.get_worker_host(pod_args, self._is_docker, False) + connection_list[shard_id].append(f'{worker_host}:{pod_args.port}') + parsed_args['head'].connection_list = json.dumps(connection_list) + return parsed_args @property diff --git a/jina/orchestrate/flow/base.py b/jina/orchestrate/flow/base.py index e4632339e4d0e..dcb5fa8c455fd 100644 --- a/jina/orchestrate/flow/base.py +++ b/jina/orchestrate/flow/base.py @@ -38,7 +38,7 @@ from jina import __default_host__, __default_port_monitoring__, __docker_host__, helper from jina.clients import Client -from jina.clients.mixin import AsyncPostMixin, PostMixin +from jina.clients.mixin import AsyncPostMixin, HealthCheckMixin, PostMixin from jina.enums import ( DeploymentRoleType, FlowBuildLevel, @@ -97,7 +97,7 @@ class FlowType(type(ExitStack), type(JAMLCompatible)): ] -class Flow(PostMixin, JAMLCompatible, ExitStack, metaclass=FlowType): +class Flow(PostMixin, HealthCheckMixin, JAMLCompatible, ExitStack, metaclass=FlowType): """Flow is how Jina streamlines and distributes Executors.""" # overload_inject_start_client_flow diff --git a/jina/orchestrate/pods/helper.py b/jina/orchestrate/pods/helper.py index efc99fb60a8da..50397c2e60d38 100644 --- a/jina/orchestrate/pods/helper.py +++ b/jina/orchestrate/pods/helper.py @@ -3,13 +3,9 @@ from functools import partial from typing import TYPE_CHECKING -from grpc import RpcError - from jina.enums import GatewayProtocolType, PodRoleType from jina.hubble.helper import is_valid_huburi from jina.hubble.hubio import HubIO -from jina.serve.networking import GrpcConnectionPool -from jina.types.request.control import ControlRequest if TYPE_CHECKING: from argparse import Namespace @@ -94,18 +90,3 @@ def update_runtime_cls(args, copy=False) -> 'Namespace': _args.runtime_cls = 'HeadRuntime' return _args - - -def is_ready(address: str) -> bool: - """ - TODO: make this async - Check if status is ready. - :param address: the address where the control message needs to be sent - :return: True if status is ready else False. - """ - - try: - GrpcConnectionPool.send_request_sync(ControlRequest('STATUS'), address) - except RpcError: - return False - return True diff --git a/jina/proto/jina.proto b/jina/proto/jina.proto index 74597487fc13c..5689cde828318 100644 --- a/jina/proto/jina.proto +++ b/jina/proto/jina.proto @@ -56,12 +56,7 @@ message StatusProto { enum StatusCode { SUCCESS = 0; // success - PENDING = 1; // there are pending messages, more messages are followed - READY = 2; // ready to use - ERROR = 3; // error - ERROR_DUPLICATE = 4; // already a existing pod running - ERROR_NOTALLOWED = 5; // not allowed to open pod remotely - ERROR_CHAINED = 6; // chained from the previous error + ERROR = 1; // error } // status code @@ -80,7 +75,7 @@ message StatusProto { // the exception traceback stacks repeated string stacks = 3; - // the name of the executor bind to that peapod (if applicable) + // the name of the executor bind to that Executor (if applicable) string executor = 4; } @@ -93,32 +88,13 @@ message StatusProto { * Represents an entity (like an ExecutorRuntime) */ message RelatedEntity { - string id = 1; // unique id of the entity, like the name of a pea + string id = 1; // unique id of the entity, like the name of a pod string address = 2; // address of the entity, could be an IP address, domain name etc, does not include port uint32 port = 3; // port this entity is listening on optional uint32 shard_id = 4; // the id of the shard it belongs to, if it is a shard } -/** - * Represents a ControlRequest - */ -message ControlRequestProto { - - HeaderProto header = 1; // header contains meta info defined by the user - - enum Command { - STATUS = 0; // check the status of the BasePod - ACTIVATE = 1; // used to add Pods to a Pod - DEACTIVATE = 2; // used to remove Pods from a Pod - } - - Command command = 2; // the control command - - repeated RelatedEntity relatedEntities = 3; // list of entities this ControlMessage is related to -} - - /** * Represents a DataRequest */ @@ -148,15 +124,6 @@ message DataRequestListProto { repeated DataRequestProto requests = 1; // requests in this list } -/** - * jina gRPC service for ControlRequests. - */ -service JinaControlRequestRPC { - // Used for passing ControlRequests to the Executors - rpc process_control (ControlRequestProto) returns (ControlRequestProto) { - } -} - /** * jina gRPC service for DataRequests. */ @@ -193,3 +160,12 @@ service JinaDiscoverEndpointsRPC { } } + +/** + * jina gRPC service to expose Endpoints from Executors. + */ +service JinaGatewayDryRunRPC { + rpc dry_run (google.protobuf.Empty) returns (StatusProto) { + } +} + diff --git a/jina/proto/jina_pb2.py b/jina/proto/jina_pb2.py index fa4f245092b48..d0879789e957b 100644 --- a/jina/proto/jina_pb2.py +++ b/jina/proto/jina_pb2.py @@ -18,7 +18,7 @@ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\njina.proto\x12\x04jina\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x0e\x64ocarray.proto\"\x9f\x01\n\nRouteProto\x12\x10\n\x08\x65xecutor\x18\x01 \x01(\t\x12.\n\nstart_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12!\n\x06status\x18\x04 \x01(\x0b\x32\x11.jina.StatusProto\"\xc6\x01\n\x0bHeaderProto\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12!\n\x06status\x18\x02 \x01(\x0b\x32\x11.jina.StatusProto\x12\x1a\n\rexec_endpoint\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0ftarget_executor\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x14\n\x07timeout\x18\x05 \x01(\rH\x02\x88\x01\x01\x42\x10\n\x0e_exec_endpointB\x12\n\x10_target_executorB\n\n\x08_timeout\"#\n\x0e\x45ndpointsProto\x12\x11\n\tendpoints\x18\x01 \x03(\t\"\xcf\x02\n\x0bStatusProto\x12*\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1c.jina.StatusProto.StatusCode\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x33\n\texception\x18\x03 \x01(\x0b\x32 .jina.StatusProto.ExceptionProto\x1aN\n\x0e\x45xceptionProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x61rgs\x18\x02 \x03(\t\x12\x0e\n\x06stacks\x18\x03 \x03(\t\x12\x10\n\x08\x65xecutor\x18\x04 \x01(\t\"z\n\nStatusCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\x0b\n\x07PENDING\x10\x01\x12\t\n\x05READY\x10\x02\x12\t\n\x05\x45RROR\x10\x03\x12\x13\n\x0f\x45RROR_DUPLICATE\x10\x04\x12\x14\n\x10\x45RROR_NOTALLOWED\x10\x05\x12\x11\n\rERROR_CHAINED\x10\x06\"^\n\rRelatedEntity\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x07\x61\x64\x64ress\x18\x02 \x01(\t\x12\x0c\n\x04port\x18\x03 \x01(\r\x12\x15\n\x08shard_id\x18\x04 \x01(\rH\x00\x88\x01\x01\x42\x0b\n\t_shard_id\"\xcf\x01\n\x13\x43ontrolRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0e\x32!.jina.ControlRequestProto.Command\x12,\n\x0frelatedEntities\x18\x03 \x03(\x0b\x32\x13.jina.RelatedEntity\"3\n\x07\x43ommand\x12\n\n\x06STATUS\x10\x00\x12\x0c\n\x08\x41\x43TIVATE\x10\x01\x12\x0e\n\nDEACTIVATE\x10\x02\"\xa0\x02\n\x10\x44\x61taRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\x12\x35\n\x04\x64\x61ta\x18\x04 \x01(\x0b\x32\'.jina.DataRequestProto.DataContentProto\x1a\x63\n\x10\x44\x61taContentProto\x12,\n\x04\x64ocs\x18\x01 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12\x14\n\ndocs_bytes\x18\x02 \x01(\x0cH\x00\x42\x0b\n\tdocuments\"@\n\x14\x44\x61taRequestListProto\x12(\n\x08requests\x18\x01 \x03(\x0b\x32\x16.jina.DataRequestProto2b\n\x15JinaControlRequestRPC\x12I\n\x0fprocess_control\x12\x19.jina.ControlRequestProto\x1a\x19.jina.ControlRequestProto\"\x00\x32Z\n\x12JinaDataRequestRPC\x12\x44\n\x0cprocess_data\x12\x1a.jina.DataRequestListProto\x1a\x16.jina.DataRequestProto\"\x00\x32\x63\n\x18JinaSingleDataRequestRPC\x12G\n\x13process_single_data\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00\x32G\n\x07JinaRPC\x12<\n\x04\x43\x61ll\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00(\x01\x30\x01\x32`\n\x18JinaDiscoverEndpointsRPC\x12\x44\n\x12\x65ndpoint_discovery\x12\x16.google.protobuf.Empty\x1a\x14.jina.EndpointsProto\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\njina.proto\x12\x04jina\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x0e\x64ocarray.proto\"\x9f\x01\n\nRouteProto\x12\x10\n\x08\x65xecutor\x18\x01 \x01(\t\x12.\n\nstart_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12!\n\x06status\x18\x04 \x01(\x0b\x32\x11.jina.StatusProto\"\xc6\x01\n\x0bHeaderProto\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12!\n\x06status\x18\x02 \x01(\x0b\x32\x11.jina.StatusProto\x12\x1a\n\rexec_endpoint\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0ftarget_executor\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x14\n\x07timeout\x18\x05 \x01(\rH\x02\x88\x01\x01\x42\x10\n\x0e_exec_endpointB\x12\n\x10_target_executorB\n\n\x08_timeout\"#\n\x0e\x45ndpointsProto\x12\x11\n\tendpoints\x18\x01 \x03(\t\"\xf9\x01\n\x0bStatusProto\x12*\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1c.jina.StatusProto.StatusCode\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x33\n\texception\x18\x03 \x01(\x0b\x32 .jina.StatusProto.ExceptionProto\x1aN\n\x0e\x45xceptionProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x61rgs\x18\x02 \x03(\t\x12\x0e\n\x06stacks\x18\x03 \x03(\t\x12\x10\n\x08\x65xecutor\x18\x04 \x01(\t\"$\n\nStatusCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\t\n\x05\x45RROR\x10\x01\"^\n\rRelatedEntity\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x07\x61\x64\x64ress\x18\x02 \x01(\t\x12\x0c\n\x04port\x18\x03 \x01(\r\x12\x15\n\x08shard_id\x18\x04 \x01(\rH\x00\x88\x01\x01\x42\x0b\n\t_shard_id\"\xa0\x02\n\x10\x44\x61taRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\x12\x35\n\x04\x64\x61ta\x18\x04 \x01(\x0b\x32\'.jina.DataRequestProto.DataContentProto\x1a\x63\n\x10\x44\x61taContentProto\x12,\n\x04\x64ocs\x18\x01 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12\x14\n\ndocs_bytes\x18\x02 \x01(\x0cH\x00\x42\x0b\n\tdocuments\"@\n\x14\x44\x61taRequestListProto\x12(\n\x08requests\x18\x01 \x03(\x0b\x32\x16.jina.DataRequestProto2Z\n\x12JinaDataRequestRPC\x12\x44\n\x0cprocess_data\x12\x1a.jina.DataRequestListProto\x1a\x16.jina.DataRequestProto\"\x00\x32\x63\n\x18JinaSingleDataRequestRPC\x12G\n\x13process_single_data\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00\x32G\n\x07JinaRPC\x12<\n\x04\x43\x61ll\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00(\x01\x30\x01\x32`\n\x18JinaDiscoverEndpointsRPC\x12\x44\n\x12\x65ndpoint_discovery\x12\x16.google.protobuf.Empty\x1a\x14.jina.EndpointsProto\"\x00\x32N\n\x14JinaGatewayDryRunRPC\x12\x36\n\x07\x64ry_run\x12\x16.google.protobuf.Empty\x1a\x11.jina.StatusProto\"\x00\x62\x06proto3') @@ -28,12 +28,10 @@ _STATUSPROTO = DESCRIPTOR.message_types_by_name['StatusProto'] _STATUSPROTO_EXCEPTIONPROTO = _STATUSPROTO.nested_types_by_name['ExceptionProto'] _RELATEDENTITY = DESCRIPTOR.message_types_by_name['RelatedEntity'] -_CONTROLREQUESTPROTO = DESCRIPTOR.message_types_by_name['ControlRequestProto'] _DATAREQUESTPROTO = DESCRIPTOR.message_types_by_name['DataRequestProto'] _DATAREQUESTPROTO_DATACONTENTPROTO = _DATAREQUESTPROTO.nested_types_by_name['DataContentProto'] _DATAREQUESTLISTPROTO = DESCRIPTOR.message_types_by_name['DataRequestListProto'] _STATUSPROTO_STATUSCODE = _STATUSPROTO.enum_types_by_name['StatusCode'] -_CONTROLREQUESTPROTO_COMMAND = _CONTROLREQUESTPROTO.enum_types_by_name['Command'] RouteProto = _reflection.GeneratedProtocolMessageType('RouteProto', (_message.Message,), { 'DESCRIPTOR' : _ROUTEPROTO, '__module__' : 'jina_pb2' @@ -77,13 +75,6 @@ }) _sym_db.RegisterMessage(RelatedEntity) -ControlRequestProto = _reflection.GeneratedProtocolMessageType('ControlRequestProto', (_message.Message,), { - 'DESCRIPTOR' : _CONTROLREQUESTPROTO, - '__module__' : 'jina_pb2' - # @@protoc_insertion_point(class_scope:jina.ControlRequestProto) - }) -_sym_db.RegisterMessage(ControlRequestProto) - DataRequestProto = _reflection.GeneratedProtocolMessageType('DataRequestProto', (_message.Message,), { 'DataContentProto' : _reflection.GeneratedProtocolMessageType('DataContentProto', (_message.Message,), { @@ -106,11 +97,11 @@ }) _sym_db.RegisterMessage(DataRequestListProto) -_JINACONTROLREQUESTRPC = DESCRIPTOR.services_by_name['JinaControlRequestRPC'] _JINADATAREQUESTRPC = DESCRIPTOR.services_by_name['JinaDataRequestRPC'] _JINASINGLEDATAREQUESTRPC = DESCRIPTOR.services_by_name['JinaSingleDataRequestRPC'] _JINARPC = DESCRIPTOR.services_by_name['JinaRPC'] _JINADISCOVERENDPOINTSRPC = DESCRIPTOR.services_by_name['JinaDiscoverEndpointsRPC'] +_JINAGATEWAYDRYRUNRPC = DESCRIPTOR.services_by_name['JinaGatewayDryRunRPC'] if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None @@ -121,31 +112,27 @@ _ENDPOINTSPROTO._serialized_start=491 _ENDPOINTSPROTO._serialized_end=526 _STATUSPROTO._serialized_start=529 - _STATUSPROTO._serialized_end=864 + _STATUSPROTO._serialized_end=778 _STATUSPROTO_EXCEPTIONPROTO._serialized_start=662 _STATUSPROTO_EXCEPTIONPROTO._serialized_end=740 _STATUSPROTO_STATUSCODE._serialized_start=742 - _STATUSPROTO_STATUSCODE._serialized_end=864 - _RELATEDENTITY._serialized_start=866 - _RELATEDENTITY._serialized_end=960 - _CONTROLREQUESTPROTO._serialized_start=963 - _CONTROLREQUESTPROTO._serialized_end=1170 - _CONTROLREQUESTPROTO_COMMAND._serialized_start=1119 - _CONTROLREQUESTPROTO_COMMAND._serialized_end=1170 - _DATAREQUESTPROTO._serialized_start=1173 - _DATAREQUESTPROTO._serialized_end=1461 - _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_start=1362 - _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_end=1461 - _DATAREQUESTLISTPROTO._serialized_start=1463 - _DATAREQUESTLISTPROTO._serialized_end=1527 - _JINACONTROLREQUESTRPC._serialized_start=1529 - _JINACONTROLREQUESTRPC._serialized_end=1627 - _JINADATAREQUESTRPC._serialized_start=1629 - _JINADATAREQUESTRPC._serialized_end=1719 - _JINASINGLEDATAREQUESTRPC._serialized_start=1721 - _JINASINGLEDATAREQUESTRPC._serialized_end=1820 - _JINARPC._serialized_start=1822 - _JINARPC._serialized_end=1893 - _JINADISCOVERENDPOINTSRPC._serialized_start=1895 - _JINADISCOVERENDPOINTSRPC._serialized_end=1991 + _STATUSPROTO_STATUSCODE._serialized_end=778 + _RELATEDENTITY._serialized_start=780 + _RELATEDENTITY._serialized_end=874 + _DATAREQUESTPROTO._serialized_start=877 + _DATAREQUESTPROTO._serialized_end=1165 + _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_start=1066 + _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_end=1165 + _DATAREQUESTLISTPROTO._serialized_start=1167 + _DATAREQUESTLISTPROTO._serialized_end=1231 + _JINADATAREQUESTRPC._serialized_start=1233 + _JINADATAREQUESTRPC._serialized_end=1323 + _JINASINGLEDATAREQUESTRPC._serialized_start=1325 + _JINASINGLEDATAREQUESTRPC._serialized_end=1424 + _JINARPC._serialized_start=1426 + _JINARPC._serialized_end=1497 + _JINADISCOVERENDPOINTSRPC._serialized_start=1499 + _JINADISCOVERENDPOINTSRPC._serialized_end=1595 + _JINAGATEWAYDRYRUNRPC._serialized_start=1597 + _JINAGATEWAYDRYRUNRPC._serialized_end=1675 # @@protoc_insertion_point(module_scope) diff --git a/jina/proto/jina_pb2_grpc.py b/jina/proto/jina_pb2_grpc.py index 37aa3bfde40aa..4575da9a47ba5 100644 --- a/jina/proto/jina_pb2_grpc.py +++ b/jina/proto/jina_pb2_grpc.py @@ -6,86 +6,6 @@ from . import serializer as jina__pb2 -class JinaControlRequestRPCStub(object): - """* - jina gRPC service for ControlRequests. - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.process_control = channel.unary_unary( - '/jina.JinaControlRequestRPC/process_control', - request_serializer=jina__pb2.ControlRequestProto.SerializeToString, - response_deserializer=jina__pb2.ControlRequestProto.FromString, - ) - - -class JinaControlRequestRPCServicer(object): - """* - jina gRPC service for ControlRequests. - """ - - def process_control(self, request, context): - """Used for passing ControlRequests to the Executors""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_JinaControlRequestRPCServicer_to_server(servicer, server): - rpc_method_handlers = { - 'process_control': grpc.unary_unary_rpc_method_handler( - servicer.process_control, - request_deserializer=jina__pb2.ControlRequestProto.FromString, - response_serializer=jina__pb2.ControlRequestProto.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'jina.JinaControlRequestRPC', rpc_method_handlers - ) - server.add_generic_rpc_handlers((generic_handler,)) - - -# This class is part of an EXPERIMENTAL API. -class JinaControlRequestRPC(object): - """* - jina gRPC service for ControlRequests. - """ - - @staticmethod - def process_control( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - '/jina.JinaControlRequestRPC/process_control', - jina__pb2.ControlRequestProto.SerializeToString, - jina__pb2.ControlRequestProto.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) - - class JinaDataRequestRPCStub(object): """* jina gRPC service for DataRequests. @@ -331,7 +251,7 @@ def Call( class JinaDiscoverEndpointsRPCStub(object): """* - jina gRPC service for DataRequests. + jina gRPC service to expose Endpoints from Executors. """ def __init__(self, channel): @@ -349,7 +269,7 @@ def __init__(self, channel): class JinaDiscoverEndpointsRPCServicer(object): """* - jina gRPC service for DataRequests. + jina gRPC service to expose Endpoints from Executors. """ def endpoint_discovery(self, request, context): @@ -376,7 +296,7 @@ def add_JinaDiscoverEndpointsRPCServicer_to_server(servicer, server): # This class is part of an EXPERIMENTAL API. class JinaDiscoverEndpointsRPC(object): """* - jina gRPC service for DataRequests. + jina gRPC service to expose Endpoints from Executors. """ @staticmethod @@ -407,3 +327,83 @@ def endpoint_discovery( timeout, metadata, ) + + +class JinaGatewayDryRunRPCStub(object): + """* + jina gRPC service to expose Endpoints from Executors. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.dry_run = channel.unary_unary( + '/jina.JinaGatewayDryRunRPC/dry_run', + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + response_deserializer=jina__pb2.StatusProto.FromString, + ) + + +class JinaGatewayDryRunRPCServicer(object): + """* + jina gRPC service to expose Endpoints from Executors. + """ + + def dry_run(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_JinaGatewayDryRunRPCServicer_to_server(servicer, server): + rpc_method_handlers = { + 'dry_run': grpc.unary_unary_rpc_method_handler( + servicer.dry_run, + request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + response_serializer=jina__pb2.StatusProto.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'jina.JinaGatewayDryRunRPC', rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class JinaGatewayDryRunRPC(object): + """* + jina gRPC service to expose Endpoints from Executors. + """ + + @staticmethod + def dry_run( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + '/jina.JinaGatewayDryRunRPC/dry_run', + google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + jina__pb2.StatusProto.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/jina/proto/serializer.py b/jina/proto/serializer.py index 686bb67a981ba..bbdf0d9f796f5 100644 --- a/jina/proto/serializer.py +++ b/jina/proto/serializer.py @@ -2,39 +2,9 @@ from typing import Iterable, List, Union from jina.proto import jina_pb2 -from jina.types.request.control import ControlRequest from jina.types.request.data import DataRequest -class ControlRequestProto: - """This class is a drop-in replacement for gRPC default serializer. - - It replace default serializer to make sure we always work with `Request` - - """ - - @staticmethod - def SerializeToString(x: 'ControlRequest'): - """ - # noqa: DAR101 - # noqa: DAR102 - # noqa: DAR201 - """ - return x.proto.SerializePartialToString() - - @staticmethod - def FromString(x: bytes): - """ - # noqa: DAR101 - # noqa: DAR102 - # noqa: DAR201 - """ - proto = jina_pb2.ControlRequestProto() - proto.ParseFromString(x) - - return ControlRequest(request=proto) - - class DataRequestProto: """This class is a drop-in replacement for gRPC default serializer. @@ -130,3 +100,30 @@ def FromString(x: bytes): ep.ParseFromString(x) return ep + + +class StatusProto: + """Since the serializer is replacing the `jina_pb2 to know how to exactly serialize messages, this is just a placeholder that + delegates the serializing and deserializing to the internal protobuf structure with no extra optimization. + """ + + @staticmethod + def SerializeToString(x): + """ + # noqa: DAR101 + # noqa: DAR102 + # noqa: DAR201 + """ + return x.SerializeToString() + + @staticmethod + def FromString(x: bytes): + """ + # noqa: DAR101 + # noqa: DAR102 + # noqa: DAR201 + """ + sp = jina_pb2.StatusProto() + sp.ParseFromString(x) + + return sp diff --git a/jina/resources/extra-requirements.txt b/jina/resources/extra-requirements.txt index 5c97f3e87727a..d4cc2f94bf6dc 100644 --- a/jina/resources/extra-requirements.txt +++ b/jina/resources/extra-requirements.txt @@ -29,6 +29,7 @@ numpy: core protobuf>=3.19.1,<=3.20.1: core grpcio>=1.46.0: core grpcio-reflection>=1.46.0: core +grpcio-health-checking>=1.46.0: core pyyaml>=5.3.1: core packaging>=20.0: core docarray>=0.13.14: core diff --git a/jina/resources/health_check/pod.py b/jina/resources/health_check/pod.py index faec7192f6b3b..ffd6705ca90b1 100644 --- a/jina/resources/health_check/pod.py +++ b/jina/resources/health_check/pod.py @@ -3,22 +3,14 @@ def check_health_pod(addr: str): :param addr: the address on which the pod is serving ex : localhost:1234 """ - import grpc - - from jina.serve.networking import GrpcConnectionPool - from jina.types.request.control import ControlRequest - - try: - GrpcConnectionPool.send_request_sync( - request=ControlRequest('STATUS'), - target=addr, - ) - except grpc.RpcError as e: - print('The pod is unhealthy') - print(e) - raise e - - print('The pod is healthy') + from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime + + is_ready = AsyncNewLoopRuntime.is_ready(addr) + + if not is_ready: + raise Exception('Pod is unhealthy') + + print('The Pod is healthy') if __name__ == '__main__': diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index 51bd81fc6b5b6..f725148952567 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -23,7 +23,11 @@ if TYPE_CHECKING: from prometheus_client import Summary -__all__ = ['BaseExecutor'] + from docarray import DocumentArray + +__dry_run_endpoint__ = '_jina_dry_run_' + +__all__ = ['BaseExecutor', __dry_run_endpoint__] class ExecutorType(type(JAMLCompatible), type): @@ -114,6 +118,16 @@ def __init__( self._add_runtime_args(runtime_args) self._init_monitoring() self.logger = JinaLogger(self.__class__.__name__) + if __dry_run_endpoint__ not in self.requests: + self.requests[__dry_run_endpoint__] = self._dry_run_func + else: + self.logger.warning( + f' Endpoint {__dry_run_endpoint__} is defined by the Executor. Be aware that this endpoint is usually reserved to enable health checks from the Client through the gateway.' + f' So it is recommended not to expose this endpoint. ' + ) + + def _dry_run_func(self, *args, **kwargs): + pass def _add_runtime_args(self, _runtime_args: Optional[Dict]): if _runtime_args: diff --git a/jina/serve/networking.py b/jina/serve/networking.py index 4939302b6a8fb..45c2fbb8f6d13 100644 --- a/jina/serve/networking.py +++ b/jina/serve/networking.py @@ -8,6 +8,7 @@ import grpc from grpc.aio import AioRpcError +from grpc_health.v1 import health_pb2, health_pb2_grpc from grpc_reflection.v1alpha.reflection_pb2 import ServerReflectionRequest from grpc_reflection.v1alpha.reflection_pb2_grpc import ServerReflectionStub @@ -16,7 +17,6 @@ from jina.logging.logger import JinaLogger from jina.proto import jina_pb2, jina_pb2_grpc from jina.types.request import Request -from jina.types.request.control import ControlRequest from jina.types.request.data import DataRequest TLS_PROTOCOL_SCHEMES = ['grpcs', 'https', 'wss'] @@ -26,6 +26,8 @@ if TYPE_CHECKING: from prometheus_client import CollectorRegistry + from jina.types.request.control import ControlRequest + class ReplicaList: """ @@ -153,7 +155,6 @@ class ConnectionStubs: """ STUB_MAPPING = { - 'jina.JinaControlRequestRPC': jina_pb2_grpc.JinaControlRequestRPCStub, 'jina.JinaDataRequestRPC': jina_pb2_grpc.JinaDataRequestRPCStub, 'jina.JinaSingleDataRequestRPC': jina_pb2_grpc.JinaSingleDataRequestRPCStub, 'jina.JinaDiscoverEndpointsRPC': jina_pb2_grpc.JinaDiscoverEndpointsRPCStub, @@ -175,7 +176,6 @@ async def _init_stubs(self): stubs = defaultdict(lambda: None) for service in available_services: stubs[service] = self.STUB_MAPPING[service](self.channel) - self.control_stub = stubs['jina.JinaControlRequestRPC'] self.data_list_stub = stubs['jina.JinaDataRequestRPC'] self.single_data_stub = stubs['jina.JinaSingleDataRequestRPC'] self.stream_stub = stubs['jina.JinaRPC'] @@ -265,20 +265,6 @@ async def send_requests( raise ValueError( 'Can not send list of DataRequests. gRPC endpoint not available.' ) - elif request_type == ControlRequest: - if self.control_stub: - call_result = self.control_stub.process_control( - requests[0], timeout=timeout - ) - metadata, response = ( - await call_result.trailing_metadata(), - await call_result, - ) - return response, metadata - else: - raise ValueError( - 'Can not send ControlRequest. gRPC endpoint not available.' - ) else: raise ValueError(f'Unsupported request type {type(requests[0])}') @@ -508,7 +494,7 @@ def send_requests( ) -> List[asyncio.Task]: """Send a request to target via one or all of the pooled connections, depending on polling_type - :param requests: request (DataRequest/ControlRequest) to send + :param requests: request (DataRequest) to send :param deployment: name of the Jina deployment to send the request to :param head: If True it is send to the head, otherwise to the worker pods :param shard_id: Send to a specific shard of the deployment, ignored for polling ALL @@ -806,81 +792,6 @@ def get_grpc_channel( return insecure_channel(address, options) - @staticmethod - def activate_worker_sync( - worker_host: str, - worker_port: int, - target_head: str, - shard_id: Optional[int] = None, - ) -> ControlRequest: - """ - Register a given worker to a head by sending an activate request - - :param worker_host: the host address of the worker - :param worker_port: the port of the worker - :param target_head: address of the head to send the activate request to - :param shard_id: id of the shard the worker belongs to - :returns: the response request - """ - activate_request = ControlRequest(command='ACTIVATE') - activate_request.add_related_entity( - 'worker', worker_host, worker_port, shard_id - ) - - if os.name != 'nt': - os.unsetenv('http_proxy') - os.unsetenv('https_proxy') - - return GrpcConnectionPool.send_request_sync(activate_request, target_head) - - @staticmethod - async def activate_worker( - worker_host: str, - worker_port: int, - target_head: str, - shard_id: Optional[int] = None, - ) -> ControlRequest: - """ - Register a given worker to a head by sending an activate request - - :param worker_host: the host address of the worker - :param worker_port: the port of the worker - :param target_head: address of the head to send the activate request to - :param shard_id: id of the shard the worker belongs to - :returns: the response request - """ - activate_request = ControlRequest(command='ACTIVATE') - activate_request.add_related_entity( - 'worker', worker_host, worker_port, shard_id - ) - return await GrpcConnectionPool.send_request_async( - activate_request, target_head - ) - - @staticmethod - async def deactivate_worker( - worker_host: str, - worker_port: int, - target_head: str, - shard_id: Optional[int] = None, - ) -> ControlRequest: - """ - Remove a given worker to a head by sending a deactivate request - - :param worker_host: the host address of the worker - :param worker_port: the port of the worker - :param target_head: address of the head to send the deactivate request to - :param shard_id: id of the shard the worker belongs to - :returns: the response request - """ - activate_request = ControlRequest(command='DEACTIVATE') - activate_request.add_related_entity( - 'worker', worker_host, worker_port, shard_id - ) - return await GrpcConnectionPool.send_request_async( - activate_request, target_head - ) - @staticmethod def send_request_sync( request: Request, @@ -910,22 +821,51 @@ def send_request_sync( tls=tls, root_certificates=root_certificates, ) as channel: - if type(request) == DataRequest: - metadata = (('endpoint', endpoint),) if endpoint else None - stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) - response, call = stub.process_single_data.with_call( - request, - timeout=timeout, - metadata=metadata, - ) - elif type(request) == ControlRequest: - stub = jina_pb2_grpc.JinaControlRequestRPCStub(channel) - response = stub.process_control(request, timeout=timeout) + metadata = (('endpoint', endpoint),) if endpoint else None + stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) + response, call = stub.process_single_data.with_call( + request, + timeout=timeout, + metadata=metadata, + ) return response except grpc.RpcError as e: if e.code() != grpc.StatusCode.UNAVAILABLE or i == 2: raise + @staticmethod + def send_health_check_sync( + target: str, + timeout=100.0, + tls=False, + root_certificates: Optional[str] = None, + ) -> health_pb2.HealthCheckResponse: + """ + Sends a request synchronously to the target via grpc + + :param target: where to send the request to, like 127.0.0.1:8080 + :param timeout: timeout for the send + :param tls: if True, use tls encryption for the grpc channel + :param root_certificates: the path to the root certificates for tls, only used if tls is True + + :returns: the response health check + """ + + for i in range(3): + try: + with GrpcConnectionPool.get_grpc_channel( + target, + tls=tls, + root_certificates=root_certificates, + ) as channel: + health_check_req = health_pb2.HealthCheckRequest() + health_check_req.service = '' + stub = health_pb2_grpc.HealthStub(channel) + return stub.Check(health_check_req, timeout=timeout) + except grpc.RpcError as e: + if e.code() != grpc.StatusCode.UNAVAILABLE or i == 2: + raise + @staticmethod def send_requests_sync( requests: List[Request], @@ -1005,12 +945,8 @@ async def send_request_async( tls=tls, root_certificates=root_certificates, ) as channel: - if type(request) == DataRequest: - stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) - return await stub.process_single_data(request, timeout=timeout) - elif type(request) == ControlRequest: - stub = jina_pb2_grpc.JinaControlRequestRPCStub(channel) - return await stub.process_control(request, timeout=timeout) + stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) + return await stub.process_single_data(request, timeout=timeout) @staticmethod def create_async_channel_stub( @@ -1024,7 +960,7 @@ def create_async_channel_stub( :param root_certificates: the path to the root certificates for tls, only u :param summary: Optional Prometheus summary object - :returns: DataRequest/ControlRequest stubs and an async grpc channel + :returns: DataRequest stubs and an async grpc channel """ channel = GrpcConnectionPool.get_grpc_channel( address, @@ -1057,7 +993,11 @@ async def get_available_services(channel) -> List[str]: [ service.name for service in res.list_services_response.service - if service.name != 'grpc.reflection.v1alpha.ServerReflection' + if service.name + not in { + 'grpc.reflection.v1alpha.ServerReflection', + 'jina.JinaGatewayDryRunRPC', + } ] ) return service_names[0] diff --git a/jina/serve/runtimes/asyncio.py b/jina/serve/runtimes/asyncio.py index 22fd5fc326716..35f7b99413771 100644 --- a/jina/serve/runtimes/asyncio.py +++ b/jina/serve/runtimes/asyncio.py @@ -12,7 +12,6 @@ from jina.serve.networking import GrpcConnectionPool from jina.serve.runtimes.base import BaseRuntime from jina.serve.runtimes.monitoring import MonitoringMixin -from jina.types.request.control import ControlRequest from jina.types.request.data import DataRequest if TYPE_CHECKING: @@ -148,12 +147,15 @@ def is_ready(ctrl_address: str, **kwargs) -> bool: """ try: - GrpcConnectionPool.send_request_sync( - ControlRequest('STATUS'), ctrl_address, timeout=1.0 + from grpc_health.v1 import health_pb2, health_pb2_grpc + + response = GrpcConnectionPool.send_health_check_sync( + ctrl_address, timeout=1.0 ) - except RpcError as e: + # TODO: Get the proper value of the ServingStatus SERVING KEY + return response.status == 1 + except RpcError: return False - return True @staticmethod def wait_for_ready_or_shutdown( @@ -181,16 +183,8 @@ def wait_for_ready_or_shutdown( time.sleep(0.1) return False - def _log_info_msg(self, request: Union[ControlRequest, DataRequest]): - if type(request) == DataRequest: - self._log_data_request(request) - elif type(request) == ControlRequest: - self._log_control_request(request) - - def _log_control_request(self, request: ControlRequest): - self.logger.debug( - f'recv ControlRequest {request.command} with id: {request.header.request_id}' - ) + def _log_info_msg(self, request: DataRequest): + self._log_data_request(request) def _log_data_request(self, request: DataRequest): self.logger.debug( diff --git a/jina/serve/runtimes/gateway/graph/topology_graph.py b/jina/serve/runtimes/gateway/graph/topology_graph.py index c1c4dc8e6b6cc..3df86cdcf67f9 100644 --- a/jina/serve/runtimes/gateway/graph/topology_graph.py +++ b/jina/serve/runtimes/gateway/graph/topology_graph.py @@ -71,7 +71,7 @@ def _handle_internalnetworkerror(self, err): else: raise - def get_endpoints(self, connection_pool: GrpcConnectionPool): + def get_endpoints(self, connection_pool: GrpcConnectionPool) -> asyncio.Task: return connection_pool.send_discover_endpoint(self.name) async def _wait_previous_and_send( @@ -300,3 +300,29 @@ def origin_nodes(self): :return: A list of nodes """ return self._origin_nodes + + @property + def all_nodes(self): + """ + The set of all the nodes inside this Graph + + :return: A list of nodes + """ + + def _get_all_nodes(node, accum, accum_names): + if node.name not in accum_names: + accum.append(node) + accum_names.append(node.name) + for n in node.outgoing_nodes: + _get_all_nodes(n, accum, accum_names) + return accum, accum_names + + nodes = [] + node_names = [] + for origin_node in self.origin_nodes: + subtree_nodes, subtree_node_names = _get_all_nodes(origin_node, [], []) + for st_node, st_node_name in zip(subtree_nodes, subtree_node_names): + if st_node_name not in node_names: + nodes.append(st_node) + node_names.append(st_node_name) + return nodes diff --git a/jina/serve/runtimes/gateway/grpc/__init__.py b/jina/serve/runtimes/gateway/grpc/__init__.py index d25a33a06050b..0fc8cf5b0da18 100644 --- a/jina/serve/runtimes/gateway/grpc/__init__.py +++ b/jina/serve/runtimes/gateway/grpc/__init__.py @@ -1,6 +1,8 @@ +import argparse import os import grpc +from grpc_health.v1 import health, health_pb2, health_pb2_grpc from grpc_reflection.v1alpha import reflection from jina import __default_host__ @@ -10,15 +12,26 @@ from jina.serve.runtimes.gateway import GatewayRuntime from jina.serve.runtimes.gateway.request_handling import RequestHandler from jina.serve.stream import RequestStreamer +from jina.types.request.status import StatusMessage __all__ = ['GRPCGatewayRuntime'] -from jina.types.request.control import ControlRequest - class GRPCGatewayRuntime(GatewayRuntime): """Gateway Runtime for gRPC.""" + def __init__( + self, + args: argparse.Namespace, + **kwargs, + ): + """Initialize the runtime + :param args: args from CLI + :param kwargs: keyword args + """ + self._health_servicer = health.HealthServicer(experimental_non_blocking=True) + super().__init__(args, **kwargs) + async def async_setup(self): """ The async method to setup. @@ -59,13 +72,18 @@ async def _async_setup_server(self): self.streamer.Call = self.streamer.stream jina_pb2_grpc.add_JinaRPCServicer_to_server(self.streamer, self.server) - jina_pb2_grpc.add_JinaControlRequestRPCServicer_to_server(self, self.server) + jina_pb2_grpc.add_JinaGatewayDryRunRPCServicer_to_server(self, self.server) service_names = ( jina_pb2.DESCRIPTOR.services_by_name['JinaRPC'].full_name, - jina_pb2.DESCRIPTOR.services_by_name['JinaControlRequestRPC'].full_name, + jina_pb2.DESCRIPTOR.services_by_name['JinaGatewayDryRunRPC'].full_name, reflection.SERVICE_NAME, ) + # Mark all services as healthy. + health_pb2_grpc.add_HealthServicer_to_server(self._health_servicer, self.server) + + for service in service_names: + self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING) reflection.enable_server_reflection(service_names, self.server) bind_addr = f'{__default_host__}:{self.args.port}' @@ -100,6 +118,7 @@ async def async_teardown(self): """Close the connection pool""" # usually async_cancel should already have been called, but then its a noop # if the runtime is stopped without a sigterm (e.g. as a context manager, this can happen) + self._health_servicer.enter_graceful_shutdown() await self.async_cancel() await self._connection_pool.close() @@ -112,18 +131,33 @@ async def async_run_forever(self): self._connection_pool.start() await self.server.wait_for_termination() - async def process_control(self, request: ControlRequest, *args) -> ControlRequest: + async def dry_run(self, empty, context) -> jina_pb2.StatusProto: """ - Should be used to check readiness by sending STATUS ControlRequests. - Throws for any other command than STATUS. + Process the the call requested by having a dry run call to every Executor in the graph - :param request: the ControlRequest, should have command 'STATUS' - :param args: additional arguments in the grpc call, ignored - :returns: will be the original request + :param empty: The service expects an empty protobuf message + :param context: grpc context + :returns: the response request """ - - if self.logger.debug_enabled: - self._log_control_request(request) - if request.command != 'STATUS': - raise ValueError('gateway only support STATUS ControlRequests') - return request + from docarray import DocumentArray + from jina.clients.request import request_generator + from jina.enums import DataInputType + from jina.serve.executors import __dry_run_endpoint__ + + da = DocumentArray() + + try: + req_iterator = request_generator( + exec_endpoint=__dry_run_endpoint__, + data=da, + data_type=DataInputType.DOCUMENT, + ) + async for _ in self.streamer.stream(request_iterator=req_iterator): + pass + status_message = StatusMessage() + status_message.set_code(jina_pb2.StatusProto.SUCCESS) + return status_message.proto + except Exception as ex: + status_message = StatusMessage() + status_message.set_exception(ex) + return status_message.proto diff --git a/jina/serve/runtimes/gateway/http/app.py b/jina/serve/runtimes/gateway/http/app.py index 0b6b7a403e1ad..a56b293851348 100644 --- a/jina/serve/runtimes/gateway/http/app.py +++ b/jina/serve/runtimes/gateway/http/app.py @@ -96,17 +96,54 @@ async def _shutdown(): @app.get( path='/', - summary='Get the health of Jina service', + summary='Get the health of Jina Gateway service', response_model=JinaHealthModel, ) - async def _health(): + async def _gateway_health(): """ - Get the health of this Jina service. + Get the health of this Gateway service. .. # noqa: DAR201 """ return {} + from docarray import DocumentArray + from jina.proto import jina_pb2 + from jina.serve.executors import __dry_run_endpoint__ + from jina.serve.runtimes.gateway.http.models import PROTO_TO_PYDANTIC_MODELS + from jina.types.request.status import StatusMessage + + @app.get( + path='/dry_run', + summary='Get the readiness of Jina Flow service, sends an empty DocumentArray to the complete Flow to ' + 'validate connectivity', + response_model=PROTO_TO_PYDANTIC_MODELS.StatusProto, + ) + async def _flow_health(): + """ + Get the health of the complete Flow service. + .. # noqa: DAR201 + + """ + + da = DocumentArray() + + try: + _ = await _get_singleton_result( + request_generator( + exec_endpoint=__dry_run_endpoint__, + data=da, + data_type=DataInputType.DOCUMENT, + ) + ) + status_message = StatusMessage() + status_message.set_code(jina_pb2.StatusProto.SUCCESS) + return status_message.to_dict() + except Exception as ex: + status_message = StatusMessage() + status_message.set_exception(ex) + return status_message.to_dict(use_integers_for_enums=True) + @app.get( path='/status', summary='Get the status of Jina service', @@ -180,6 +217,8 @@ async def post( def _generate_exception_header(error: InternalNetworkError): import traceback + from jina.proto.serializer import DataRequest + exception_dict = { 'name': str(error.__class__), 'stacks': [ @@ -188,7 +227,7 @@ def _generate_exception_header(error: InternalNetworkError): 'executor': '', } status_dict = { - 'code': 3, # status error + 'code': DataRequest().status.ERROR, 'description': error.details() if error.details() else '', 'exception': exception_dict, } @@ -263,7 +302,6 @@ async def foo(body: JinaRequestModel): from dataclasses import asdict import strawberry - from docarray import DocumentArray from docarray.document.strawberry_type import ( JSONScalar, StrawberryDocument, @@ -271,6 +309,8 @@ async def foo(body: JinaRequestModel): ) from strawberry.fastapi import GraphQLRouter + from docarray import DocumentArray + async def get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint ): diff --git a/jina/serve/runtimes/gateway/request_handling.py b/jina/serve/runtimes/gateway/request_handling.py index 804ce7421e708..bc261ff6e2701 100644 --- a/jina/serve/runtimes/gateway/request_handling.py +++ b/jina/serve/runtimes/gateway/request_handling.py @@ -4,8 +4,8 @@ from typing import TYPE_CHECKING, Callable, List, Optional import grpc.aio -from docarray import DocumentArray +from docarray import DocumentArray from jina.excepts import InternalNetworkError from jina.importer import ImportExtensions from jina.serve.networking import GrpcConnectionPool @@ -72,22 +72,7 @@ def handle_request( """ async def gather_endpoints(request_graph): - def _get_all_nodes(node, accum, accum_names): - if node.name not in accum_names: - accum.append(node) - accum_names.append(node.name) - for n in node.outgoing_nodes: - _get_all_nodes(n, accum, accum_names) - return accum, accum_names - - nodes = [] - node_names = [] - for origin_node in request_graph.origin_nodes: - subtree_nodes, subtree_node_names = _get_all_nodes(origin_node, [], []) - for st_node, st_node_name in zip(subtree_nodes, subtree_node_names): - if st_node_name not in node_names: - nodes.append(st_node) - node_names.append(st_node_name) + nodes = request_graph.all_nodes try: tasks_to_get_endpoints = [ node.get_endpoints(connection_pool) for node in nodes diff --git a/jina/serve/runtimes/gateway/websocket/app.py b/jina/serve/runtimes/gateway/websocket/app.py index 61366c8232235..0126039849297 100644 --- a/jina/serve/runtimes/gateway/websocket/app.py +++ b/jina/serve/runtimes/gateway/websocket/app.py @@ -1,13 +1,14 @@ import argparse -import json -from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional +from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union +from docarray import DocumentArray from jina.clients.request import request_generator from jina.enums import DataInputType, WebsocketSubProtocols from jina.excepts import InternalNetworkError from jina.importer import ImportExtensions from jina.logging.logger import JinaLogger from jina.types.request.data import DataRequest +from jina.types.request.status import StatusMessage if TYPE_CHECKING: from prometheus_client import CollectorRegistry @@ -104,7 +105,9 @@ async def iter(self, websocket: WebSocket) -> AsyncIterator[Any]: except WebSocketDisconnect: pass - async def send(self, websocket: WebSocket, data: DataRequest) -> None: + async def send( + self, websocket: WebSocket, data: Union[DataRequest, StatusMessage] + ) -> None: subprotocol = self.protocol_dict[self.get_client(websocket)] if subprotocol == WebsocketSubProtocols.JSON: return await websocket.send_json(data.to_dict(), mode='text') @@ -190,4 +193,90 @@ async def req_iter(): logger.info('Client successfully disconnected from server') manager.disconnect(websocket) + async def _get_singleton_result(request_iterator) -> Dict: + """ + Streams results from AsyncPrefetchCall as a dict + + :param request_iterator: request iterator, with length of 1 + :return: the first result from the request iterator + """ + async for k in streamer.stream(request_iterator=request_iterator): + request_dict = k.to_dict() + return request_dict + + from docarray import DocumentArray + from jina.proto import jina_pb2 + from jina.serve.executors import __dry_run_endpoint__ + from jina.serve.runtimes.gateway.http.models import PROTO_TO_PYDANTIC_MODELS + + @app.get( + path='/dry_run', + summary='Get the readiness of Jina Flow service, sends an empty DocumentArray to the complete Flow to ' + 'validate connectivity', + response_model=PROTO_TO_PYDANTIC_MODELS.StatusProto, + ) + async def _dry_run_http(): + """ + Get the health of the complete Flow service. + .. # noqa: DAR201 + + """ + + da = DocumentArray() + + try: + _ = await _get_singleton_result( + request_generator( + exec_endpoint=__dry_run_endpoint__, + data=da, + data_type=DataInputType.DOCUMENT, + ) + ) + status_message = StatusMessage() + status_message.set_code(jina_pb2.StatusProto.SUCCESS) + return status_message.to_dict() + except Exception as ex: + status_message = StatusMessage() + status_message.set_exception(ex) + return status_message.to_dict(use_integers_for_enums=True) + + @app.websocket('/dry_run') + async def websocket_endpoint( + websocket: WebSocket, response: Response + ): # 'response' is a FastAPI response, not a Jina response + from jina.proto import jina_pb2 + from jina.serve.executors import __dry_run_endpoint__ + + await manager.connect(websocket) + + da = DocumentArray() + try: + async for _ in streamer.stream( + request_iterator=request_generator( + exec_endpoint=__dry_run_endpoint__, + data=da, + data_type=DataInputType.DOCUMENT, + ) + ): + pass + status_message = StatusMessage() + status_message.set_code(jina_pb2.StatusProto.SUCCESS) + await manager.send(websocket, status_message) + except InternalNetworkError as err: + manager.disconnect(websocket) + msg = ( + err.details() + if _fits_ws_close_msg(err.details()) # some messages are too long + else f'Network error while connecting to deployment at {err.dest_addr}. It may be down.' + ) + await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=msg) + except WebSocketDisconnect: + logger.info('Client successfully disconnected from server') + manager.disconnect(websocket) + except Exception as ex: + manager.disconnect(websocket) + status_message = StatusMessage() + status_message.set_exception(ex) + await manager.send(websocket, status_message) + return app diff --git a/jina/serve/runtimes/head/__init__.py b/jina/serve/runtimes/head/__init__.py index 848e1935d342f..2dc630a04a30a 100644 --- a/jina/serve/runtimes/head/__init__.py +++ b/jina/serve/runtimes/head/__init__.py @@ -8,6 +8,7 @@ from typing import Dict, List, Optional, Tuple import grpc +from grpc_health.v1 import health, health_pb2, health_pb2_grpc from grpc_reflection.v1alpha import reflection from jina.enums import PollingType @@ -17,7 +18,6 @@ from jina.serve.networking import GrpcConnectionPool from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime from jina.serve.runtimes.request_handlers.data_request_handler import DataRequestHandler -from jina.types.request.control import ControlRequest from jina.types.request.data import DataRequest, Response @@ -37,6 +37,8 @@ def __init__( :param args: args from CLI :param kwargs: keyword args """ + self._health_servicer = health.HealthServicer(experimental_non_blocking=True) + super().__init__(args, **kwargs) if args.name is None: args.name = '' @@ -148,19 +150,22 @@ async def async_setup(self): self, self._grpc_server ) jina_pb2_grpc.add_JinaDataRequestRPCServicer_to_server(self, self._grpc_server) - jina_pb2_grpc.add_JinaControlRequestRPCServicer_to_server( - self, self._grpc_server - ) jina_pb2_grpc.add_JinaDiscoverEndpointsRPCServicer_to_server( self, self._grpc_server ) service_names = ( jina_pb2.DESCRIPTOR.services_by_name['JinaSingleDataRequestRPC'].full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaDataRequestRPC'].full_name, - jina_pb2.DESCRIPTOR.services_by_name['JinaControlRequestRPC'].full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaDiscoverEndpointsRPC'].full_name, reflection.SERVICE_NAME, ) + # Mark all services as healthy. + health_pb2_grpc.add_HealthServicer_to_server( + self._health_servicer, self._grpc_server + ) + + for service in service_names: + self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING) reflection.enable_server_reflection(service_names, self._grpc_server) bind_addr = f'0.0.0.0:{self.args.port}' @@ -181,6 +186,7 @@ async def async_cancel(self): async def async_teardown(self): """Close the connection pool""" + self._health_servicer.enter_graceful_shutdown() await self.async_cancel() await self.connection_pool.close() @@ -236,48 +242,6 @@ async def process_data(self, requests: List[DataRequest], context) -> DataReques context.set_trailing_metadata((('is-error', 'true'),)) return requests[0] - async def process_control(self, request: ControlRequest, *args) -> ControlRequest: - """ - Process the received control request and return the input request - - :param request: the data request to process - :param args: additional arguments in the grpc call, ignored - :returns: the input request - """ - try: - if self.logger.debug_enabled: - self._log_control_request(request) - - if request.command == 'ACTIVATE': - - for relatedEntity in request.relatedEntities: - connection_string = f'{relatedEntity.address}:{relatedEntity.port}' - - self.connection_pool.add_connection( - deployment=self._deployment_name, - address=connection_string, - shard_id=relatedEntity.shard_id - if relatedEntity.HasField('shard_id') - else None, - ) - elif request.command == 'DEACTIVATE': - for relatedEntity in request.relatedEntities: - connection_string = f'{relatedEntity.address}:{relatedEntity.port}' - await self.connection_pool.remove_connection( - deployment=self._deployment_name, - address=connection_string, - shard_id=relatedEntity.shard_id, - ) - return request - except (RuntimeError, Exception) as ex: - self.logger.error( - f'{ex!r}' + f'\n add "--quiet-error" to suppress the exception details' - if not self.args.quiet_error - else '', - exc_info=not self.args.quiet_error, - ) - raise - async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: """ Uses the connection pool to send a discover endpoint call to the workers diff --git a/jina/serve/runtimes/worker/__init__.py b/jina/serve/runtimes/worker/__init__.py index 190aadca8a040..2834192f7ac54 100644 --- a/jina/serve/runtimes/worker/__init__.py +++ b/jina/serve/runtimes/worker/__init__.py @@ -4,13 +4,13 @@ from typing import List import grpc +from grpc_health.v1 import health, health_pb2, health_pb2_grpc from grpc_reflection.v1alpha import reflection from jina.importer import ImportExtensions from jina.proto import jina_pb2, jina_pb2_grpc from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime from jina.serve.runtimes.request_handlers.data_request_handler import DataRequestHandler -from jina.types.request.control import ControlRequest from jina.types.request.data import DataRequest @@ -26,6 +26,7 @@ def __init__( :param args: args from CLI :param kwargs: keyword args """ + self._health_servicer = health.HealthServicer(experimental_non_blocking=True) super().__init__(args, **kwargs) async def async_setup(self): @@ -78,19 +79,23 @@ async def _async_setup_grpc_server(self): self, self._grpc_server ) jina_pb2_grpc.add_JinaDataRequestRPCServicer_to_server(self, self._grpc_server) - jina_pb2_grpc.add_JinaControlRequestRPCServicer_to_server( - self, self._grpc_server - ) + jina_pb2_grpc.add_JinaDiscoverEndpointsRPCServicer_to_server( self, self._grpc_server ) service_names = ( jina_pb2.DESCRIPTOR.services_by_name['JinaSingleDataRequestRPC'].full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaDataRequestRPC'].full_name, - jina_pb2.DESCRIPTOR.services_by_name['JinaControlRequestRPC'].full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaDiscoverEndpointsRPC'].full_name, reflection.SERVICE_NAME, ) + # Mark all services as healthy. + health_pb2_grpc.add_HealthServicer_to_server( + self._health_servicer, self._grpc_server + ) + + for service in service_names: + self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING) reflection.enable_server_reflection(service_names, self._grpc_server) bind_addr = f'0.0.0.0:{self.args.port}' self.logger.debug(f'start listening on {bind_addr}') @@ -112,6 +117,7 @@ async def async_cancel(self): async def async_teardown(self): """Close the data request handler""" + self._health_servicer.enter_graceful_shutdown() await self.async_cancel() self._data_request_handler.close() @@ -166,32 +172,3 @@ async def process_data(self, requests: List[DataRequest], context) -> DataReques requests[0].add_exception(ex, self._data_request_handler._executor) context.set_trailing_metadata((('is-error', 'true'),)) return requests[0] - - async def process_control(self, request: ControlRequest, *args) -> ControlRequest: - """ - Process the received control request and return the same request - - :param request: the control request to process - :param args: additional arguments in the grpc call, ignored - :returns: the input request - """ - try: - if self.logger.debug_enabled: - self._log_control_request(request) - - if request.command == 'STATUS': - pass - else: - raise RuntimeError( - f'WorkerRuntime received unsupported ControlRequest command {request.command}' - ) - except (RuntimeError, Exception) as ex: - self.logger.error( - f'{ex!r}' + f'\n add "--quiet-error" to suppress the exception details' - if not self.args.quiet_error - else '', - exc_info=not self.args.quiet_error, - ) - - request.add_exception(ex, self._data_request_handler._executor) - return request diff --git a/jina/types/mixin.py b/jina/types/mixin.py index e6961b9d6fbdb..8dff0234a3325 100644 --- a/jina/types/mixin.py +++ b/jina/types/mixin.py @@ -1,6 +1,6 @@ from typing import Dict -from jina.helper import typename, T, TYPE_CHECKING, deprecate_by +from jina.helper import TYPE_CHECKING, T, deprecate_by, typename if TYPE_CHECKING: from jina.proto import jina_pb2 @@ -18,7 +18,6 @@ class ProtoTypeMixin: .. code-block:: python class MyJinaType(ProtoTypeMixin): - def __init__(self, proto: Optional[jina_pb2.SomePbMsg] = None): self._pb_body = proto or jina_pb2.SomePbMsg() @@ -35,21 +34,20 @@ def to_json(self) -> str: self.proto, preserving_proto_field_name=True, sort_keys=True ) - def to_dict(self) -> Dict: + def to_dict(self, **kwargs) -> Dict: """Return the object in Python dictionary. .. note:: Array like object such as :class:`numpy.ndarray` (i.e. anything described as :class:`jina_pb2.NdArrayProto`) will be converted to Python list. + :param kwargs: Extra kwargs to be passed to MessageToDict, like use_integers_for_enums + :return: dict representation of the object """ from google.protobuf.json_format import MessageToDict - return MessageToDict( - self.proto, - preserving_proto_field_name=True, - ) + return MessageToDict(self.proto, preserving_proto_field_name=True, **kwargs) @property def proto(self) -> 'jina_pb2._reflection.GeneratedProtocolMessageType': diff --git a/jina/types/request/control.py b/jina/types/request/control.py deleted file mode 100644 index 025e579e291d6..0000000000000 --- a/jina/types/request/control.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import Optional - -from jina.helper import random_identity, typename -from jina.proto import jina_pb2 -from jina.types.request import Request - -_available_commands = dict(jina_pb2.ControlRequestProto.DESCRIPTOR.enum_values_by_name) - - -class ControlRequest(Request): - """ - :class:`ControlRequest` is one of the **primitive data type** in Jina. - - It offers a Pythonic interface to allow users access and manipulate - :class:`jina.jina_pb2.ControlRequestProto` object without working with Protobuf itself. - - A container for serialized :class:`jina_pb2.ControlRequestProto` that only triggers deserialization - and decompression when receives the first read access to its member. - - It overrides :meth:`__getattr__` to provide the same get/set interface as an - :class:`jina_pb2.ControlRequestProtoProto` object. - - :param command: the command for this request, can be STATUS, ACTIVATE or DEACTIVATE - :param request: The request. - """ - - def __init__( - self, - command: Optional[str] = None, - request: Optional['jina_pb2.jina_pb2.ControlRequestProto'] = None, - ): - - if isinstance(request, jina_pb2.ControlRequestProto): - self._pb_body = request - elif request is not None: - # note ``None`` is not considered as a bad type - raise ValueError(f'{typename(request)} is not recognizable') - if command: - proto = jina_pb2.ControlRequestProto() - proto.header.request_id = random_identity() - if command in _available_commands: - proto.command = getattr(jina_pb2.ControlRequestProto, command) - else: - raise ValueError( - f'command "{command}" is not supported, must be one of {_available_commands}' - ) - self._pb_body = proto - - def add_related_entity( - self, id: str, address: str, port: int, shard_id: Optional[int] = None - ): - """ - Add a related entity to this ControlMessage - - :param id: jina id of the entity - :param address: address of the entity - :param port: Port of the entity - :param shard_id: Optional id of the shard this entity belongs to - """ - self.proto.relatedEntities.append( - jina_pb2.RelatedEntity(id=id, address=address, port=port, shard_id=shard_id) - ) - - @property - def proto(self) -> 'jina_pb2.ControlRequestProto': - """ - Cast ``self`` to a :class:`jina_pb2.ControlRequestProto`. Laziness will be broken and serialization will be recomputed when calling - :meth:`SerializeToString`. - :return: protobuf instance - """ - return self._pb_body - - @property - def command(self) -> str: - """Get the command. - - .. #noqa: DAR201""" - return jina_pb2.ControlRequestProto.Command.Name(self.proto.command) diff --git a/jina/types/request/status.py b/jina/types/request/status.py new file mode 100644 index 0000000000000..fac4312f7f592 --- /dev/null +++ b/jina/types/request/status.py @@ -0,0 +1,60 @@ +from typing import Dict, Optional, TypeVar + +from google.protobuf import json_format + +from jina.excepts import BadRequestType +from jina.helper import typename +from jina.proto import jina_pb2 +from jina.types.mixin import ProtoTypeMixin + +StatusSourceType = TypeVar('StatusSourceType', jina_pb2.StatusProto, str, Dict, bytes) + + +class StatusMessage(ProtoTypeMixin): + """Represents a Status message used for health check of the Flow""" + + def __init__( + self, + status_object: Optional[StatusSourceType] = None, + ): + self._pb_body = jina_pb2.StatusProto() + try: + if isinstance(status_object, jina_pb2.StatusProto): + self._pb_body = status_object + elif isinstance(status_object, dict): + json_format.ParseDict(status_object, self._pb_body) + elif isinstance(status_object, str): + json_format.Parse(status_object, self._pb_body) + elif isinstance(status_object, bytes): + self._pb_body.ParseFromString(status_object) + elif status_object is not None: + # note ``None`` is not considered as a bad type + raise ValueError(f'{typename(status_object)} is not recognizable') + else: + self._pb_body = jina_pb2.StatusProto() + except Exception as ex: + raise BadRequestType( + f'fail to construct a {self.__class__} object from {status_object}' + ) from ex + + def set_exception(self, ex: Exception): + """Set exception information into the Status Message + + :param ex: The Exception to be filled + """ + import traceback + + self.proto.code = jina_pb2.StatusProto.ERROR + self.proto.description = repr(ex) + self.proto.exception.name = ex.__class__.__name__ + self.proto.exception.args.extend([str(v) for v in ex.args]) + self.proto.exception.stacks.extend( + traceback.format_exception(etype=type(ex), value=ex, tb=ex.__traceback__) + ) + + def set_code(self, code): + """Set the code of the Status Message + + :param code: The code to be added + """ + self.proto.code = code diff --git a/tests/integration/flow-dry-run/__init__.py b/tests/integration/flow-dry-run/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/integration/flow-dry-run/test_flow_dry_run.py b/tests/integration/flow-dry-run/test_flow_dry_run.py new file mode 100644 index 0000000000000..851cf5f4d0f73 --- /dev/null +++ b/tests/integration/flow-dry-run/test_flow_dry_run.py @@ -0,0 +1,14 @@ +import pytest + +from jina import Flow + + +@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket']) +def test_dry_run(protocol): + f = Flow(protocol=protocol).add() + with f: + dry_run = f.dry_run() + dry_run_negative = f.dry_run() + + assert dry_run + assert not dry_run_negative diff --git a/tests/integration/pods/container/test_pod.py b/tests/integration/pods/container/test_pod.py index 6e68377ffbd7e..3ddc69b6ae57a 100644 --- a/tests/integration/pods/container/test_pod.py +++ b/tests/integration/pods/container/test_pod.py @@ -1,5 +1,5 @@ import asyncio -import multiprocessing +import json import os import time @@ -11,10 +11,8 @@ from jina.orchestrate.pods import Pod from jina.orchestrate.pods.container import ContainerPod from jina.parsers import set_gateway_parser, set_pod_parser -from jina.serve.networking import GrpcConnectionPool from jina.serve.runtimes.head import HeadRuntime from jina.serve.runtimes.worker import WorkerRuntime -from jina.types.request.control import ControlRequest cur_dir = os.path.dirname(os.path.abspath(__file__)) @@ -61,8 +59,12 @@ async def test_pods_trivial_topology( # create a single worker pod worker_pod = _create_worker_pod(worker_port) + # this would be done by the Pod, its adding the worker to the head + worker_host, worker_port = worker_pod.runtime_ctrl_address.split(':') + connection_list_dict = {'0': [f'{worker_host}:{worker_port}']} + # create a single head pod - head_pod = _create_head_pod(head_port) + head_pod = _create_head_pod(head_port, connection_list_dict) # create a single gateway pod gateway_pod = _create_gateway_pod(graph_description, pod_addresses, port) @@ -86,14 +88,6 @@ async def test_pods_trivial_topology( worker_pod.ready_or_shutdown.event.wait(timeout=5.0) gateway_pod.ready_or_shutdown.event.wait(timeout=5.0) - # this would be done by the Pod, its adding the worker to the head - activate_msg = ControlRequest(command='ACTIVATE') - worker_host, worker_port = worker_pod.runtime_ctrl_address.split(':') - activate_msg.add_related_entity('worker', worker_host, int(worker_port)) - assert GrpcConnectionPool.send_request_sync( - activate_msg, head_pod.runtime_ctrl_address - ) - # send requests to the gateway c = Client(host='localhost', port=port, asyncio=True) responses = c.post( @@ -115,13 +109,14 @@ def _create_worker_pod(port): return ContainerPod(args) -def _create_head_pod(port): +def _create_head_pod(port, connection_list_dict): args = set_pod_parser().parse_args([]) args.port = port args.name = 'head' args.pod_role = PodRoleType.HEAD args.polling = PollingType.ANY args.uses = 'docker://head-runtime' + args.connection_list = json.dumps(connection_list_dict) return ContainerPod(args) diff --git a/tests/integration/pods/test_pod.py b/tests/integration/pods/test_pod.py index 2753e64f0960d..29e3273cd46a7 100644 --- a/tests/integration/pods/test_pod.py +++ b/tests/integration/pods/test_pod.py @@ -2,6 +2,7 @@ import inspect import json import time +from collections import defaultdict import pytest @@ -15,7 +16,6 @@ ) from jina.resources.health_check.pod import check_health_pod from jina.serve.networking import GrpcConnectionPool -from jina.types.request.control import ControlRequest @pytest.mark.asyncio @@ -31,7 +31,8 @@ async def test_pods_trivial_topology(port_generator): worker_pod = _create_worker_pod(worker_port) # create a single head pod - head_pod = _create_head_pod(head_port) + connection_list_dict = {'0': [f'127.0.0.1:{worker_port}']} + head_pod = _create_head_pod(head_port, connection_list_dict) # create a single gateway pod gateway_pod = _create_gateway_pod(graph_description, pod_addresses, port) @@ -40,11 +41,6 @@ async def test_pods_trivial_topology(port_generator): # this would be done by the Pod, its adding the worker to the head head_pod.wait_start_success() worker_pod.wait_start_success() - activate_msg = ControlRequest(command='ACTIVATE') - activate_msg.add_related_entity('worker', '127.0.0.1', worker_port) - assert GrpcConnectionPool.send_request_sync( - activate_msg, f'127.0.0.1:{head_port}' - ) # send requests to the gateway gateway_pod.wait_start_success() @@ -82,7 +78,8 @@ async def test_pods_health_check(port_generator, protocol, health_check): worker_pod = _create_worker_pod(worker_port) # create a single head pod - head_pod = _create_head_pod(head_port) + connection_list_dict = {'0': [f'127.0.0.1:{worker_port}']} + head_pod = _create_head_pod(head_port, connection_list_dict) # create a single gateway pod gateway_pod = _create_gateway_pod(graph_description, pod_addresses, port, protocol) @@ -91,12 +88,6 @@ async def test_pods_health_check(port_generator, protocol, health_check): # this would be done by the Pod, its adding the worker to the head head_pod.wait_start_success() worker_pod.wait_start_success() - activate_msg = ControlRequest(command='ACTIVATE') - activate_msg.add_related_entity('worker', '127.0.0.1', worker_port) - assert GrpcConnectionPool.send_request_sync( - activate_msg, f'127.0.0.1:{head_port}' - ) - # send requests to the gateway gateway_pod.wait_start_success() @@ -139,7 +130,6 @@ async def test_pods_flow_topology( ] pods = [] pod_addresses = '{' - ports = [] for deployment in deployments: if uses_before: uses_before_port, uses_before_pod = await _start_create_pod( @@ -152,11 +142,18 @@ async def test_pods_flow_topology( ) pods.append(uses_after_pod) + # create worker + worker_port, worker_pod = await _start_create_pod(deployment, port_generator) + pods.append(worker_pod) + # create head head_port = port_generator() pod_addresses += f'"{deployment}": ["0.0.0.0:{head_port}"],' + + connection_list_dict = {'0': [f'127.0.0.1:{worker_port}']} head_pod = _create_head_pod( head_port, + connection_list_dict, f'{deployment}/head', 'ANY', f'127.0.0.1:{uses_before_port}' if uses_before else None, @@ -166,18 +163,11 @@ async def test_pods_flow_topology( pods.append(head_pod) head_pod.start() - # create worker - worker_port, worker_pod = await _start_create_pod(deployment, port_generator) - ports.append((head_port, worker_port)) - pods.append(worker_pod) await asyncio.sleep(0.1) for pod in pods: pod.wait_start_success() - for head_port, worker_port in ports: - await _activate_worker(head_port, worker_port) - # remove last comma pod_addresses = pod_addresses[:-1] pod_addresses += '}' @@ -218,12 +208,9 @@ async def test_pods_shards(polling, port_generator): graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' - # create a single head pod - head_pod = _create_head_pod(head_port, 'head', polling) - head_pod.start() - # create the shards shard_pods = [] + connection_list_dict = {} for i in range(10): # create worker worker_port = port_generator() @@ -231,18 +218,18 @@ async def test_pods_shards(polling, port_generator): worker_pod = _create_worker_pod(worker_port, f'pod0/shard/{i}') shard_pods.append(worker_pod) worker_pod.start() + connection_list_dict[i] = [f'127.0.0.1:{worker_port}'] await asyncio.sleep(0.1) + # create a single head pod + head_pod = _create_head_pod(head_port, connection_list_dict, 'head', polling) + head_pod.start() + head_pod.wait_start_success() for i, pod in enumerate(shard_pods): # this would be done by the Pod, its adding the worker to the head pod.wait_start_success() - activate_msg = ControlRequest(command='ACTIVATE') - activate_msg.add_related_entity( - 'worker', '127.0.0.1', pod.args.port, shard_id=i - ) - GrpcConnectionPool.send_request_sync(activate_msg, f'127.0.0.1:{head_port}') # create a single gateway pod gateway_pod = _create_gateway_pod(graph_description, pod_addresses, port) @@ -275,12 +262,10 @@ async def test_pods_replicas(port_generator): graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' - # create a single head pod - head_pod = _create_head_pod(head_port, 'head') - head_pod.start() - # create the shards replica_pods = [] + + connection_list_dict = defaultdict(list) for i in range(10): # create worker worker_port = port_generator() @@ -288,17 +273,16 @@ async def test_pods_replicas(port_generator): worker_pod = _create_worker_pod(worker_port, f'pod0/{i}') replica_pods.append(worker_pod) worker_pod.start() + connection_list_dict[0].append(f'127.0.0.1:{worker_port}') await asyncio.sleep(0.1) # this would be done by the Pod, its adding the worker to the head - head_pod.wait_start_success() - for worker_pod in replica_pods: - worker_pod.wait_start_success() - activate_msg = ControlRequest(command='ACTIVATE') - activate_msg.add_related_entity('worker', '127.0.0.1', worker_pod.args.port) - GrpcConnectionPool.send_request_sync(activate_msg, f'127.0.0.1:{head_port}') + # create a single head pod + head_pod = _create_head_pod(head_port, connection_list_dict, 'head') + head_pod.start() + head_pod.wait_start_success() # create a single gateway pod gateway_pod = _create_gateway_pod(graph_description, pod_addresses, port) gateway_pod.start() @@ -337,11 +321,23 @@ async def test_pods_with_executor(port_generator): ) pods.append(uses_after_pod) + connection_list_dict = {} + + # create some shards + for i in range(10): + # create worker + worker_port, worker_pod = await _start_create_pod( + 'pod0', port_generator, type=f'shards/{i}', executor='NameChangeExecutor' + ) + pods.append(worker_pod) + await asyncio.sleep(0.1) + connection_list_dict[i] = [f'127.0.0.1:{worker_port}'] + # create head head_port = port_generator() - pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' head_pod = _create_head_pod( head_port, + connection_list_dict, f'pod0/head', 'ALL', f'127.0.0.1:{uses_before_port}', @@ -351,21 +347,12 @@ async def test_pods_with_executor(port_generator): pods.append(head_pod) head_pod.start() - # create some shards - for i in range(10): - # create worker - worker_port, worker_pod = await _start_create_pod( - 'pod0', port_generator, type=f'shards/{i}', executor='NameChangeExecutor' - ) - pods.append(worker_pod) - await asyncio.sleep(0.1) - await _activate_worker(head_port, worker_port, shard_id=i) - for pod in pods: pod.wait_start_success() # create a single gateway pod port = port_generator() + pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' gateway_pod = _create_gateway_pod(graph_description, pod_addresses, port) gateway_pod.start() @@ -439,34 +426,34 @@ async def test_pods_with_replicas_advance_faster(port_generator): graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' - # create a single head pod - head_pod = _create_head_pod(head_port, 'head') - head_pod.start() - # create a single gateway pod gateway_pod = _create_gateway_pod(graph_description, pod_addresses, port) gateway_pod.start() # create the shards + connection_list_dict = {} pods = [] for i in range(10): # create worker worker_port = port_generator() # create a single worker pod worker_pod = _create_worker_pod(worker_port, f'pod0/{i}', 'FastSlowExecutor') + connection_list_dict[i] = [f'127.0.0.1:{worker_port}'] + pods.append(worker_pod) worker_pod.start() await asyncio.sleep(0.1) + # create a single head pod + head_pod = _create_head_pod(head_port, connection_list_dict, 'head') + head_pod.start() + head_pod.wait_start_success() gateway_pod.wait_start_success() for pod in pods: # this would be done by the Pod, its adding the worker to the head pod.wait_start_success() - activate_msg = ControlRequest(command='ACTIVATE') - activate_msg.add_related_entity('worker', '127.0.0.1', pod.args.port) - GrpcConnectionPool.send_request_sync(activate_msg, f'127.0.0.1:{head_port}') c = Client(host='localhost', port=port, asyncio=True) input_docs = [Document(text='slow'), Document(text='fast')] @@ -536,7 +523,14 @@ def _create_worker_pod(port, name='', executor=None): return Pod(args) -def _create_head_pod(port, name='', polling='ANY', uses_before=None, uses_after=None): +def _create_head_pod( + port, + connection_list_dict, + name='', + polling='ANY', + uses_before=None, + uses_after=None, +): args = set_pod_parser().parse_args([]) args.port = port args.name = name @@ -548,6 +542,7 @@ def _create_head_pod(port, name='', polling='ANY', uses_before=None, uses_after= args.uses_before_address = uses_before if uses_after: args.uses_after_address = uses_after + args.connection_list = json.dumps(connection_list_dict) return Pod(args) diff --git a/tests/integration/runtimes/test_gateway_dry_run.py b/tests/integration/runtimes/test_gateway_dry_run.py new file mode 100644 index 0000000000000..00f9c833dba7a --- /dev/null +++ b/tests/integration/runtimes/test_gateway_dry_run.py @@ -0,0 +1,99 @@ +import asyncio +import json +import multiprocessing +import threading +import time +from collections import defaultdict + +import pytest + +from jina import Client, Document, Executor, requests +from jina.enums import PollingType +from jina.parsers import set_gateway_parser, set_pod_parser +from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime +from jina.serve.runtimes.gateway.grpc import GRPCGatewayRuntime +from jina.serve.runtimes.gateway.http import HTTPGatewayRuntime +from jina.serve.runtimes.gateway.websocket import WebSocketGatewayRuntime +from jina.serve.runtimes.worker import WorkerRuntime + + +def _create_worker_runtime(port, name='', executor=None): + args = set_pod_parser().parse_args([]) + args.port = port + args.name = name + if executor: + args.uses = executor + with WorkerRuntime(args) as runtime: + runtime.run_forever() + + +def _create_gateway_runtime(graph_description, pod_addresses, port, protocol='grpc'): + if protocol == 'http': + gateway_runtime = HTTPGatewayRuntime + elif protocol == 'websocket': + gateway_runtime = WebSocketGatewayRuntime + else: + gateway_runtime = GRPCGatewayRuntime + with gateway_runtime( + set_gateway_parser().parse_args( + [ + '--graph-description', + graph_description, + '--deployments-addresses', + pod_addresses, + '--port', + str(port), + ] + ) + ) as runtime: + runtime.run_forever() + + +@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket']) +def test_dry_run_of_flow(port_generator, protocol): + worker_port = port_generator() + port = port_generator() + graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' + pod_addresses = f'{{"pod0": ["0.0.0.0:{worker_port}"]}}' + + # create a single worker runtime + worker_process = multiprocessing.Process( + target=_create_worker_runtime, args=(worker_port,) + ) + worker_process.start() + + # create a single gateway runtime + gateway_process = multiprocessing.Process( + target=_create_gateway_runtime, + args=(graph_description, pod_addresses, port, protocol), + ) + gateway_process.start() + + AsyncNewLoopRuntime.wait_for_ready_or_shutdown( + timeout=5.0, + ctrl_address=f'0.0.0.0:{worker_port}', + ready_or_shutdown_event=multiprocessing.Event(), + ) + AsyncNewLoopRuntime.wait_for_ready_or_shutdown( + timeout=5.0, + ctrl_address=f'0.0.0.0:{port}', + ready_or_shutdown_event=multiprocessing.Event(), + ) + + # send requests to the gateway + c = Client(host='localhost', port=port, asyncio=True, protocol=protocol) + dry_run_alive = c.dry_run() + + worker_process.terminate() + worker_process.join() + + dry_run_worker_removed = c.dry_run() + + gateway_process.terminate() + gateway_process.join() + + assert dry_run_alive + assert not dry_run_worker_removed + + assert gateway_process.exitcode == 0 + assert worker_process.exitcode == 0 diff --git a/tests/integration/runtimes/test_network_failures.py b/tests/integration/runtimes/test_network_failures.py index 664576c836355..747bcc11eb0bc 100644 --- a/tests/integration/runtimes/test_network_failures.py +++ b/tests/integration/runtimes/test_network_failures.py @@ -5,11 +5,9 @@ from jina import Client, Document, Executor, requests from jina.parsers import set_gateway_parser, set_pod_parser -from jina.serve.networking import GrpcConnectionPool from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime from jina.serve.runtimes.gateway.http import HTTPGatewayRuntime from jina.serve.runtimes.worker import WorkerRuntime -from jina.types.request.control import ControlRequest from .test_runtimes import _create_gateway_runtime, _create_head_runtime @@ -51,9 +49,9 @@ def _create_gateway(port, graph, pod_addr, protocol): return p -def _create_head(port, polling): +def _create_head(port, connection_list_dict, polling): p = multiprocessing.Process( - target=_create_head_runtime, args=(port, 'head', polling) + target=_create_head_runtime, args=(port, connection_list_dict, 'head', polling) ) p.start() time.sleep(0.1) @@ -230,7 +228,9 @@ async def test_runtimes_headful_topology(port_generator, protocol, terminate_hea graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' - head_process = _create_head(head_port, 'ANY') + connection_list_dict = {'0': [f'127.0.0.1:{worker_port}']} + + head_process = _create_head(head_port, connection_list_dict, 'ANY') worker_process = _create_worker(worker_port) gateway_process = _create_gateway( gateway_port, graph_description, pod_addresses, protocol @@ -256,11 +256,6 @@ async def test_runtimes_headful_topology(port_generator, protocol, terminate_hea ready_or_shutdown_event=multiprocessing.Event(), ) - # this would be done by the Pod, its adding the worker to the head - activate_msg = ControlRequest(command='ACTIVATE') - activate_msg.add_related_entity('worker', '127.0.0.1', worker_port) - GrpcConnectionPool.send_request_sync(activate_msg, f'127.0.0.1:{head_port}') - # terminate pod, either head or worker behind the head if terminate_head: head_process.terminate() diff --git a/tests/integration/runtimes/test_runtimes.py b/tests/integration/runtimes/test_runtimes.py index 5372b8edbc3b3..3a94682ed822c 100644 --- a/tests/integration/runtimes/test_runtimes.py +++ b/tests/integration/runtimes/test_runtimes.py @@ -3,20 +3,19 @@ import multiprocessing import threading import time +from collections import defaultdict import pytest from jina import Client, Document, Executor, requests from jina.enums import PollingType from jina.parsers import set_gateway_parser, set_pod_parser -from jina.serve.networking import GrpcConnectionPool from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime from jina.serve.runtimes.gateway.grpc import GRPCGatewayRuntime from jina.serve.runtimes.gateway.http import HTTPGatewayRuntime from jina.serve.runtimes.gateway.websocket import WebSocketGatewayRuntime from jina.serve.runtimes.head import HeadRuntime from jina.serve.runtimes.worker import WorkerRuntime -from jina.types.request.control import ControlRequest @pytest.mark.asyncio @@ -35,8 +34,10 @@ async def test_runtimes_trivial_topology(port_generator): worker_process.start() # create a single head runtime + connection_list_dict = {'0': [f'127.0.0.1:{worker_port}']} + head_process = multiprocessing.Process( - target=_create_head_runtime, args=(head_port,) + target=_create_head_runtime, args=(head_port, connection_list_dict) ) head_process.start() @@ -67,11 +68,6 @@ async def test_runtimes_trivial_topology(port_generator): ready_or_shutdown_event=multiprocessing.Event(), ) - # this would be done by the Pod, its adding the worker to the head - activate_msg = ControlRequest(command='ACTIVATE') - activate_msg.add_related_entity('worker', '127.0.0.1', worker_port) - GrpcConnectionPool.send_request_sync(activate_msg, f'127.0.0.1:{head_port}') - # send requests to the gateway c = Client(host='localhost', port=port, asyncio=True) responses = c.post('/', inputs=async_inputs, request_size=1, return_responses=True) @@ -146,13 +142,24 @@ async def test_runtimes_flow_topology( ) runtime_processes.append(uses_after_process) + # create worker + worker_port, worker_process = await _create_worker(pod, port_generator) + AsyncNewLoopRuntime.wait_for_ready_or_shutdown( + timeout=5.0, + ready_or_shutdown_event=threading.Event(), + ctrl_address=f'127.0.0.1:{worker_port}', + ) + runtime_processes.append(worker_process) + # create head head_port = port_generator() pod_addresses += f'"{pod}": ["0.0.0.0:{head_port}"],' + connection_list_dict = {'0': [f'127.0.0.1:{worker_port}']} head_process = multiprocessing.Process( target=_create_head_runtime, args=( head_port, + connection_list_dict, f'{pod}/head', 'ANY', f'127.0.0.1:{uses_before_port}' if uses_before else None, @@ -162,18 +169,8 @@ async def test_runtimes_flow_topology( runtime_processes.append(head_process) head_process.start() - # create worker - worker_port, worker_process = await _create_worker(pod, port_generator) - AsyncNewLoopRuntime.wait_for_ready_or_shutdown( - timeout=5.0, - ready_or_shutdown_event=threading.Event(), - ctrl_address=f'127.0.0.1:{worker_port}', - ) - runtime_processes.append(worker_process) await asyncio.sleep(0.1) - await _activate_worker(head_port, worker_port) - # remove last comma pod_addresses = pod_addresses[:-1] pod_addresses += '}' @@ -227,15 +224,10 @@ async def test_runtimes_shards(polling, port_generator): graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' - # create a single head runtime - head_process = multiprocessing.Process( - target=_create_head_runtime, args=(head_port, 'head', polling) - ) - head_process.start() - # create the shards shard_processes = [] worker_ports = [] + connection_list_dict = defaultdict(list) for i in range(10): # create worker worker_port = port_generator() @@ -248,8 +240,14 @@ async def test_runtimes_shards(polling, port_generator): await asyncio.sleep(0.1) worker_ports.append(worker_port) + connection_list_dict[i].append(f'127.0.0.1:{worker_port}') - await _activate_runtimes(head_port, worker_ports) + # create a single head runtime + head_process = multiprocessing.Process( + target=_create_head_runtime, + args=(head_port, connection_list_dict, 'head', polling), + ) + head_process.start() # create a single gateway runtime gateway_process = multiprocessing.Process( @@ -300,15 +298,10 @@ async def test_runtimes_replicas(port_generator): graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' - # create a single head runtime - head_process = multiprocessing.Process( - target=_create_head_runtime, args=(head_port, 'head') - ) - head_process.start() - # create the shards replica_processes = [] worker_ports = [] + connection_list_dict = defaultdict(list) for i in range(10): # create worker worker_port = port_generator() @@ -321,8 +314,13 @@ async def test_runtimes_replicas(port_generator): await asyncio.sleep(0.1) worker_ports.append(worker_port) + connection_list_dict[0].append(f'127.0.0.1:{worker_port}') - await _activate_runtimes(head_port, worker_ports) + # create a single head runtime + head_process = multiprocessing.Process( + target=_create_head_runtime, args=(head_port, connection_list_dict, 'head') + ) + head_process.start() # create a single gateway runtime gateway_process = multiprocessing.Process( @@ -383,21 +381,9 @@ async def test_runtimes_with_executor(port_generator): # create head head_port = port_generator() pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' - head_process = multiprocessing.Process( - target=_create_head_runtime, - args=( - head_port, - f'pod0/head', - 'ALL', - f'127.0.0.1:{uses_before_port}', - f'127.0.0.1:{uses_after_port}', - ), - ) - runtime_processes.append(head_process) - head_process.start() - runtime_processes.append(head_process) # create some shards + connection_list_dict = defaultdict(list) worker_ports = [] for i in range(10): # create worker @@ -407,8 +393,21 @@ async def test_runtimes_with_executor(port_generator): runtime_processes.append(worker_process) await asyncio.sleep(0.1) worker_ports.append(worker_port) + connection_list_dict[i].append(f'127.0.0.1:{worker_port}') - await _activate_runtimes(head_port, worker_ports) + head_process = multiprocessing.Process( + target=_create_head_runtime, + args=( + head_port, + connection_list_dict, + f'pod0/head', + 'ALL', + f'127.0.0.1:{uses_before_port}', + f'127.0.0.1:{uses_after_port}', + ), + ) + runtime_processes.append(head_process) + head_process.start() # create a single gateway runtime port = port_generator() @@ -507,15 +506,10 @@ async def test_runtimes_with_replicas_advance_faster(port_generator): graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' - # create a single head runtime - head_process = multiprocessing.Process( - target=_create_head_runtime, args=(head_port, 'head') - ) - head_process.start() - # create the shards replica_processes = [] worker_ports = [] + connection_list_dict = defaultdict(list) for i in range(10): # create worker worker_port = port_generator() @@ -529,8 +523,13 @@ async def test_runtimes_with_replicas_advance_faster(port_generator): await asyncio.sleep(0.1) worker_ports.append(worker_port) + connection_list_dict[i].append(f'127.0.0.1:{worker_port}') - await _activate_runtimes(head_port, worker_ports) + # create a single head runtime + head_process = multiprocessing.Process( + target=_create_head_runtime, args=(head_port, connection_list_dict, 'head') + ) + head_process.start() # create a single gateway runtime gateway_process = multiprocessing.Process( @@ -660,7 +659,6 @@ def __init__(self, runtime_args, *args, **kwargs): @requests def foo(self, docs, **kwargs): - print(f'{self.name} doc count {len(docs)}') docs.append(Document(text=self.name)) return docs @@ -673,15 +671,6 @@ def foo(self, docs, **kwargs): time.sleep(1.0) -async def _activate_worker(head_port, worker_port, shard_id=None): - # this would be done by the Pod, its adding the worker to the head - activate_msg = ControlRequest(command='ACTIVATE') - activate_msg.add_related_entity( - 'worker', '127.0.0.1', worker_port, shard_id=shard_id - ) - GrpcConnectionPool.send_request_sync(activate_msg, f'127.0.0.1:{head_port}') - - async def _create_worker(pod, port_generator, type='worker', executor=None): worker_port = port_generator() worker_process = multiprocessing.Process( @@ -702,7 +691,12 @@ def _create_worker_runtime(port, name='', executor=None): def _create_head_runtime( - port, name='', polling='ANY', uses_before=None, uses_after=None + port, + connection_list_dict, + name='', + polling='ANY', + uses_before=None, + uses_after=None, ): args = set_pod_parser().parse_args([]) args.port = port @@ -712,6 +706,7 @@ def _create_head_runtime( args.uses_before_address = uses_before if uses_after: args.uses_after_address = uses_after + args.connection_list = json.dumps(connection_list_dict) with HeadRuntime(args) as runtime: runtime.run_forever() @@ -742,13 +737,3 @@ def _create_gateway_runtime(graph_description, pod_addresses, port, protocol='gr async def async_inputs(): for _ in range(20): yield Document(text='client0-Request') - - -async def _activate_runtimes(head_port, worker_ports): - for i, worker_port in enumerate(worker_ports): - AsyncNewLoopRuntime.wait_for_ready_or_shutdown( - timeout=5.0, - ready_or_shutdown_event=threading.Event(), - ctrl_address=f'127.0.0.1:{worker_port}', - ) - await _activate_worker(head_port, worker_port, shard_id=i) diff --git a/tests/integration/v2_api/test_func_routing.py b/tests/integration/v2_api/test_func_routing.py index b2fb79aee174a..0585f29d6d6d2 100644 --- a/tests/integration/v2_api/test_func_routing.py +++ b/tests/integration/v2_api/test_func_routing.py @@ -48,7 +48,8 @@ def foo(self, **kwargs): inputs=[(Document(), Document()) for _ in range(3)], return_responses=True, ) - assert results[0].header.status.code == 3 + + assert results[0].header.status.code == 1 def test_func_default_routing(): diff --git a/tests/unit/serve/executors/test_executor.py b/tests/unit/serve/executors/test_executor.py index ff29e9c072cb9..70102db27c32e 100644 --- a/tests/unit/serve/executors/test_executor.py +++ b/tests/unit/serve/executors/test_executor.py @@ -253,6 +253,10 @@ def do(self, *args, **kwargs): ], ) def test_override_requests(uses_requests, expected): + from jina.serve.executors import __dry_run_endpoint__ + + expected.add(__dry_run_endpoint__) + class OverrideExec(Executor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/tests/unit/serve/runtimes/gateway/grpc/test_grpc_gateway_runtime.py b/tests/unit/serve/runtimes/gateway/grpc/test_grpc_gateway_runtime.py index 904727d1f79aa..804d035c3f2be 100644 --- a/tests/unit/serve/runtimes/gateway/grpc/test_grpc_gateway_runtime.py +++ b/tests/unit/serve/runtimes/gateway/grpc/test_grpc_gateway_runtime.py @@ -469,10 +469,7 @@ def _create_runtime(): async with grpc.aio.insecure_channel(f'127.0.0.1:{port}') as channel: service_names = await GrpcConnectionPool.get_available_services(channel) - assert all( - service_name in service_names - for service_name in ['jina.JinaControlRequestRPC', 'jina.JinaRPC'] - ) + assert all(service_name in service_names for service_name in ['jina.JinaRPC']) p.terminate() p.join() diff --git a/tests/unit/serve/runtimes/gateway/grpc/test_grpc_tls.py b/tests/unit/serve/runtimes/gateway/grpc/test_grpc_tls.py index 1b5c40a04a360..aac9647596134 100644 --- a/tests/unit/serve/runtimes/gateway/grpc/test_grpc_tls.py +++ b/tests/unit/serve/runtimes/gateway/grpc/test_grpc_tls.py @@ -1,13 +1,11 @@ import os import time -import grpc import pytest -from docarray import Document +from docarray import Document from jina import Client, Flow from jina.serve.networking import GrpcConnectionPool -from jina.types.request.control import ControlRequest @pytest.fixture @@ -54,8 +52,7 @@ def test_grpc_ssl_with_flow_and_client(cert_pem, key_pem, error_log_level): with open(cert_pem, 'rb') as f: creds = f.read() - GrpcConnectionPool.send_request_sync( - request=ControlRequest('STATUS'), + GrpcConnectionPool.send_health_check_sync( target=f'localhost:{flow.port}', root_certificates=creds, tls=True, diff --git a/tests/unit/serve/runtimes/gateway/http/test_models.py b/tests/unit/serve/runtimes/gateway/http/test_models.py index 949e2811ec77d..0e2c022dd93c1 100644 --- a/tests/unit/serve/runtimes/gateway/http/test_models.py +++ b/tests/unit/serve/runtimes/gateway/http/test_models.py @@ -19,7 +19,7 @@ def test_enum_definitions(): status_code_enum_definition = PROTO_TO_PYDANTIC_MODELS.StatusProto().schema()[ 'definitions' ]['StatusCode'] - assert status_code_enum_definition['enum'] == [0, 1, 2, 3, 4, 5, 6] + assert status_code_enum_definition['enum'] == [0, 1] def test_timestamp(): diff --git a/tests/unit/serve/runtimes/head/test_head_runtime.py b/tests/unit/serve/runtimes/head/test_head_runtime.py index 463e2b161bffb..eb90d6751cb5c 100644 --- a/tests/unit/serve/runtimes/head/test_head_runtime.py +++ b/tests/unit/serve/runtimes/head/test_head_runtime.py @@ -7,7 +7,6 @@ import grpc import pytest -from grpc import RpcError from jina import Document, DocumentArray from jina.clients.request import request_generator @@ -18,17 +17,16 @@ from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime from jina.serve.runtimes.head import HeadRuntime from jina.types.request import Request -from jina.types.request.control import ControlRequest from jina.types.request.data import DataRequest def test_regular_data_case(): args = set_pod_parser().parse_args([]) args.polling = PollingType.ANY + connection_list_dict = {0: [f'fake_ip:8080']} + args.connection_list = json.dumps(connection_list_dict) cancel_event, handle_queue, runtime_thread = _create_runtime(args) - _add_worker(args) - with grpc.insecure_channel( f'{args.host}:{args.port}', options=GrpcConnectionPool.get_default_grpc_options(), @@ -44,33 +42,6 @@ def test_regular_data_case(): _destroy_runtime(args, cancel_event, runtime_thread) -def test_control_message_processing(): - args = set_pod_parser().parse_args([]) - cancel_event, handle_queue, runtime_thread = _create_runtime(args) - - # no connection registered yet - resp = GrpcConnectionPool.send_request_sync( - _create_test_data_message(), f'{args.host}:{args.port}' - ) - assert resp.status.code == resp.status.ERROR - - _add_worker(args, 'ip1') - # after adding a connection, sending should work - result = GrpcConnectionPool.send_request_sync( - _create_test_data_message(), f'{args.host}:{args.port}' - ) - assert result - - _remove_worker(args, 'ip1') - # after removing the connection again, sending does not work anymore - resp = GrpcConnectionPool.send_request_sync( - _create_test_data_message(), f'{args.host}:{args.port}' - ) - assert resp.status.code == resp.status.ERROR - - _destroy_runtime(args, cancel_event, runtime_thread) - - @pytest.mark.parametrize('disable_reduce', [False, True]) def test_message_merging(disable_reduce): if not disable_reduce: @@ -78,12 +49,10 @@ def test_message_merging(disable_reduce): else: args = set_pod_parser().parse_args(['--disable-reduce']) args.polling = PollingType.ALL + connection_list_dict = {0: [f'ip1:8080'], 1: [f'ip2:8080'], 2: [f'ip3:8080']} + args.connection_list = json.dumps(connection_list_dict) cancel_event, handle_queue, runtime_thread = _create_runtime(args) - assert handle_queue.empty() - _add_worker(args, 'ip1', shard_id=0) - _add_worker(args, 'ip2', shard_id=1) - _add_worker(args, 'ip3', shard_id=2) assert handle_queue.empty() data_request = _create_test_data_message() @@ -102,12 +71,10 @@ def test_uses_before_uses_after(): args.polling = PollingType.ALL args.uses_before_address = 'fake_address' args.uses_after_address = 'fake_address' + connection_list_dict = {0: [f'ip1:8080'], 1: [f'ip2:8080'], 2: [f'ip3:8080']} + args.connection_list = json.dumps(connection_list_dict) cancel_event, handle_queue, runtime_thread = _create_runtime(args) - assert handle_queue.empty() - _add_worker(args, 'ip1', shard_id=0) - _add_worker(args, 'ip2', shard_id=1) - _add_worker(args, 'ip3', shard_id=2) assert handle_queue.empty() result = GrpcConnectionPool.send_request_sync( @@ -139,10 +106,10 @@ def decompress(self): args = set_pod_parser().parse_args([]) args.polling = PollingType.ANY + connection_list_dict = {0: [f'fake_ip:8080']} + args.connection_list = json.dumps(connection_list_dict) cancel_event, handle_queue, runtime_thread = _create_runtime(args) - _add_worker(args) - with grpc.insecure_channel( f'{args.host}:{args.port}', options=GrpcConnectionPool.get_default_grpc_options(), @@ -172,10 +139,11 @@ def test_dynamic_polling(polling): str(2), ] ) - cancel_event, handle_queue, runtime_thread = _create_runtime(args) - _add_worker(args, shard_id=0) - _add_worker(args, shard_id=1) + connection_list_dict = {0: [f'fake_ip:8080'], 1: [f'fake_ip:8080']} + args.connection_list = json.dumps(connection_list_dict) + + cancel_event, handle_queue, runtime_thread = _create_runtime(args) with grpc.insecure_channel( f'{args.host}:{args.port}', @@ -214,11 +182,10 @@ def test_base_polling(polling): str(2), ] ) + connection_list_dict = {0: [f'fake_ip:8080'], 1: [f'fake_ip:8080']} + args.connection_list = json.dumps(connection_list_dict) cancel_event, handle_queue, runtime_thread = _create_runtime(args) - _add_worker(args, shard_id=0) - _add_worker(args, shard_id=1) - with grpc.insecure_channel( f'{args.host}:{args.port}', options=GrpcConnectionPool.get_default_grpc_options(), @@ -263,7 +230,6 @@ async def test_head_runtime_reflection(): assert all( service_name in service_names for service_name in [ - 'jina.JinaControlRequestRPC', 'jina.JinaDataRequestRPC', 'jina.JinaSingleDataRequestRPC', ] @@ -275,10 +241,10 @@ async def test_head_runtime_reflection(): def test_timeout_behaviour(): args = set_pod_parser().parse_args(['--timeout-send', '100']) args.polling = PollingType.ANY + connection_list_dict = {0: [f'fake_ip:8080']} + args.connection_list = json.dumps(connection_list_dict) cancel_event, handle_queue, runtime_thread = _create_runtime(args) - _add_worker(args) - with grpc.insecure_channel( f'{args.host}:{args.port}', options=GrpcConnectionPool.get_default_grpc_options(), @@ -340,22 +306,6 @@ async def mock_task_wrapper(new_requests, *args, **kwargs): return cancel_event, handle_queue, runtime_thread -def _add_worker(args, ip='fake_ip', shard_id=None): - activate_msg = ControlRequest(command='ACTIVATE') - activate_msg.add_related_entity('worker', ip, 8080, shard_id) - assert GrpcConnectionPool.send_request_sync( - activate_msg, f'{args.host}:{args.port}' - ) - - -def _remove_worker(args, ip='fake_ip', shard_id=None): - activate_msg = ControlRequest(command='DEACTIVATE') - activate_msg.add_related_entity('worker', ip, 8080, shard_id) - assert GrpcConnectionPool.send_request_sync( - activate_msg, f'{args.host}:{args.port}' - ) - - def _destroy_runtime(args, cancel_event, runtime_thread): cancel_event.set() runtime_thread.join() diff --git a/tests/unit/serve/runtimes/worker/test_worker_runtime.py b/tests/unit/serve/runtimes/worker/test_worker_runtime.py index b64b89bd2e4d2..d9247c5a8be29 100644 --- a/tests/unit/serve/runtimes/worker/test_worker_runtime.py +++ b/tests/unit/serve/runtimes/worker/test_worker_runtime.py @@ -9,8 +9,8 @@ import grpc import pytest import requests as req -from docarray import Document +from docarray import Document from jina import DocumentArray, Executor, requests from jina.clients.request import request_generator from jina.parsers import set_pod_parser @@ -367,7 +367,6 @@ def start_runtime(args, cancel_event): assert all( service_name in service_names for service_name in [ - 'jina.JinaControlRequestRPC', 'jina.JinaDataRequestRPC', 'jina.JinaSingleDataRequestRPC', ] diff --git a/tests/unit/test_helper.py b/tests/unit/test_helper.py index 73b5416fef7e3..41e749c0d2d19 100644 --- a/tests/unit/test_helper.py +++ b/tests/unit/test_helper.py @@ -112,17 +112,12 @@ def test_pprint_routes(capfd): r.status.exception.stacks.extend(['r1\nline1', 'r2\nline2']) result.append(r) r = jina_pb2.RouteProto() - r.status.code = jina_pb2.StatusProto.ERROR_CHAINED - r.status.exception.stacks.extend(['line1', 'line2']) - result.append(r) - r = jina_pb2.RouteProto() r.status.code = jina_pb2.StatusProto.SUCCESS result.append(r) rr = DataRequest() rr.routes.extend(result) pprint_routes(rr) out, err = capfd.readouterr() - assert '⚪' in out assert '🟢' in out assert 'Executor' in out assert 'Time' in out @@ -130,7 +125,6 @@ def test_pprint_routes(capfd): assert 'r1' in out assert 'line1r2' in out assert 'line2' in out - assert 'line1line2' in out def test_convert_tuple_to_list(): diff --git a/tests/unit/types/request/test_request.py b/tests/unit/types/request/test_request.py index 200431dbf91c1..ff447deee2f61 100644 --- a/tests/unit/types/request/test_request.py +++ b/tests/unit/types/request/test_request.py @@ -8,7 +8,6 @@ from jina.helper import random_identity from jina.proto import jina_pb2 from jina.proto.serializer import DataRequestProto -from jina.types.request.control import ControlRequest from jina.types.request.data import DataRequest, Response @@ -20,13 +19,6 @@ def req(): return r -@pytest.fixture(scope='function') -def control_req(): - r = jina_pb2.ControlRequestProto() - r.header.request_id = random_identity() - return r - - def test_init(req): assert DataRequest(request=None) assert DataRequest(request=req) @@ -60,13 +52,6 @@ def test_data_backwards_compatibility(req): assert len(req.data.docs) == len(req.docs) -def test_command(control_req): - request = ControlRequest(request=control_req) - cmd = request.command - assert cmd - assert isinstance(cmd, str) - - def test_as_pb_object(req): request = DataRequest(request=None) assert request.proto diff --git a/tests/unit/types/request/test_status_message.py b/tests/unit/types/request/test_status_message.py new file mode 100644 index 0000000000000..46bfa63833a4d --- /dev/null +++ b/tests/unit/types/request/test_status_message.py @@ -0,0 +1,41 @@ +import pytest +from google.protobuf.json_format import MessageToDict, MessageToJson + +from jina.excepts import BadRequestType +from jina.proto import jina_pb2 +from jina.types.request.status import StatusMessage + + +@pytest.fixture(scope='function') +def status_pb(): + return jina_pb2.StatusProto() + + +def test_init(status_pb): + assert StatusMessage(status_object=None) + assert StatusMessage(status_object=status_pb) + assert StatusMessage(status_object=MessageToDict(status_pb)) + assert StatusMessage(status_object=MessageToJson(status_pb)) + + +def test_init_fail(): + with pytest.raises(BadRequestType): + StatusMessage(status_object=5) + + +@pytest.mark.parametrize( + 'status_code', [jina_pb2.StatusProto.SUCCESS, jina_pb2.StatusProto.ERROR] +) +def test_set_code(status_code): + status = StatusMessage() + status.set_code(status_code) + assert status.proto.code == status_code + + +def test_set_exception(): + status = StatusMessage() + exc = Exception('exception code') + status.set_exception(exc) + assert status.proto.code == jina_pb2.StatusProto.ERROR + assert status.proto.description == repr(exc) + assert status.proto.exception.name == exc.__class__.__name__ diff --git a/tests/unit/yaml/datauriindex.yml b/tests/unit/yaml/datauriindex.yml deleted file mode 100644 index 7c84255062fe1..0000000000000 --- a/tests/unit/yaml/datauriindex.yml +++ /dev/null @@ -1,23 +0,0 @@ -!DataURIPbIndexer -with: - index_filename: doc.gzip -metas: - name: doc_indexer # a customized name - workspace: $TEST_DATAURIINDEX_WORKSPACE -requests: - on: - ControlRequest: - - !ControlReqDriver {} - SearchRequest: - - !KVSearchDriver {} - IndexRequest: - - !URI2DataURI - with: - override: true - - !URI2Text {} - - !ExcludeQL - with: - fields: - - buffer - - chunks - - !KVIndexDriver {}