Skip to content

Commit

Permalink
refactor: extract gateway app logic into custom gateway class (#5153)
Browse files Browse the repository at this point in the history
  • Loading branch information
alaeddine-13 committed Sep 20, 2022
1 parent caf4a3d commit 243639d
Show file tree
Hide file tree
Showing 15 changed files with 792 additions and 443 deletions.
101 changes: 101 additions & 0 deletions jina/serve/gateway.py
@@ -0,0 +1,101 @@
import abc
import argparse
from typing import TYPE_CHECKING, Optional

from jina.jaml import JAMLCompatible
from jina.logging.logger import JinaLogger
from jina.serve.streamer import GatewayStreamer

__all__ = ['BaseGateway']

if TYPE_CHECKING:
from prometheus_client import CollectorRegistry


class BaseGateway(JAMLCompatible):
"""
The base class of all custom Gateways, can be used to build a custom interface to a Jina Flow that supports
gateway logic
:class:`jina.Gateway` as an alias for this class.
"""

def __init__(
self,
name: Optional[str] = 'gateway',
**kwargs,
):
"""
:param name: Gateway pod name
:param kwargs: additional extra keyword arguments to avoid failing when extra params ara passed that are not expected
"""
self.streamer = None
self.name = name
# TODO: original implementation also passes args, maybe move this to a setter/initializer func
self.logger = JinaLogger(self.name)

def set_streamer(
self,
args: 'argparse.Namespace' = None,
timeout_send: Optional[float] = None,
metrics_registry: Optional['CollectorRegistry'] = None,
runtime_name: Optional[str] = None,
):
"""
Set streamer object by providing runtime parameters.
:param args: runtime args
:param timeout_send: grpc connection timeout
:param metrics_registry: metric registry when monitoring is enabled
:param runtime_name: name of the runtime providing the streamer
"""
import json

from jina.serve.streamer import GatewayStreamer

graph_description = json.loads(args.graph_description)
graph_conditions = json.loads(args.graph_conditions)
deployments_addresses = json.loads(args.deployments_addresses)
deployments_disable_reduce = json.loads(args.deployments_disable_reduce)

self.streamer = GatewayStreamer(
graph_representation=graph_description,
executor_addresses=deployments_addresses,
graph_conditions=graph_conditions,
deployments_disable_reduce=deployments_disable_reduce,
timeout_send=timeout_send,
retries=args.retries,
compression=args.compression,
runtime_name=runtime_name,
prefetch=args.prefetch,
logger=self.logger,
metrics_registry=metrics_registry,
)

@abc.abstractmethod
async def setup_server(self):
"""Setup server"""
...

@abc.abstractmethod
async def run_server(self):
"""Run server forever"""
...

async def teardown(self):
"""Free other resources allocated with the server, e.g, gateway object, ..."""
await self.streamer.close()

@abc.abstractmethod
async def stop_server(self):
"""Stop server"""
...

# some servers need to set a flag useful in handling termination signals
# e.g, HTTPGateway/ WebSocketGateway
@property
def should_exit(self) -> bool:
"""
Boolean flag that indicates whether the gateway server should exit or not
:return: boolean flag
"""
return False
1 change: 0 additions & 1 deletion jina/serve/runtimes/gateway/__init__.py
Expand Up @@ -28,4 +28,3 @@ def __init__(
if self.timeout_send:
self.timeout_send /= 1e3 # convert ms to seconds
super().__init__(args, cancel_event, **kwargs)

159 changes: 15 additions & 144 deletions jina/serve/runtimes/gateway/grpc/__init__.py
@@ -1,37 +1,17 @@
import argparse
import os

import grpc
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
from grpc_reflection.v1alpha import reflection

from jina import __default_host__
from jina.excepts import PortAlreadyUsed
from jina.helper import get_full_version, is_port_free
from jina.proto import jina_pb2, jina_pb2_grpc
from jina.serve.bff import GatewayBFF
from jina.helper import is_port_free
from jina.serve.runtimes.gateway import GatewayRuntime
from jina.serve.runtimes.helper import _get_grpc_server_options
from jina.types.request.status import StatusMessage
from jina.serve.runtimes.gateway.grpc.gateway import GRPCGateway

__all__ = ['GRPCGatewayRuntime']


class GRPCGatewayRuntime(GatewayRuntime):
"""Gateway Runtime for gRPC."""

def __init__(
self,
args: argparse.Namespace,
**kwargs,
):
"""Initialize the runtime
:param args: args from CLI
:param kwargs: keyword args
"""
self._health_servicer = health.HealthServicer(experimental_non_blocking=True)
super().__init__(args, **kwargs)

async def async_setup(self):
"""
The async method to setup.
Expand All @@ -45,142 +25,33 @@ async def async_setup(self):
if not (is_port_free(__default_host__, self.args.port)):
raise PortAlreadyUsed(f'port:{self.args.port}')

self.server = grpc.aio.server(
options=_get_grpc_server_options(self.args.grpc_server_options)
self.gateway = GRPCGateway(
name=self.name,
grpc_server_options=self.args.grpc_server_options,
port=self.args.port,
ssl_keyfile=self.args.ssl_keyfile,
ssl_certfile=self.args.ssl_certfile,
)

await self._async_setup_server()

async def _async_setup_server(self):

import json

graph_description = json.loads(self.args.graph_description)
graph_conditions = json.loads(self.args.graph_conditions)
deployments_addresses = json.loads(self.args.deployments_addresses)
deployments_disable_reduce = json.loads(self.args.deployments_disable_reduce)

self.gateway_bff = GatewayBFF(
graph_representation=graph_description,
executor_addresses=deployments_addresses,
graph_conditions=graph_conditions,
deployments_disable_reduce=deployments_disable_reduce,
self.gateway.set_streamer(
args=self.args,
timeout_send=self.timeout_send,
retries=self.args.retries,
compression=self.args.compression,
runtime_name=self.name,
prefetch=self.args.prefetch,
logger=self.logger,
metrics_registry=self.metrics_registry,
runtime_name=self.name,
)

jina_pb2_grpc.add_JinaRPCServicer_to_server(
self.gateway_bff._streamer, self.server
)
jina_pb2_grpc.add_JinaGatewayDryRunRPCServicer_to_server(self, self.server)
jina_pb2_grpc.add_JinaInfoRPCServicer_to_server(self, self.server)

service_names = (
jina_pb2.DESCRIPTOR.services_by_name['JinaRPC'].full_name,
jina_pb2.DESCRIPTOR.services_by_name['JinaGatewayDryRunRPC'].full_name,
jina_pb2.DESCRIPTOR.services_by_name['JinaInfoRPC'].full_name,
reflection.SERVICE_NAME,
)
# Mark all services as healthy.
health_pb2_grpc.add_HealthServicer_to_server(self._health_servicer, self.server)

for service in service_names:
self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING)
reflection.enable_server_reflection(service_names, self.server)

bind_addr = f'{__default_host__}:{self.args.port}'

if self.args.ssl_keyfile and self.args.ssl_certfile:
with open(self.args.ssl_keyfile, 'rb') as f:
private_key = f.read()
with open(self.args.ssl_certfile, 'rb') as f:
certificate_chain = f.read()

server_credentials = grpc.ssl_server_credentials(
(
(
private_key,
certificate_chain,
),
)
)
self.server.add_secure_port(bind_addr, server_credentials)
elif (
self.args.ssl_keyfile != self.args.ssl_certfile
): # if we have only ssl_keyfile and not ssl_certfile or vice versa
raise ValueError(
f"you can't pass a ssl_keyfile without a ssl_certfile and vice versa"
)
else:
self.server.add_insecure_port(bind_addr)
self.logger.debug(f'start server bound to {bind_addr}')
await self.server.start()
await self.gateway.setup_server()

async def async_teardown(self):
"""Close the connection pool"""
# usually async_cancel should already have been called, but then its a noop
# if the runtime is stopped without a sigterm (e.g. as a context manager, this can happen)
self._health_servicer.enter_graceful_shutdown()
await self.gateway_bff.close()
await self.gateway.teardown()
await self.async_cancel()

async def async_cancel(self):
"""The async method to stop server."""
await self.server.stop(0)
await self.gateway.stop_server()

async def async_run_forever(self):
"""The async running of server."""
await self.server.wait_for_termination()

async def dry_run(self, empty, context) -> jina_pb2.StatusProto:
"""
Process the the call requested by having a dry run call to every Executor in the graph
:param empty: The service expects an empty protobuf message
:param context: grpc context
:returns: the response request
"""
from docarray import DocumentArray

from jina.clients.request import request_generator
from jina.enums import DataInputType
from jina.serve.executors import __dry_run_endpoint__

da = DocumentArray()

try:
req_iterator = request_generator(
exec_endpoint=__dry_run_endpoint__,
data=da,
data_type=DataInputType.DOCUMENT,
)
async for _ in self.gateway_bff.stream(request_iterator=req_iterator):
pass
status_message = StatusMessage()
status_message.set_code(jina_pb2.StatusProto.SUCCESS)
return status_message.proto
except Exception as ex:
status_message = StatusMessage()
status_message.set_exception(ex)
return status_message.proto

async def _status(self, empty, context) -> jina_pb2.JinaInfoProto:
"""
Process the the call requested and return the JinaInfo of the Runtime
:param empty: The service expects an empty protobuf message
:param context: grpc context
:returns: the response request
"""
infoProto = jina_pb2.JinaInfoProto()
version, env_info = get_full_version()
for k, v in version.items():
infoProto.jina[k] = str(v)
for k, v in env_info.items():
infoProto.envs[k] = str(v)
return infoProto
await self.gateway.run_server()

0 comments on commit 243639d

Please sign in to comment.