Skip to content

Commit

Permalink
refactor: small optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jul 28, 2022
1 parent b04e04a commit a339669
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 19 deletions.
24 changes: 20 additions & 4 deletions jina/serve/runtimes/gateway/graph/topology_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ async def _wait_previous_and_send(
endpoint: Optional[str],
executor_endpoint_mapping: Optional[Dict] = None,
target_executor_pattern: Optional[str] = None,
request_input_parameters: Dict = {}
request_input_parameters: Dict = {},
copy_request_at_send: bool = False
):
# Check my condition and send request with the condition
metadata = {}
Expand All @@ -114,7 +115,10 @@ async def _wait_previous_and_send(
request.parameters = _parse_specific_params(
request.parameters, self.name
)
self.parts_to_send.append(copy.deepcopy(request))
if copy_request_at_send:
self.parts_to_send.append(copy.deepcopy(request))
else:
self.parts_to_send.append(request)
# this is a specific needs
if len(self.parts_to_send) == self.number_of_parts:
self.start_time = datetime.utcnow()
Expand Down Expand Up @@ -174,6 +178,8 @@ def get_leaf_tasks(
executor_endpoint_mapping: Optional[Dict] = None,
target_executor_pattern: Optional[str] = None,
request_input_parameters: Dict = {},
request_input_has_specific_params: bool = False,
copy_request_at_send: bool = False
) -> List[Tuple[bool, asyncio.Task]]:
"""
Gets all the tasks corresponding from all the subgraphs born from this node
Expand All @@ -185,6 +191,8 @@ def get_leaf_tasks(
:param executor_endpoint_mapping: Optional map that maps the name of a Deployment with the endpoints that it binds to so that they can be skipped if needed
:param target_executor_pattern: Optional regex pattern for the target executor to decide whether or not the Executor should receive the request
:param request_input_parameters: The parameters coming from the Request as they arrive to the gateway
:param request_input_has_specific_params: Parameter added for optimization. If this is False, there is no need to copy at all the request
:param copy_request_at_send: Copy the request before actually calling the `ConnectionPool` sending
.. note:
deployment1 -> outgoing_nodes: deployment2
Expand Down Expand Up @@ -216,14 +224,16 @@ def get_leaf_tasks(
endpoint=endpoint,
executor_endpoint_mapping=executor_endpoint_mapping,
target_executor_pattern=target_executor_pattern,
request_input_parameters=request_input_parameters
request_input_parameters=request_input_parameters,
copy_request_at_send=copy_request_at_send
)
)
if self.leaf: # I am like a leaf
return [
(not self.floating, wait_previous_and_send_task)
] # I am the last in the chain
hanging_tasks_tuples = []
num_outgoing_nodes = len(self.outgoing_nodes)
for outgoing_node in self.outgoing_nodes:
t = outgoing_node.get_leaf_tasks(
connection_pool=connection_pool,
Expand All @@ -232,7 +242,9 @@ def get_leaf_tasks(
endpoint=endpoint,
executor_endpoint_mapping=executor_endpoint_mapping,
target_executor_pattern=target_executor_pattern,
request_input_parameters=request_input_parameters
request_input_parameters=request_input_parameters,
request_input_has_specific_params=request_input_has_specific_params,
copy_request_at_send=num_outgoing_nodes > 1 and request_input_has_specific_params
)
# We are interested in the last one, that will be the task that awaits all the previous
hanging_tasks_tuples.extend(t)
Expand Down Expand Up @@ -392,6 +404,10 @@ def _get_all_nodes(node, accum, accum_names):
return nodes

def collect_all_results(self):
"""Collect all the results from every node into a single dictionary so that gateway can collect them
:return: A dictionary of the results
"""
res = {}
for node in self.all_nodes:
if node.result_in_params_returned:
Expand Down
38 changes: 24 additions & 14 deletions jina/serve/runtimes/gateway/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from jina.importer import ImportExtensions
from jina.serve.networking import GrpcConnectionPool
from jina.serve.runtimes.gateway.graph.topology_graph import TopologyGraph
from jina.serve.runtimes.helper import _is_param_for_specific_executor
from jina.serve.runtimes.request_handlers.data_request_handler import DataRequestHandler

if TYPE_CHECKING:
Expand All @@ -29,18 +30,18 @@ class RequestHandler:
"""

def __init__(
self,
metrics_registry: Optional['CollectorRegistry'] = None,
runtime_name: Optional[str] = None,
self,
metrics_registry: Optional['CollectorRegistry'] = None,
runtime_name: Optional[str] = None,
):
self._request_init_time = {} if metrics_registry else None
self._executor_endpoint_mapping = None
self._gathering_endpoints = False

if metrics_registry:
with ImportExtensions(
required=True,
help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina',
required=True,
help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina',
):
from prometheus_client import Gauge, Summary

Expand Down Expand Up @@ -81,7 +82,7 @@ def _update_end_request_metrics(self, result: 'Request'):
self._pending_requests_metrics.dec()

def handle_request(
self, graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool'
self, graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool'
) -> Callable[['Request'], 'Tuple[Future, Optional[Future]]']:
"""
Function that handles the requests arriving to the gateway. This will be passed to the streamer.
Expand All @@ -102,8 +103,8 @@ async def gather_endpoints(request_graph):
err_code = err.code()
if err_code == grpc.StatusCode.UNAVAILABLE:
err._details = (
err.details()
+ f' |Gateway: Communication error with deployment at address(es) {err.dest_addr}. Head or worker(s) may be down.'
err.details()
+ f' |Gateway: Communication error with deployment at address(es) {err.dest_addr}. Head or worker(s) may be down.'
)
raise err
else:
Expand All @@ -121,8 +122,8 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]':

if graph.has_filter_conditions:
request_doc_ids = request.data.docs[
:, 'id'
] # used to maintain order of docs that are filtered by executors
:, 'id'
] # used to maintain order of docs that are filtered by executors
responding_tasks = []
floating_tasks = []
endpoint = request.header.exec_endpoint
Expand All @@ -131,6 +132,13 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]':
r.start_time.GetCurrentTime()
# If the request is targeting a specific deployment, we can send directly to the deployment instead of
# querying the graph
num_outgoing_nodes = len(request_graph.origin_nodes)
has_specific_params = False
request_input_parameters = request.parameters
for key in request_input_parameters:
if _is_param_for_specific_executor(key):
has_specific_params = True

for origin_node in request_graph.origin_nodes:
leaf_tasks = origin_node.get_leaf_tasks(
connection_pool=connection_pool,
Expand All @@ -139,7 +147,9 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]':
endpoint=endpoint,
executor_endpoint_mapping=self._executor_endpoint_mapping,
target_executor_pattern=request.header.target_executor,
request_input_parameters=request.parameters
request_input_parameters=request_input_parameters,
request_input_has_specific_params=has_specific_params,
copy_request_at_send=num_outgoing_nodes > 1 and has_specific_params
)
# Every origin node returns a set of tasks that are the ones corresponding to the leafs of each of their
# subtrees that unwrap all the previous tasks. It starts like a chain of waiting for tasks from previous
Expand All @@ -159,12 +169,12 @@ def sort_by_request_order(doc):
response.data.docs = DocumentArray(sorted_docs)

async def _process_results_at_end_gateway(
tasks: List[asyncio.Task], request_graph: TopologyGraph
tasks: List[asyncio.Task], request_graph: TopologyGraph
) -> asyncio.Future:
try:
if (
self._executor_endpoint_mapping is None
and not self._gathering_endpoints
self._executor_endpoint_mapping is None
and not self._gathering_endpoints
):
self._gathering_endpoints = True
asyncio.create_task(gather_endpoints(request_graph))
Expand Down
2 changes: 1 addition & 1 deletion jina/serve/runtimes/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _get_name_from_replicas_name(name: str) -> Tuple[str]:
def _is_param_for_specific_executor(key_name: str) -> bool:
"""Tell if a key is for a specific Executor
ex: 'key' is for every Executor whereas 'key__my_executor' is only for 'my_executor'
ex: 'key' is for every Executor whereas 'my_executor__key' is only for 'my_executor'
:param key_name: key name of the param
:return: return True if key_name is for specific Executor, False otherwise
Expand Down

0 comments on commit a339669

Please sign in to comment.