diff --git a/jina/serve/gateway.py b/jina/serve/gateway.py new file mode 100644 index 0000000000000..7c83cb6ac4c88 --- /dev/null +++ b/jina/serve/gateway.py @@ -0,0 +1,101 @@ +import abc +import argparse +from typing import TYPE_CHECKING, Optional + +from jina.jaml import JAMLCompatible +from jina.logging.logger import JinaLogger +from jina.serve.streamer import GatewayStreamer + +__all__ = ['BaseGateway'] + +if TYPE_CHECKING: + from prometheus_client import CollectorRegistry + + +class BaseGateway(JAMLCompatible): + """ + The base class of all custom Gateways, can be used to build a custom interface to a Jina Flow that supports + gateway logic + + :class:`jina.Gateway` as an alias for this class. + """ + + def __init__( + self, + name: Optional[str] = 'gateway', + **kwargs, + ): + """ + :param name: Gateway pod name + :param kwargs: additional extra keyword arguments to avoid failing when extra params ara passed that are not expected + """ + self.streamer = None + self.name = name + # TODO: original implementation also passes args, maybe move this to a setter/initializer func + self.logger = JinaLogger(self.name) + + def set_streamer( + self, + args: 'argparse.Namespace' = None, + timeout_send: Optional[float] = None, + metrics_registry: Optional['CollectorRegistry'] = None, + runtime_name: Optional[str] = None, + ): + """ + Set streamer object by providing runtime parameters. + :param args: runtime args + :param timeout_send: grpc connection timeout + :param metrics_registry: metric registry when monitoring is enabled + :param runtime_name: name of the runtime providing the streamer + """ + import json + + from jina.serve.streamer import GatewayStreamer + + graph_description = json.loads(args.graph_description) + graph_conditions = json.loads(args.graph_conditions) + deployments_addresses = json.loads(args.deployments_addresses) + deployments_disable_reduce = json.loads(args.deployments_disable_reduce) + + self.streamer = GatewayStreamer( + graph_representation=graph_description, + executor_addresses=deployments_addresses, + graph_conditions=graph_conditions, + deployments_disable_reduce=deployments_disable_reduce, + timeout_send=timeout_send, + retries=args.retries, + compression=args.compression, + runtime_name=runtime_name, + prefetch=args.prefetch, + logger=self.logger, + metrics_registry=metrics_registry, + ) + + @abc.abstractmethod + async def setup_server(self): + """Setup server""" + ... + + @abc.abstractmethod + async def run_server(self): + """Run server forever""" + ... + + async def teardown(self): + """Free other resources allocated with the server, e.g, gateway object, ...""" + await self.streamer.close() + + @abc.abstractmethod + async def stop_server(self): + """Stop server""" + ... + + # some servers need to set a flag useful in handling termination signals + # e.g, HTTPGateway/ WebSocketGateway + @property + def should_exit(self) -> bool: + """ + Boolean flag that indicates whether the gateway server should exit or not + :return: boolean flag + """ + return False diff --git a/jina/serve/runtimes/gateway/__init__.py b/jina/serve/runtimes/gateway/__init__.py index cc48b494b2dc7..591f7628089e2 100644 --- a/jina/serve/runtimes/gateway/__init__.py +++ b/jina/serve/runtimes/gateway/__init__.py @@ -28,4 +28,3 @@ def __init__( if self.timeout_send: self.timeout_send /= 1e3 # convert ms to seconds super().__init__(args, cancel_event, **kwargs) - diff --git a/jina/serve/runtimes/gateway/grpc/__init__.py b/jina/serve/runtimes/gateway/grpc/__init__.py index f2d62c7590986..644cc624ec514 100644 --- a/jina/serve/runtimes/gateway/grpc/__init__.py +++ b/jina/serve/runtimes/gateway/grpc/__init__.py @@ -1,18 +1,10 @@ -import argparse import os -import grpc -from grpc_health.v1 import health, health_pb2, health_pb2_grpc -from grpc_reflection.v1alpha import reflection - from jina import __default_host__ from jina.excepts import PortAlreadyUsed -from jina.helper import get_full_version, is_port_free -from jina.proto import jina_pb2, jina_pb2_grpc -from jina.serve.bff import GatewayBFF +from jina.helper import is_port_free from jina.serve.runtimes.gateway import GatewayRuntime -from jina.serve.runtimes.helper import _get_grpc_server_options -from jina.types.request.status import StatusMessage +from jina.serve.runtimes.gateway.grpc.gateway import GRPCGateway __all__ = ['GRPCGatewayRuntime'] @@ -20,18 +12,6 @@ class GRPCGatewayRuntime(GatewayRuntime): """Gateway Runtime for gRPC.""" - def __init__( - self, - args: argparse.Namespace, - **kwargs, - ): - """Initialize the runtime - :param args: args from CLI - :param kwargs: keyword args - """ - self._health_servicer = health.HealthServicer(experimental_non_blocking=True) - super().__init__(args, **kwargs) - async def async_setup(self): """ The async method to setup. @@ -45,142 +25,33 @@ async def async_setup(self): if not (is_port_free(__default_host__, self.args.port)): raise PortAlreadyUsed(f'port:{self.args.port}') - self.server = grpc.aio.server( - options=_get_grpc_server_options(self.args.grpc_server_options) + self.gateway = GRPCGateway( + name=self.name, + grpc_server_options=self.args.grpc_server_options, + port=self.args.port, + ssl_keyfile=self.args.ssl_keyfile, + ssl_certfile=self.args.ssl_certfile, ) - await self._async_setup_server() - - async def _async_setup_server(self): - - import json - - graph_description = json.loads(self.args.graph_description) - graph_conditions = json.loads(self.args.graph_conditions) - deployments_addresses = json.loads(self.args.deployments_addresses) - deployments_disable_reduce = json.loads(self.args.deployments_disable_reduce) - - self.gateway_bff = GatewayBFF( - graph_representation=graph_description, - executor_addresses=deployments_addresses, - graph_conditions=graph_conditions, - deployments_disable_reduce=deployments_disable_reduce, + self.gateway.set_streamer( + args=self.args, timeout_send=self.timeout_send, - retries=self.args.retries, - compression=self.args.compression, - runtime_name=self.name, - prefetch=self.args.prefetch, - logger=self.logger, metrics_registry=self.metrics_registry, + runtime_name=self.name, ) - - jina_pb2_grpc.add_JinaRPCServicer_to_server( - self.gateway_bff._streamer, self.server - ) - jina_pb2_grpc.add_JinaGatewayDryRunRPCServicer_to_server(self, self.server) - jina_pb2_grpc.add_JinaInfoRPCServicer_to_server(self, self.server) - - service_names = ( - jina_pb2.DESCRIPTOR.services_by_name['JinaRPC'].full_name, - jina_pb2.DESCRIPTOR.services_by_name['JinaGatewayDryRunRPC'].full_name, - jina_pb2.DESCRIPTOR.services_by_name['JinaInfoRPC'].full_name, - reflection.SERVICE_NAME, - ) - # Mark all services as healthy. - health_pb2_grpc.add_HealthServicer_to_server(self._health_servicer, self.server) - - for service in service_names: - self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING) - reflection.enable_server_reflection(service_names, self.server) - - bind_addr = f'{__default_host__}:{self.args.port}' - - if self.args.ssl_keyfile and self.args.ssl_certfile: - with open(self.args.ssl_keyfile, 'rb') as f: - private_key = f.read() - with open(self.args.ssl_certfile, 'rb') as f: - certificate_chain = f.read() - - server_credentials = grpc.ssl_server_credentials( - ( - ( - private_key, - certificate_chain, - ), - ) - ) - self.server.add_secure_port(bind_addr, server_credentials) - elif ( - self.args.ssl_keyfile != self.args.ssl_certfile - ): # if we have only ssl_keyfile and not ssl_certfile or vice versa - raise ValueError( - f"you can't pass a ssl_keyfile without a ssl_certfile and vice versa" - ) - else: - self.server.add_insecure_port(bind_addr) - self.logger.debug(f'start server bound to {bind_addr}') - await self.server.start() + await self.gateway.setup_server() async def async_teardown(self): """Close the connection pool""" # usually async_cancel should already have been called, but then its a noop # if the runtime is stopped without a sigterm (e.g. as a context manager, this can happen) - self._health_servicer.enter_graceful_shutdown() - await self.gateway_bff.close() + await self.gateway.teardown() await self.async_cancel() async def async_cancel(self): """The async method to stop server.""" - await self.server.stop(0) + await self.gateway.stop_server() async def async_run_forever(self): """The async running of server.""" - await self.server.wait_for_termination() - - async def dry_run(self, empty, context) -> jina_pb2.StatusProto: - """ - Process the the call requested by having a dry run call to every Executor in the graph - - :param empty: The service expects an empty protobuf message - :param context: grpc context - :returns: the response request - """ - from docarray import DocumentArray - - from jina.clients.request import request_generator - from jina.enums import DataInputType - from jina.serve.executors import __dry_run_endpoint__ - - da = DocumentArray() - - try: - req_iterator = request_generator( - exec_endpoint=__dry_run_endpoint__, - data=da, - data_type=DataInputType.DOCUMENT, - ) - async for _ in self.gateway_bff.stream(request_iterator=req_iterator): - pass - status_message = StatusMessage() - status_message.set_code(jina_pb2.StatusProto.SUCCESS) - return status_message.proto - except Exception as ex: - status_message = StatusMessage() - status_message.set_exception(ex) - return status_message.proto - - async def _status(self, empty, context) -> jina_pb2.JinaInfoProto: - """ - Process the the call requested and return the JinaInfo of the Runtime - - :param empty: The service expects an empty protobuf message - :param context: grpc context - :returns: the response request - """ - infoProto = jina_pb2.JinaInfoProto() - version, env_info = get_full_version() - for k, v in version.items(): - infoProto.jina[k] = str(v) - for k, v in env_info.items(): - infoProto.envs[k] = str(v) - return infoProto + await self.gateway.run_server() diff --git a/jina/serve/runtimes/gateway/grpc/gateway.py b/jina/serve/runtimes/gateway/grpc/gateway.py new file mode 100644 index 0000000000000..22013691e20b4 --- /dev/null +++ b/jina/serve/runtimes/gateway/grpc/gateway.py @@ -0,0 +1,157 @@ +from typing import Optional + +import grpc +from grpc_health.v1 import health, health_pb2, health_pb2_grpc +from grpc_reflection.v1alpha import reflection + +from jina import __default_host__ +from jina.helper import get_full_version +from jina.proto import jina_pb2, jina_pb2_grpc +from jina.types.request.status import StatusMessage + +from ....gateway import BaseGateway +from ...helper import _get_grpc_server_options + + +class GRPCGateway(BaseGateway): + """GRPC Gateway implementation""" + + def __init__( + self, + port: Optional[int] = None, + grpc_server_options: Optional[dict] = None, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + **kwargs, + ): + """Initialize the gateway + :param port: The port of the Gateway, which the client should connect to. + :param grpc_server_options: Dictionary of kwargs arguments that will be passed to the grpc server as options when starting the server, example : {'grpc.max_send_message_length': -1} + :param ssl_keyfile: the path to the key file + :param ssl_certfile: the path to the certificate file + :param kwargs: keyword args + """ + super().__init__(**kwargs) + self.port = port + self.grpc_server_options = grpc_server_options + self.ssl_keyfile = ssl_keyfile + self.ssl_certfile = ssl_certfile + self.server = grpc.aio.server( + options=_get_grpc_server_options(self.grpc_server_options) + ) + self.health_servicer = health.HealthServicer(experimental_non_blocking=True) + + async def setup_server(self): + """ + setup GRPC server + """ + jina_pb2_grpc.add_JinaRPCServicer_to_server( + self.streamer._streamer, self.server + ) + + jina_pb2_grpc.add_JinaGatewayDryRunRPCServicer_to_server(self, self.server) + jina_pb2_grpc.add_JinaInfoRPCServicer_to_server(self, self.server) + + service_names = ( + jina_pb2.DESCRIPTOR.services_by_name['JinaRPC'].full_name, + jina_pb2.DESCRIPTOR.services_by_name['JinaGatewayDryRunRPC'].full_name, + jina_pb2.DESCRIPTOR.services_by_name['JinaInfoRPC'].full_name, + reflection.SERVICE_NAME, + ) + # Mark all services as healthy. + health_pb2_grpc.add_HealthServicer_to_server(self.health_servicer, self.server) + + for service in service_names: + self.health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING) + reflection.enable_server_reflection(service_names, self.server) + + bind_addr = f'{__default_host__}:{self.port}' + + if self.ssl_keyfile and self.ssl_certfile: + with open(self.ssl_keyfile, 'rb') as f: + private_key = f.read() + with open(self.ssl_certfile, 'rb') as f: + certificate_chain = f.read() + + server_credentials = grpc.ssl_server_credentials( + ( + ( + private_key, + certificate_chain, + ), + ) + ) + self.server.add_secure_port(bind_addr, server_credentials) + elif ( + self.ssl_keyfile != self.ssl_certfile + ): # if we have only ssl_keyfile and not ssl_certfile or vice versa + raise ValueError( + f"you can't pass a ssl_keyfile without a ssl_certfile and vice versa" + ) + else: + self.server.add_insecure_port(bind_addr) + self.logger.debug(f'start server bound to {bind_addr}') + await self.server.start() + + async def teardown(self): + """Free other resources allocated with the server, e.g, gateway object, ...""" + await super().teardown() + self.health_servicer.enter_graceful_shutdown() + + async def stop_server(self): + """ + Stop GRPC server + """ + await self.server.stop(0) + + async def run_server(self): + """Run GRPC server forever""" + await self.server.wait_for_termination() + + async def dry_run(self, empty, context) -> jina_pb2.StatusProto: + """ + Process the the call requested by having a dry run call to every Executor in the graph + + :param empty: The service expects an empty protobuf message + :param context: grpc context + :returns: the response request + """ + from docarray import DocumentArray + + from jina.clients.request import request_generator + from jina.enums import DataInputType + from jina.serve.executors import __dry_run_endpoint__ + + da = DocumentArray() + + try: + req_iterator = request_generator( + exec_endpoint=__dry_run_endpoint__, + data=da, + data_type=DataInputType.DOCUMENT, + ) + async for _ in self.streamer.stream(request_iterator=req_iterator): + pass + status_message = StatusMessage() + status_message.set_code(jina_pb2.StatusProto.SUCCESS) + return status_message.proto + except Exception as ex: + status_message = StatusMessage() + status_message.set_exception(ex) + return status_message.proto + + async def _status(self, empty, context) -> jina_pb2.JinaInfoProto: + """ + Process the the call requested and return the JinaInfo of the Runtime + + :param empty: The service expects an empty protobuf message + :param context: grpc context + :returns: the response request + """ + infoProto = jina_pb2.JinaInfoProto() + version, env_info = get_full_version() + for k, v in version.items(): + infoProto.jina[k] = str(v) + for k, v in env_info.items(): + infoProto.envs[k] = str(v) + return infoProto diff --git a/jina/serve/runtimes/gateway/http/__init__.py b/jina/serve/runtimes/gateway/http/__init__.py index fb0e7d286ecdc..0cb2903194f9d 100644 --- a/jina/serve/runtimes/gateway/http/__init__.py +++ b/jina/serve/runtimes/gateway/http/__init__.py @@ -1,14 +1,14 @@ import asyncio -import logging import os from jina import __default_host__ -from jina.importer import ImportExtensions from jina.serve.runtimes.gateway import GatewayRuntime from jina.serve.runtimes.gateway.http.app import get_fastapi_app __all__ = ['HTTPGatewayRuntime'] +from jina.serve.runtimes.gateway.http.gateway import HTTPGateway + class HTTPGatewayRuntime(GatewayRuntime): """Runtime for HTTP interface.""" @@ -19,88 +19,45 @@ async def async_setup(self): Setup the uvicorn server. """ - with ImportExtensions(required=True): - from uvicorn import Config, Server - - class UviServer(Server): - """The uvicorn server.""" - - async def setup(self, sockets=None): - """ - Setup uvicorn server. - - :param sockets: sockets of server. - """ - config = self.config - if not config.loaded: - config.load() - self.lifespan = config.lifespan_class(config) - self.install_signal_handlers() - await self.startup(sockets=sockets) - if self.should_exit: - return - - async def serve(self, **kwargs): - """ - Start the server. - - :param kwargs: keyword arguments - """ - await self.main_loop() - - if 'CICD_JINA_DISABLE_HEALTHCHECK_LOGS' in os.environ: - - class _EndpointFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - # NOTE: space is important after `GET /`, else all logs will be disabled. - return record.getMessage().find("GET / ") == -1 - - # Filter out healthcheck endpoint `GET /` - logging.getLogger("uvicorn.access").addFilter(_EndpointFilter()) - - from jina.helper import extend_rest_interface - - uvicorn_kwargs = self.args.uvicorn_kwargs or {} - - for ssl_file in ['ssl_keyfile', 'ssl_certfile']: - if getattr(self.args, ssl_file): - if ssl_file not in uvicorn_kwargs.keys(): - uvicorn_kwargs[ssl_file] = getattr(self.args, ssl_file) - - self._server = UviServer( - config=Config( - app=extend_rest_interface( - get_fastapi_app( - args=self.args, - logger=self.logger, - timeout_send=self.timeout_send, - metrics_registry=self.metrics_registry, - ) - ), - host=__default_host__, - port=self.args.port, - log_level=os.getenv('JINA_LOG_LEVEL', 'error').lower(), - **uvicorn_kwargs - ) + self.gateway = HTTPGateway( + name=self.name, + port=self.args.port, + title=self.args.title, + description=self.args.description, + no_debug_endpoints=self.args.no_debug_endpoints, + no_crud_endpoints=self.args.no_crud_endpoints, + expose_endpoints=self.args.expose_endpoints, + expose_graphql_endpoint=self.args.expose_graphql_endpoint, + cors=self.args.cors, + ssl_keyfile=self.args.ssl_keyfile, + ssl_certfile=self.args.ssl_certfile, + uvicorn_kwargs=self.args.uvicorn_kwargs, ) - await self._server.setup() - async def async_run_forever(self): - """Running method of the server.""" - await self._server.serve() + self.gateway.set_streamer( + args=self.args, + timeout_send=self.timeout_send, + metrics_registry=self.metrics_registry, + runtime_name=self.args.name, + ) + await self.gateway.setup_server() async def _wait_for_cancel(self): """Do NOT override this method when inheriting from :class:`GatewayPod`""" # handle terminate signals - while not self.is_cancel.is_set() and not self._server.should_exit: + while not self.is_cancel.is_set() and not self.gateway.should_exit: await asyncio.sleep(0.1) await self.async_cancel() async def async_teardown(self): """Shutdown the server.""" - await self._server.shutdown() + await self.gateway.teardown() async def async_cancel(self): """Stop the server.""" - self._server.should_exit = True + await self.gateway.stop_server() + + async def async_run_forever(self): + """Running method of the server.""" + await self.gateway.run_server() diff --git a/jina/serve/runtimes/gateway/http/app.py b/jina/serve/runtimes/gateway/http/app.py index 558547a4662b8..2ff5f9175d714 100644 --- a/jina/serve/runtimes/gateway/http/app.py +++ b/jina/serve/runtimes/gateway/http/app.py @@ -11,22 +11,34 @@ from jina.logging.logger import JinaLogger if TYPE_CHECKING: - from prometheus_client import CollectorRegistry + from jina.serve.streamer import GatewayStreamer def get_fastapi_app( - args: 'argparse.Namespace', - logger: 'JinaLogger', - timeout_send: Optional[float] = None, - metrics_registry: Optional['CollectorRegistry'] = None, + streamer: 'GatewayStreamer', + title: str, + description: str, + no_debug_endpoints: bool, + no_crud_endpoints: bool, + expose_endpoints: bool, + expose_graphql_endpoint: bool, + cors: bool, + logger: 'JinaLogger', ): """ Get the app from FastAPI as the REST interface. - :param args: passed arguments. + :param streamer: gateway streamer object + :param title: The title of this HTTP server. It will be used in automatics docs such as Swagger UI. + :param description: The description of this HTTP server. It will be used in automatics docs such as Swagger UI. + :param no_debug_endpoints: If set, `/status` `/post` endpoints are removed from HTTP interface. + :param no_crud_endpoints: If set, `/index`, `/search`, `/update`, `/delete` endpoints are removed from HTTP interface. + + Any executor that has `@requests(on=...)` bind with those values will receive data requests. + :param expose_endpoints: A JSON string that represents a map from executor endpoints (`@requests(on=...)`) to HTTP endpoints. + :param expose_graphql_endpoint: If set, /graphql endpoint is added to HTTP interface. + :param cors: If set, a CORS middleware is added to FastAPI frontend to allow cross-origin access. :param logger: Jina logger. - :param timeout_send: Timeout to be used when sending to Executors - :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler :return: fastapi app """ with ImportExtensions(required=True): @@ -40,14 +52,14 @@ def get_fastapi_app( ) app = FastAPI( - title=args.title or 'My Jina Service', - description=args.description - or 'This is my awesome service. You can set `title` and `description` in your `Flow` or `Gateway` ' - 'to customize the title and description.', + title=title or 'My Jina Service', + description=description + or 'This is my awesome service. You can set `title` and `description` in your `Flow` or `Gateway` ' + 'to customize the title and description.', version=__version__, ) - if args.cors: + if cors: app.add_middleware( CORSMiddleware, allow_origins=['*'], @@ -57,38 +69,17 @@ def get_fastapi_app( ) logger.warning('CORS is enabled. This service is accessible from any website!') - from jina.serve.bff import GatewayBFF - - import json - - graph_description = json.loads(args.graph_description) - graph_conditions = json.loads(args.graph_conditions) - deployments_addresses = json.loads(args.deployments_addresses) - deployments_disable_reduce = json.loads(args.deployments_disable_reduce) - - gateway_bff = GatewayBFF(graph_representation=graph_description, - executor_addresses=deployments_addresses, - graph_conditions=graph_conditions, - deployments_disable_reduce=deployments_disable_reduce, - timeout_send=timeout_send, - retries=args.retries, - compression=args.compression, - runtime_name=args.name, - prefetch=args.prefetch, - logger=logger, - metrics_registry=metrics_registry) - @app.on_event('shutdown') async def _shutdown(): - await gateway_bff.close() + await streamer.close() openapi_tags = [] - if not args.no_debug_endpoints: + if not no_debug_endpoints: openapi_tags.append( { 'name': 'Debug', 'description': 'Debugging interface. In production, you should hide them by setting ' - '`--no-debug-endpoints` in `Flow`/`Gateway`.', + '`--no-debug-endpoints` in `Flow`/`Gateway`.', } ) @@ -108,6 +99,7 @@ async def _gateway_health(): return {} from docarray import DocumentArray + from jina.proto import jina_pb2 from jina.serve.executors import __dry_run_endpoint__ from jina.serve.runtimes.gateway.http.models import ( @@ -119,7 +111,7 @@ async def _gateway_health(): @app.get( path='/dry_run', summary='Get the readiness of Jina Flow service, sends an empty DocumentArray to the complete Flow to ' - 'validate connectivity', + 'validate connectivity', response_model=PROTO_TO_PYDANTIC_MODELS.StatusProto, ) async def _flow_health(): @@ -176,7 +168,7 @@ async def _status(): # do not add response_model here, this debug endpoint should not restricts the response model ) async def post( - body: JinaEndpointRequestModel, response: Response + body: JinaEndpointRequestModel, response: Response ): # 'response' is a FastAPI response, not a Jina response """ Post a data request to some endpoint. @@ -278,12 +270,12 @@ async def foo(body: JinaRequestModel): ) return result - if not args.no_crud_endpoints: + if not no_crud_endpoints: openapi_tags.append( { 'name': 'CRUD', 'description': 'CRUD interface. If your service does not implement those interfaces, you can should ' - 'hide them by setting `--no-crud-endpoints` in `Flow`/`Gateway`.', + 'hide them by setting `--no-crud-endpoints` in `Flow`/`Gateway`.', } ) crud = { @@ -302,16 +294,17 @@ async def foo(body: JinaRequestModel): if openapi_tags: app.openapi_tags = openapi_tags - if args.expose_endpoints: - endpoints = json.loads(args.expose_endpoints) # type: Dict[str, Dict] + if expose_endpoints: + endpoints = json.loads(expose_endpoints) # type: Dict[str, Dict] for k, v in endpoints.items(): expose_executor_endpoint(exec_endpoint=k, **v) - if args.expose_graphql_endpoint: + if expose_graphql_endpoint: with ImportExtensions(required=True): from dataclasses import asdict import strawberry + from docarray import DocumentArray from docarray.document.strawberry_type import ( JSONScalar, StrawberryDocument, @@ -319,10 +312,8 @@ async def foo(body: JinaRequestModel): ) from strawberry.fastapi import GraphQLRouter - from docarray import DocumentArray - async def get_docs_from_endpoint( - data, target_executor, parameters, exec_endpoint + data, target_executor, parameters, exec_endpoint ): req_generator_input = { 'data': [asdict(d) for d in data], @@ -333,8 +324,8 @@ async def get_docs_from_endpoint( } if ( - req_generator_input['data'] is not None - and 'docs' in req_generator_input['data'] + req_generator_input['data'] is not None + and 'docs' in req_generator_input['data'] ): req_generator_input['data'] = req_generator_input['data']['docs'] try: @@ -352,11 +343,11 @@ async def get_docs_from_endpoint( class Mutation: @strawberry.mutation async def docs( - self, - data: Optional[List[StrawberryDocumentInput]] = None, - target_executor: Optional[str] = None, - parameters: Optional[JSONScalar] = None, - exec_endpoint: str = '/search', + self, + data: Optional[List[StrawberryDocumentInput]] = None, + target_executor: Optional[str] = None, + parameters: Optional[JSONScalar] = None, + exec_endpoint: str = '/search', ) -> List[StrawberryDocument]: return await get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint @@ -366,11 +357,11 @@ async def docs( class Query: @strawberry.field async def docs( - self, - data: Optional[List[StrawberryDocumentInput]] = None, - target_executor: Optional[str] = None, - parameters: Optional[JSONScalar] = None, - exec_endpoint: str = '/search', + self, + data: Optional[List[StrawberryDocumentInput]] = None, + target_executor: Optional[str] = None, + parameters: Optional[JSONScalar] = None, + exec_endpoint: str = '/search', ) -> List[StrawberryDocument]: return await get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint @@ -386,7 +377,7 @@ async def _get_singleton_result(request_iterator) -> Dict: :param request_iterator: request iterator, with length of 1 :return: the first result from the request iterator """ - async for k in gateway_bff.stream(request_iterator=request_iterator): + async for k in streamer.stream(request_iterator=request_iterator): request_dict = k.to_dict() return request_dict diff --git a/jina/serve/runtimes/gateway/http/gateway.py b/jina/serve/runtimes/gateway/http/gateway.py new file mode 100644 index 0000000000000..1d6ab534af742 --- /dev/null +++ b/jina/serve/runtimes/gateway/http/gateway.py @@ -0,0 +1,162 @@ +import logging +import os +from typing import Optional + +from jina import __default_host__ +from jina.importer import ImportExtensions + +from ....gateway import BaseGateway +from . import get_fastapi_app + + +class HTTPGateway(BaseGateway): + """HTTP Gateway implementation""" + + def __init__( + self, + port: Optional[int] = None, + title: Optional[str] = None, + description: Optional[str] = None, + no_debug_endpoints: Optional[bool] = False, + no_crud_endpoints: Optional[bool] = False, + expose_endpoints: Optional[str] = None, + expose_graphql_endpoint: Optional[bool] = False, + cors: Optional[bool] = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + uvicorn_kwargs: Optional[dict] = None, + **kwargs + ): + """Initialize the gateway + Get the app from FastAPI as the REST interface. + :param port: The port of the Gateway, which the client should connect to. + :param title: The title of this HTTP server. It will be used in automatics docs such as Swagger UI. + :param description: The description of this HTTP server. It will be used in automatics docs such as Swagger UI. + :param no_debug_endpoints: If set, `/status` `/post` endpoints are removed from HTTP interface. + :param no_crud_endpoints: If set, `/index`, `/search`, `/update`, `/delete` endpoints are removed from HTTP interface. + + Any executor that has `@requests(on=...)` bind with those values will receive data requests. + :param expose_endpoints: A JSON string that represents a map from executor endpoints (`@requests(on=...)`) to HTTP endpoints. + :param expose_graphql_endpoint: If set, /graphql endpoint is added to HTTP interface. + :param cors: If set, a CORS middleware is added to FastAPI frontend to allow cross-origin access. + :param ssl_keyfile: the path to the key file + :param ssl_certfile: the path to the certificate file + :param uvicorn_kwargs: Dictionary of kwargs arguments that will be passed to Uvicorn server when starting the server + :param kwargs: keyword args + """ + super().__init__(**kwargs) + self.port = port + self.title = title + self.description = description + self.no_debug_endpoints = no_debug_endpoints + self.no_crud_endpoints = no_crud_endpoints + self.expose_endpoints = expose_endpoints + self.expose_graphql_endpoint = expose_graphql_endpoint + self.cors = cors + self.ssl_keyfile = ssl_keyfile + self.ssl_certfile = ssl_certfile + self.uvicorn_kwargs = uvicorn_kwargs + + async def setup_server(self): + """ + Initialize and return GRPC server + """ + from jina.helper import extend_rest_interface + + self.app = extend_rest_interface( + get_fastapi_app( + streamer=self.streamer, + title=self.title, + description=self.description, + no_debug_endpoints=self.no_debug_endpoints, + no_crud_endpoints=self.no_crud_endpoints, + expose_endpoints=self.expose_endpoints, + expose_graphql_endpoint=self.expose_graphql_endpoint, + cors=self.cors, + logger=self.logger, + ) + ) + + with ImportExtensions(required=True): + from uvicorn import Config, Server + + class UviServer(Server): + """The uvicorn server.""" + + async def setup(self, sockets=None): + """ + Setup uvicorn server. + + :param sockets: sockets of server. + """ + config = self.config + if not config.loaded: + config.load() + self.lifespan = config.lifespan_class(config) + self.install_signal_handlers() + await self.startup(sockets=sockets) + if self.should_exit: + return + + async def serve(self, **kwargs): + """ + Start the server. + + :param kwargs: keyword arguments + """ + await self.main_loop() + + if 'CICD_JINA_DISABLE_HEALTHCHECK_LOGS' in os.environ: + + class _EndpointFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + # NOTE: space is important after `GET /`, else all logs will be disabled. + return record.getMessage().find("GET / ") == -1 + + # Filter out healthcheck endpoint `GET /` + logging.getLogger("uvicorn.access").addFilter(_EndpointFilter()) + + uvicorn_kwargs = self.uvicorn_kwargs or {} + + if self.ssl_keyfile and 'ssl_keyfile' not in uvicorn_kwargs.keys(): + uvicorn_kwargs['ssl_keyfile'] = self.ssl_keyfile + + if self.ssl_certfile and 'ssl_certfile' not in uvicorn_kwargs.keys(): + uvicorn_kwargs['ssl_certfile'] = self.ssl_certfile + + self.server = UviServer( + config=Config( + app=self.app, + host=__default_host__, + port=self.port, + log_level=os.getenv('JINA_LOG_LEVEL', 'error').lower(), + **uvicorn_kwargs, + ) + ) + + await self.server.setup() + + async def teardown(self): + """ + Free resources allocated when setting up HTTP server + """ + await super().teardown() + await self.server.shutdown() + + async def stop_server(self): + """ + Stop HTTP server + """ + self.server.should_exit = True + + async def run_server(self): + """Run HTTP server forever""" + await self.server.serve() + + @property + def should_exit(self) -> bool: + """ + Boolean flag that indicates whether the gateway server should exit or not + :return: boolean flag + """ + return self.server.should_exit diff --git a/jina/serve/runtimes/gateway/websocket/__init__.py b/jina/serve/runtimes/gateway/websocket/__init__.py index b36e0957f937a..3f132f73c164a 100644 --- a/jina/serve/runtimes/gateway/websocket/__init__.py +++ b/jina/serve/runtimes/gateway/websocket/__init__.py @@ -1,14 +1,12 @@ import asyncio -import logging -import os -from jina import __default_host__ -from jina.importer import ImportExtensions from jina.serve.runtimes.gateway import GatewayRuntime from jina.serve.runtimes.gateway.websocket.app import get_fastapi_app __all__ = ['WebSocketGatewayRuntime'] +from jina.serve.runtimes.gateway.websocket.gateway import WebSocketGateway + class WebSocketGatewayRuntime(GatewayRuntime): """Runtime for Websocket interface.""" @@ -19,88 +17,40 @@ async def async_setup(self): Setup the uvicorn server. """ - with ImportExtensions(required=True): - from uvicorn import Config, Server - - class UviServer(Server): - """The uvicorn server.""" - - async def setup(self, sockets=None): - """ - Setup uvicorn server. - - :param sockets: sockets of server. - """ - config = self.config - if not config.loaded: - config.load() - self.lifespan = config.lifespan_class(config) - self.install_signal_handlers() - await self.startup(sockets=sockets) - if self.should_exit: - return - - async def serve(self, **kwargs): - """ - Start the server. - - :param kwargs: keyword arguments - """ - await self.main_loop() - - if 'CICD_JINA_DISABLE_HEALTHCHECK_LOGS' in os.environ: - - class _EndpointFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - # NOTE: space is important after `GET /`, else all logs will be disabled. - return record.getMessage().find("GET / ") == -1 - - # Filter out healthcheck endpoint `GET /` - logging.getLogger("uvicorn.access").addFilter(_EndpointFilter()) - from jina.helper import extend_rest_interface - - uvicorn_kwargs = self.args.uvicorn_kwargs or {} - for ssl_file in ['ssl_keyfile', 'ssl_certfile']: - if getattr(self.args, ssl_file): - if ssl_file not in uvicorn_kwargs.keys(): - uvicorn_kwargs[ssl_file] = getattr(self.args, ssl_file) - - self._server = UviServer( - config=Config( - app=extend_rest_interface( - get_fastapi_app( - args=self.args, - logger=self.logger, - timeout_send=self.timeout_send, - metrics_registry=self.metrics_registry, - ) - ), - host=__default_host__, - port=self.args.port, - ws_max_size=1024 * 1024 * 1024, - log_level=os.getenv('JINA_LOG_LEVEL', 'error').lower(), - **uvicorn_kwargs - ) + self.gateway = WebSocketGateway( + name=self.name, + port=self.args.port, + ssl_keyfile=self.args.ssl_keyfile, + ssl_certfile=self.args.ssl_certfile, + uvicorn_kwargs=self.args.uvicorn_kwargs, + logger=self.logger, ) - await self._server.setup() - async def async_run_forever(self): - """Running method of ther server.""" - await self._server.serve() + self.gateway.set_streamer( + args=self.args, + timeout_send=self.timeout_send, + metrics_registry=self.metrics_registry, + runtime_name=self.args.name, + ) + await self.gateway.setup_server() async def _wait_for_cancel(self): """Do NOT override this method when inheriting from :class:`GatewayPod`""" # handle terminate signals - while not self.is_cancel.is_set() and not self._server.should_exit: + while not self.is_cancel.is_set() and not self.gateway.should_exit: await asyncio.sleep(0.1) await self.async_cancel() async def async_teardown(self): """Shutdown the server.""" - await self._server.shutdown() + await self.gateway.teardown() async def async_cancel(self): """Stop the server.""" - self._server.should_exit = True + await self.gateway.stop_server() + + async def async_run_forever(self): + """Running method of ther server.""" + await self.gateway.run_server() diff --git a/jina/serve/runtimes/gateway/websocket/app.py b/jina/serve/runtimes/gateway/websocket/app.py index dade64791c6e8..2793e18898b51 100644 --- a/jina/serve/runtimes/gateway/websocket/app.py +++ b/jina/serve/runtimes/gateway/websocket/app.py @@ -11,7 +11,7 @@ from jina.types.request.status import StatusMessage if TYPE_CHECKING: - from prometheus_client import CollectorRegistry + from jina.serve.streamer import GatewayStreamer def _fits_ws_close_msg(msg: str): @@ -22,18 +22,14 @@ def _fits_ws_close_msg(msg: str): def get_fastapi_app( - args: 'argparse.Namespace', - logger: 'JinaLogger', - timeout_send: Optional[float] = None, - metrics_registry: Optional['CollectorRegistry'] = None, + streamer: 'GatewayStreamer', + logger: 'JinaLogger', ): """ Get the app from FastAPI as the Websocket interface. - :param args: passed arguments. + :param streamer: gateway streamer object. :param logger: Jina logger. - :param timeout_send: Timeout to be used when sending to Executors - :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler :return: fastapi app """ @@ -101,7 +97,7 @@ async def iter(self, websocket: WebSocket) -> AsyncIterator[Any]: pass async def send( - self, websocket: WebSocket, data: Union[DataRequest, StatusMessage] + self, websocket: WebSocket, data: Union[DataRequest, StatusMessage] ) -> None: subprotocol = self.protocol_dict[self.get_client(websocket)] if subprotocol == WebsocketSubProtocols.JSON: @@ -113,26 +109,6 @@ async def send( app = FastAPI() - from jina.serve.bff import GatewayBFF - import json - - graph_description = json.loads(args.graph_description) - graph_conditions = json.loads(args.graph_conditions) - deployments_addresses = json.loads(args.deployments_addresses) - deployments_disable_reduce = json.loads(args.deployments_disable_reduce) - - gateway_bff = GatewayBFF(graph_representation=graph_description, - executor_addresses=deployments_addresses, - graph_conditions=graph_conditions, - deployments_disable_reduce=deployments_disable_reduce, - timeout_send=timeout_send, - retries=args.retries, - compression=args.compression, - runtime_name=args.name, - prefetch=args.prefetch, - logger=logger, - metrics_registry=metrics_registry) - @app.get( path='/', summary='Get the health of Jina service', @@ -166,11 +142,11 @@ async def _status(): @app.on_event('shutdown') async def _shutdown(): - await gateway_bff.close() + await streamer.close() @app.websocket('/') async def websocket_endpoint( - websocket: WebSocket, response: Response + websocket: WebSocket, response: Response ): # 'response' is a FastAPI response, not a Jina response await manager.connect(websocket) @@ -198,7 +174,7 @@ async def req_iter(): yield DataRequest(request) try: - async for msg in gateway_bff.stream(request_iterator=req_iter()): + async for msg in streamer.stream(request_iterator=req_iter()): await manager.send(websocket, msg) except InternalNetworkError as err: import grpc @@ -228,11 +204,12 @@ async def _get_singleton_result(request_iterator) -> Dict: :param request_iterator: request iterator, with length of 1 :return: the first result from the request iterator """ - async for k in gateway_bff.stream(request_iterator=request_iterator): + async for k in streamer.stream(request_iterator=request_iterator): request_dict = k.to_dict() return request_dict from docarray import DocumentArray + from jina.proto import jina_pb2 from jina.serve.executors import __dry_run_endpoint__ from jina.serve.runtimes.gateway.http.models import PROTO_TO_PYDANTIC_MODELS @@ -240,7 +217,7 @@ async def _get_singleton_result(request_iterator) -> Dict: @app.get( path='/dry_run', summary='Get the readiness of Jina Flow service, sends an empty DocumentArray to the complete Flow to ' - 'validate connectivity', + 'validate connectivity', response_model=PROTO_TO_PYDANTIC_MODELS.StatusProto, ) async def _dry_run_http(): @@ -270,7 +247,7 @@ async def _dry_run_http(): @app.websocket('/dry_run') async def websocket_endpoint( - websocket: WebSocket, response: Response + websocket: WebSocket, response: Response ): # 'response' is a FastAPI response, not a Jina response from jina.proto import jina_pb2 from jina.serve.executors import __dry_run_endpoint__ @@ -279,12 +256,12 @@ async def websocket_endpoint( da = DocumentArray() try: - async for _ in gateway_bff.stream( - request_iterator=request_generator( - exec_endpoint=__dry_run_endpoint__, - data=da, - data_type=DataInputType.DOCUMENT, - ) + async for _ in streamer.stream( + request_iterator=request_generator( + exec_endpoint=__dry_run_endpoint__, + data=da, + data_type=DataInputType.DOCUMENT, + ) ): pass status_message = StatusMessage() diff --git a/jina/serve/runtimes/gateway/websocket/gateway.py b/jina/serve/runtimes/gateway/websocket/gateway.py new file mode 100644 index 0000000000000..7295adf486aa0 --- /dev/null +++ b/jina/serve/runtimes/gateway/websocket/gateway.py @@ -0,0 +1,130 @@ +import logging +import os +from typing import Optional + +from jina import __default_host__ +from jina.importer import ImportExtensions + +from ....gateway import BaseGateway +from . import get_fastapi_app + + +class WebSocketGateway(BaseGateway): + """WebSocket Gateway implementation""" + + def __init__( + self, + port: Optional[int] = None, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + uvicorn_kwargs: Optional[dict] = None, + **kwargs + ): + """Initialize the gateway + :param port: The port of the Gateway, which the client should connect to. + :param ssl_keyfile: the path to the key file + :param ssl_certfile: the path to the certificate file + :param uvicorn_kwargs: Dictionary of kwargs arguments that will be passed to Uvicorn server when starting the server + :param kwargs: keyword args + """ + super().__init__(**kwargs) + self.port = port + self.ssl_keyfile = ssl_keyfile + self.ssl_certfile = ssl_certfile + self.uvicorn_kwargs = uvicorn_kwargs + + async def setup_server(self): + """ + Setup WebSocket Server + """ + from jina.helper import extend_rest_interface + + self.app = extend_rest_interface( + get_fastapi_app( + streamer=self.streamer, + logger=self.logger, + ) + ) + + with ImportExtensions(required=True): + from uvicorn import Config, Server + + class UviServer(Server): + """The uvicorn server.""" + + async def setup(self, sockets=None): + """ + Setup uvicorn server. + + :param sockets: sockets of server. + """ + config = self.config + if not config.loaded: + config.load() + self.lifespan = config.lifespan_class(config) + self.install_signal_handlers() + await self.startup(sockets=sockets) + if self.should_exit: + return + + async def serve(self, **kwargs): + """ + Start the server. + + :param kwargs: keyword arguments + """ + await self.main_loop() + + if 'CICD_JINA_DISABLE_HEALTHCHECK_LOGS' in os.environ: + + class _EndpointFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + # NOTE: space is important after `GET /`, else all logs will be disabled. + return record.getMessage().find("GET / ") == -1 + + # Filter out healthcheck endpoint `GET /` + logging.getLogger("uvicorn.access").addFilter(_EndpointFilter()) + + uvicorn_kwargs = self.uvicorn_kwargs or {} + + if self.ssl_keyfile and 'ssl_keyfile' not in uvicorn_kwargs.keys(): + uvicorn_kwargs['ssl_keyfile'] = self.ssl_keyfile + + if self.ssl_certfile and 'ssl_certfile' not in uvicorn_kwargs.keys(): + uvicorn_kwargs['ssl_certfile'] = self.ssl_certfile + + self.server = UviServer( + config=Config( + app=self.app, + host=__default_host__, + port=self.port, + ws_max_size=1024 * 1024 * 1024, + log_level=os.getenv('JINA_LOG_LEVEL', 'error').lower(), + **uvicorn_kwargs, + ) + ) + + await self.server.setup() + + async def teardown(self): + """Free other resources allocated with the server, e.g, gateway object, ...""" + await super().teardown() + await self.server.shutdown() + + async def stop_server(self): + """ + Stop WebSocket server + """ + self.server.should_exit = True + + async def run_server(self): + """Run WebSocket server forever""" + await self.server.serve() + + @property + def should_exit(self) -> bool: + """ + Boolean flag that indicates whether the gateway server should exit or not + :return: boolean flag + """ + return self.server.should_exit diff --git a/jina/serve/bff.py b/jina/serve/streamer.py similarity index 77% rename from jina/serve/bff.py rename to jina/serve/streamer.py index 75b62e866bbe4..a4f8e5b62af28 100644 --- a/jina/serve/bff.py +++ b/jina/serve/streamer.py @@ -1,38 +1,37 @@ -from typing import Optional, Dict, Union, List, TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, List, Optional, Union -from jina.serve.runtimes.gateway.graph.topology_graph import TopologyGraph -from jina.serve.networking import GrpcConnectionPool +from docarray import DocumentArray from jina.logging.logger import JinaLogger +from jina.serve.networking import GrpcConnectionPool +from jina.serve.runtimes.gateway.graph.topology_graph import TopologyGraph from jina.serve.runtimes.gateway.request_handling import RequestHandler from jina.serve.stream import RequestStreamer -from docarray import DocumentArray - -__all__ = ['GatewayBFF'] +__all__ = ['GatewayStreamer'] if TYPE_CHECKING: from prometheus_client import CollectorRegistry -class GatewayBFF: +class GatewayStreamer: """ - Wrapper object to be used in a BFF or in the Gateway. Naming to be defined + Wrapper object to be used in a Custom Gateway. Naming to be defined """ def __init__( - self, - graph_representation: Dict, - executor_addresses: Dict[str, Union[str, List[str]]], - graph_conditions: Dict = {}, - deployments_disable_reduce: List[str] = [], - timeout_send: Optional[float] = None, - retries: int = 0, - compression: Optional[str] = None, - runtime_name: str = 'gateway_bff', - prefetch: int = 0, - logger: Optional['JinaLogger'] = None, - metrics_registry: Optional['CollectorRegistry'] = None, + self, + graph_representation: Dict, + executor_addresses: Dict[str, Union[str, List[str]]], + graph_conditions: Dict = {}, + deployments_disable_reduce: List[str] = [], + timeout_send: Optional[float] = None, + retries: int = 0, + compression: Optional[str] = None, + runtime_name: str = 'custom gateway', + prefetch: int = 0, + logger: Optional['JinaLogger'] = None, + metrics_registry: Optional['CollectorRegistry'] = None, ): """ :param graph_representation: A dictionary describing the topology of the Deployments. 2 special nodes are expected, the name `start-gateway` and `end-gateway` to @@ -49,9 +48,16 @@ def __init__( :param logger: Optional logger that can be used for logging :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics """ - topology_graph = self._create_topology_graph(graph_representation, graph_conditions, - deployments_disable_reduce, timeout_send, retries) - self._connection_pool = self._create_connection_pool(executor_addresses, compression, metrics_registry, logger) + topology_graph = self._create_topology_graph( + graph_representation, + graph_conditions, + deployments_disable_reduce, + timeout_send, + retries, + ) + self._connection_pool = self._create_connection_pool( + executor_addresses, compression, metrics_registry, logger + ) request_handler = RequestHandler(metrics_registry, runtime_name) self._streamer = RequestStreamer( @@ -64,8 +70,14 @@ def __init__( ) self._streamer.Call = self._streamer.stream - def _create_topology_graph(self, graph_description, graph_conditions, deployments_disable_reduce, timeout_send, - retries): + def _create_topology_graph( + self, + graph_description, + graph_conditions, + deployments_disable_reduce, + timeout_send, + retries, + ): # check if it should be in K8s, maybe ConnectionPoolFactory to be created return TopologyGraph( graph_representation=graph_description, @@ -75,7 +87,9 @@ def _create_topology_graph(self, graph_description, graph_conditions, deployment retries=retries, ) - def _create_connection_pool(self, deployments_addresses, compression, metrics_registry, logger): + def _create_connection_pool( + self, deployments_addresses, compression, metrics_registry, logger + ): # add the connections needed connection_pool = GrpcConnectionPool( logger=logger, @@ -100,9 +114,15 @@ def stream(self, *args, **kwargs): """ return self._streamer.stream(*args, **kwargs) - async def stream_docs(self, docs: DocumentArray, request_size: int, return_results: bool = False, - exec_endpoint: Optional[str] = None, - target_executor: Optional[str] = None, parameters: Optional[Dict] = None): + async def stream_docs( + self, + docs: DocumentArray, + request_size: int, + return_results: bool = False, + exec_endpoint: Optional[str] = None, + target_executor: Optional[str] = None, + parameters: Optional[Dict] = None, + ): """ stream documents and stream responses back. @@ -115,6 +135,7 @@ async def stream_docs(self, docs: DocumentArray, request_size: int, return_resul :yield: Yields DocumentArrays or Responses from the Executors """ from jina.types.request.data import DataRequest + def _req_generator(): for docs_batch in docs.batch(batch_size=request_size, shuffle=False): req = DataRequest() diff --git a/tests/integration/clients_extra_kwargs/test_clients_post_extra_kwargs.py b/tests/integration/clients_extra_kwargs/test_clients_post_extra_kwargs.py index 21032a8814a5d..e6a67499a65d6 100644 --- a/tests/integration/clients_extra_kwargs/test_clients_post_extra_kwargs.py +++ b/tests/integration/clients_extra_kwargs/test_clients_post_extra_kwargs.py @@ -7,6 +7,7 @@ from jina.clients import Client from jina.excepts import PortAlreadyUsed from jina.helper import is_port_free +from jina.serve.runtimes.gateway.grpc import GRPCGateway from jina.serve.runtimes.gateway.grpc import GRPCGatewayRuntime as _GRPCGatewayRuntime from jina.serve.runtimes.helper import _get_grpc_server_options from tests import random_docs @@ -38,6 +39,14 @@ async def intercept_service(self, continuation, handler_call_details): return self._deny + class AlternativeGRPCGateway(GRPCGateway): + def __init__(self, *args, **kwargs): + super(AlternativeGRPCGateway, self).__init__(*args, **kwargs) + self.server = grpc.aio.server( + interceptors=(AuthInterceptor('access_key'),), + options=_get_grpc_server_options(self.grpc_server_options), + ) + class AlternativeGRPCGatewayRuntime(_GRPCGatewayRuntime): async def async_setup(self): """ @@ -51,12 +60,21 @@ async def async_setup(self): if not (is_port_free(__default_host__, self.args.port)): raise PortAlreadyUsed(f'port:{self.args.port}') - self.server = grpc.aio.server( - interceptors=(AuthInterceptor('access_key'),), - options=_get_grpc_server_options(self.args.grpc_server_options), + self.gateway = AlternativeGRPCGateway( + name=self.name, + grpc_server_options=self.args.grpc_server_options, + port=self.args.port, + ssl_keyfile=self.args.ssl_keyfile, + ssl_certfile=self.args.ssl_certfile, ) - await self._async_setup_server() + self.gateway.set_streamer( + args=self.args, + timeout_send=self.timeout_send, + metrics_registry=self.metrics_registry, + runtime_name=self.name, + ) + await self.gateway.setup_server() monkeypatch.setattr( 'jina.serve.runtimes.gateway.grpc.GRPCGatewayRuntime', diff --git a/tests/integration/gateway_bff/__init__.py b/tests/integration/gateway_streamer/__init__.py similarity index 100% rename from tests/integration/gateway_bff/__init__.py rename to tests/integration/gateway_streamer/__init__.py diff --git a/tests/integration/gateway_bff/test_gateway_bff.py b/tests/integration/gateway_streamer/test_gateway_streamer.py similarity index 67% rename from tests/integration/gateway_bff/test_gateway_bff.py rename to tests/integration/gateway_streamer/test_gateway_streamer.py index a7c337e94133c..ce7f2b831d48b 100644 --- a/tests/integration/gateway_bff/test_gateway_bff.py +++ b/tests/integration/gateway_streamer/test_gateway_streamer.py @@ -2,16 +2,14 @@ import pytest +from jina import DocumentArray, Executor, requests from jina.parsers import set_pod_parser from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime from jina.serve.runtimes.worker import WorkerRuntime -from jina.serve.bff import GatewayBFF -from jina import Executor, requests -from jina import DocumentArray +from jina.serve.streamer import GatewayStreamer class BffTestExecutor(Executor): - @requests def foo(self, docs, parameters, **kwargs): text_to_add = parameters.get('text_to_add', 'default ') @@ -52,24 +50,41 @@ def _setup(pod0_port, pod1_port): return pod0_process, pod1_process -@pytest.mark.parametrize('parameters, target_executor, expected_text', [ # (None, None, 'default default '), - ({'pod0__text_to_add': 'param_pod0 '}, None, 'param_pod0 default '), - (None, 'pod1', 'default '), ({'pod0__text_to_add': 'param_pod0 '}, 'pod0', 'param_pod0 ')]) +@pytest.mark.parametrize( + 'parameters, target_executor, expected_text', + [ # (None, None, 'default default '), + ({'pod0__text_to_add': 'param_pod0 '}, None, 'param_pod0 default '), + (None, 'pod1', 'default '), + ({'pod0__text_to_add': 'param_pod0 '}, 'pod0', 'param_pod0 '), + ], +) @pytest.mark.asyncio -async def test_gateway_bff(port_generator, parameters, target_executor, expected_text): +async def test_custom_gateway( + port_generator, parameters, target_executor, expected_text +): pod0_port = port_generator() pod1_port = port_generator() pod0_process, pod1_process = _setup(pod0_port, pod1_port) - graph_description = {"start-gateway": ["pod0"], "pod0": ["pod1"], "pod1": ["end-gateway"]} + graph_description = { + "start-gateway": ["pod0"], + "pod0": ["pod1"], + "pod1": ["end-gateway"], + } pod_addresses = {"pod0": [f"0.0.0.0:{pod0_port}"], "pod1": [f"0.0.0.0:{pod1_port}"]} # send requests to the gateway - gateway_bff = GatewayBFF(graph_representation=graph_description, executor_addresses=pod_addresses) + gateway_streamer = GatewayStreamer( + graph_representation=graph_description, executor_addresses=pod_addresses + ) try: input_da = DocumentArray.empty(60) resp = DocumentArray.empty(0) num_resp = 0 - async for r in gateway_bff.stream_docs(docs=input_da, request_size=10, parameters=parameters, - target_executor=target_executor): + async for r in gateway_streamer.stream_docs( + docs=input_da, + request_size=10, + parameters=parameters, + target_executor=target_executor, + ): num_resp += 1 resp.extend(r) @@ -84,4 +99,4 @@ async def test_gateway_bff(port_generator, parameters, target_executor, expected pod1_process.terminate() pod0_process.join() pod1_process.join() - await gateway_bff.close() + await gateway_streamer.close() diff --git a/tests/unit/serve/runtimes/gateway/grpc/test_grpc_gateway_runtime.py b/tests/unit/serve/runtimes/gateway/grpc/test_grpc_gateway_runtime.py index 8afe0723cdae4..077f47d9a65dc 100644 --- a/tests/unit/serve/runtimes/gateway/grpc/test_grpc_gateway_runtime.py +++ b/tests/unit/serve/runtimes/gateway/grpc/test_grpc_gateway_runtime.py @@ -189,7 +189,7 @@ async def _test(): req = request_generator( '/', DocumentArray([Document(text='client0-Request')]) ) - async for resp in runtime.gateway_bff.stream(request_iterator=req): + async for resp in runtime.gateway.streamer.stream(request_iterator=req): r.append(resp) return r @@ -241,7 +241,7 @@ async def _test(): req = request_generator( '/', DocumentArray([Document(text='client0-Request')]) ) - async for resp in runtime.gateway_bff.stream(request_iterator=req): + async for resp in runtime.gateway.streamer.stream(request_iterator=req): responses.append(resp) return responses @@ -295,7 +295,7 @@ async def _test(): req = request_generator( '/', DocumentArray([Document(text='client0-Request')]) ) - async for resp in runtime.gateway_bff.stream(request_iterator=req): + async for resp in runtime.gateway.streamer.stream(request_iterator=req): responses.append(resp) return responses @@ -352,7 +352,7 @@ async def _test(): req = request_generator( '/', DocumentArray([Document(text='client0-Request')]) ) - async for resp in runtime.gateway_bff.stream(request_iterator=req): + async for resp in runtime.gateway.streamer.stream(request_iterator=req): responses.append(resp) return responses @@ -409,7 +409,7 @@ async def _test(): req = request_generator( '/', DocumentArray([Document(text='client0-Request')]) ) - async for resp in runtime.gateway_bff.stream(request_iterator=req): + async for resp in runtime.gateway.streamer.stream(request_iterator=req): responses.append(resp) return responses @@ -459,7 +459,7 @@ async def _test(): req = request_generator( '/', DocumentArray([Document(text='client0-Request')]) ) - async for resp in runtime.gateway_bff.stream(request_iterator=req): + async for resp in runtime.gateway.streamer.stream(request_iterator=req): responses.append(resp) return responses