Skip to content

Commit

Permalink
fix: fix specific params problem with branches (#5038)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jul 28, 2022
1 parent 81999e9 commit 032bd5e
Show file tree
Hide file tree
Showing 12 changed files with 271 additions and 152 deletions.
79 changes: 51 additions & 28 deletions jina/serve/runtimes/gateway/graph/topology_graph.py
Expand Up @@ -44,33 +44,31 @@ def __init__(
self.number_of_parts = number_of_parts
self.floating = floating
self.parts_to_send = []
self.original_parameters = []
self.start_time = None
self.end_time = None
self.status = None
self._filter_condition = filter_condition
self._reduce = reduce
self._timeout_send = timeout_send
self._retries = retries
self.result_in_params_returned = None

@property
def leaf(self):
return len(self.outgoing_nodes) == 0

def _update_requests_with_filter_condition(self):
def _update_requests_with_filter_condition(self, need_copy):
for i in range(len(self.parts_to_send)):
copy_req = copy.deepcopy(self.parts_to_send[i])
filtered_docs = copy_req.docs.find(self._filter_condition)
copy_req.data.docs = filtered_docs

self.parts_to_send[i] = copy_req

def _update_request_by_params(self, deployment_name: str):

req = self.parts_to_send[i] if not need_copy else copy.deepcopy(self.parts_to_send[i])
filtered_docs = req.docs.find(self._filter_condition)
req.data.docs = filtered_docs
self.parts_to_send[i] = req

def _update_request_by_params(self, deployment_name: str, request_input_parameters: Dict):
specific_parameters = _parse_specific_params(
request_input_parameters, deployment_name
)
for i in range(len(self.parts_to_send)):
specific_parameters = _parse_specific_params(
self.original_parameters[i], deployment_name
)
self.parts_to_send[i].parameters = specific_parameters

def _handle_internalnetworkerror(self, err):
Expand Down Expand Up @@ -99,12 +97,14 @@ def get_endpoints(self, connection_pool: GrpcConnectionPool) -> asyncio.Task:

async def _wait_previous_and_send(
self,
request: DataRequest,
request: Optional[DataRequest],
previous_task: Optional[asyncio.Task],
connection_pool: GrpcConnectionPool,
endpoint: Optional[str],
executor_endpoint_mapping: Optional[Dict] = None,
target_executor_pattern: Optional[str] = None,
request_input_parameters: Dict = {},
copy_request_at_send: bool = False
):
# Check my condition and send request with the condition
metadata = {}
Expand All @@ -114,18 +114,19 @@ async def _wait_previous_and_send(
if metadata and 'is-error' in metadata:
return request, metadata
elif request is not None:
original_parameters = copy.deepcopy(request.parameters)
request.parameters = _parse_specific_params(
request.parameters, self.name
)
self.parts_to_send.append(request)
self.original_parameters.append(original_parameters)
if copy_request_at_send:
self.parts_to_send.append(copy.deepcopy(request))
else:
self.parts_to_send.append(request)
# this is a specific needs
if len(self.parts_to_send) == self.number_of_parts:
self.start_time = datetime.utcnow()
self._update_request_by_params(self.name)
self._update_request_by_params(self.name, request_input_parameters)
if self._filter_condition is not None:
self._update_requests_with_filter_condition()
self._update_requests_with_filter_condition(need_copy=not copy_request_at_send)
if self._reduce and len(self.parts_to_send) > 1:
self.parts_to_send = [
DataRequestHandler.reduce_requests(self.parts_to_send)
Expand Down Expand Up @@ -154,13 +155,12 @@ async def _wait_previous_and_send(
timeout=self._timeout_send,
retries=self._retries,
)
return_parameters = original_parameters
if DataRequestHandler._KEY_RESULT in resp.parameters:
return_parameters[
DataRequestHandler._KEY_RESULT
] = resp.parameters[DataRequestHandler._KEY_RESULT]

resp.parameters = return_parameters
# Accumulate results from each Node and then add them to the original
self.result_in_params_returned = resp.parameters[DataRequestHandler._KEY_RESULT]
request.parameters = request_input_parameters
resp.parameters = request_input_parameters
self.parts_to_send.clear()
except InternalNetworkError as err:
self._handle_internalnetworkerror(err)

Expand All @@ -179,6 +179,9 @@ def get_leaf_tasks(
endpoint: Optional[str] = None,
executor_endpoint_mapping: Optional[Dict] = None,
target_executor_pattern: Optional[str] = None,
request_input_parameters: Dict = {},
request_input_has_specific_params: bool = False,
copy_request_at_send: bool = False
) -> List[Tuple[bool, asyncio.Task]]:
"""
Gets all the tasks corresponding from all the subgraphs born from this node
Expand All @@ -189,6 +192,9 @@ def get_leaf_tasks(
:param endpoint: Optional string defining the endpoint of this request
:param executor_endpoint_mapping: Optional map that maps the name of a Deployment with the endpoints that it binds to so that they can be skipped if needed
:param target_executor_pattern: Optional regex pattern for the target executor to decide whether or not the Executor should receive the request
:param request_input_parameters: The parameters coming from the Request as they arrive to the gateway
:param request_input_has_specific_params: Parameter added for optimization. If this is False, there is no need to copy at all the request
:param copy_request_at_send: Copy the request before actually calling the `ConnectionPool` sending
.. note:
deployment1 -> outgoing_nodes: deployment2
Expand All @@ -214,19 +220,22 @@ def get_leaf_tasks(
"""
wait_previous_and_send_task = asyncio.create_task(
self._wait_previous_and_send(
request_to_send,
previous_task,
connection_pool,
request=request_to_send,
previous_task=previous_task,
connection_pool=connection_pool,
endpoint=endpoint,
executor_endpoint_mapping=executor_endpoint_mapping,
target_executor_pattern=target_executor_pattern,
request_input_parameters=request_input_parameters,
copy_request_at_send=copy_request_at_send
)
)
if self.leaf: # I am like a leaf
return [
(not self.floating, wait_previous_and_send_task)
] # I am the last in the chain
hanging_tasks_tuples = []
num_outgoing_nodes = len(self.outgoing_nodes)
for outgoing_node in self.outgoing_nodes:
t = outgoing_node.get_leaf_tasks(
connection_pool=connection_pool,
Expand All @@ -235,6 +244,9 @@ def get_leaf_tasks(
endpoint=endpoint,
executor_endpoint_mapping=executor_endpoint_mapping,
target_executor_pattern=target_executor_pattern,
request_input_parameters=request_input_parameters,
request_input_has_specific_params=request_input_has_specific_params,
copy_request_at_send=num_outgoing_nodes > 1 and request_input_has_specific_params
)
# We are interested in the last one, that will be the task that awaits all the previous
hanging_tasks_tuples.extend(t)
Expand Down Expand Up @@ -392,3 +404,14 @@ def _get_all_nodes(node, accum, accum_names):
nodes.append(st_node)
node_names.append(st_node_name)
return nodes

def collect_all_results(self):
"""Collect all the results from every node into a single dictionary so that gateway can collect them
:return: A dictionary of the results
"""
res = {}
for node in self.all_nodes:
if node.result_in_params_returned:
res.update(node.result_in_params_returned)
return res
45 changes: 32 additions & 13 deletions jina/serve/runtimes/gateway/request_handling.py
Expand Up @@ -10,6 +10,8 @@
from jina.importer import ImportExtensions
from jina.serve.networking import GrpcConnectionPool
from jina.serve.runtimes.gateway.graph.topology_graph import TopologyGraph
from jina.serve.runtimes.helper import _is_param_for_specific_executor
from jina.serve.runtimes.request_handlers.data_request_handler import DataRequestHandler

if TYPE_CHECKING:
from asyncio import Future
Expand All @@ -28,18 +30,18 @@ class RequestHandler:
"""

def __init__(
self,
metrics_registry: Optional['CollectorRegistry'] = None,
runtime_name: Optional[str] = None,
self,
metrics_registry: Optional['CollectorRegistry'] = None,
runtime_name: Optional[str] = None,
):
self._request_init_time = {} if metrics_registry else None
self._executor_endpoint_mapping = None
self._gathering_endpoints = False

if metrics_registry:
with ImportExtensions(
required=True,
help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina',
required=True,
help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina',
):
from prometheus_client import Gauge, Summary

Expand Down Expand Up @@ -80,7 +82,7 @@ def _update_end_request_metrics(self, result: 'Request'):
self._pending_requests_metrics.dec()

def handle_request(
self, graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool'
self, graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool'
) -> Callable[['Request'], 'Tuple[Future, Optional[Future]]']:
"""
Function that handles the requests arriving to the gateway. This will be passed to the streamer.
Expand All @@ -101,8 +103,8 @@ async def gather_endpoints(request_graph):
err_code = err.code()
if err_code == grpc.StatusCode.UNAVAILABLE:
err._details = (
err.details()
+ f' |Gateway: Communication error with deployment at address(es) {err.dest_addr}. Head or worker(s) may be down.'
err.details()
+ f' |Gateway: Communication error with deployment at address(es) {err.dest_addr}. Head or worker(s) may be down.'
)
raise err
else:
Expand All @@ -120,8 +122,8 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]':

if graph.has_filter_conditions:
request_doc_ids = request.data.docs[
:, 'id'
] # used to maintain order of docs that are filtered by executors
:, 'id'
] # used to maintain order of docs that are filtered by executors
responding_tasks = []
floating_tasks = []
endpoint = request.header.exec_endpoint
Expand All @@ -130,6 +132,14 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]':
r.start_time.GetCurrentTime()
# If the request is targeting a specific deployment, we can send directly to the deployment instead of
# querying the graph
num_outgoing_nodes = len(request_graph.origin_nodes)
has_specific_params = False
request_input_parameters = request.parameters
for key in request_input_parameters:
if _is_param_for_specific_executor(key):
has_specific_params = True
break

for origin_node in request_graph.origin_nodes:
leaf_tasks = origin_node.get_leaf_tasks(
connection_pool=connection_pool,
Expand All @@ -138,6 +148,9 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]':
endpoint=endpoint,
executor_endpoint_mapping=self._executor_endpoint_mapping,
target_executor_pattern=request.header.target_executor,
request_input_parameters=request_input_parameters,
request_input_has_specific_params=has_specific_params,
copy_request_at_send=num_outgoing_nodes > 1 and has_specific_params
)
# Every origin node returns a set of tasks that are the ones corresponding to the leafs of each of their
# subtrees that unwrap all the previous tasks. It starts like a chain of waiting for tasks from previous
Expand All @@ -157,12 +170,12 @@ def sort_by_request_order(doc):
response.data.docs = DocumentArray(sorted_docs)

async def _process_results_at_end_gateway(
tasks: List[asyncio.Task], request_graph: TopologyGraph
tasks: List[asyncio.Task], request_graph: TopologyGraph
) -> asyncio.Future:
try:
if (
self._executor_endpoint_mapping is None
and not self._gathering_endpoints
self._executor_endpoint_mapping is None
and not self._gathering_endpoints
):
self._gathering_endpoints = True
asyncio.create_task(gather_endpoints(request_graph))
Expand All @@ -181,6 +194,12 @@ async def _process_results_at_end_gateway(

if graph.has_filter_conditions:
_sort_response_docs(response)

collect_results = request_graph.collect_all_results()
resp_params = response.parameters
if len(collect_results) > 0:
resp_params[DataRequestHandler._KEY_RESULT] = collect_results
response.parameters = resp_params
return response

# In case of empty topologies
Expand Down
2 changes: 1 addition & 1 deletion jina/serve/runtimes/helper.py
Expand Up @@ -36,7 +36,7 @@ def _get_name_from_replicas_name(name: str) -> Tuple[str]:
def _is_param_for_specific_executor(key_name: str) -> bool:
"""Tell if a key is for a specific Executor
ex: 'key' is for every Executor whereas 'key__my_executor' is only for 'my_executor'
ex: 'key' is for every Executor whereas 'my_executor__key' is only for 'my_executor'
:param key_name: key name of the param
:return: return True if key_name is for specific Executor, False otherwise
Expand Down
Expand Up @@ -8,7 +8,6 @@
cur_dir = os.path.dirname(os.path.abspath(__file__))

img_name = 'jina/replica-exec'
exposed_port = 12345


@pytest.fixture(scope='function')
Expand All @@ -26,7 +25,8 @@ def docker_image_built():

@pytest.mark.parametrize('shards', [1, 2])
@pytest.mark.parametrize('replicas', [1, 3, 4])
def test_containerruntime_args(docker_image_built, shards, replicas):
def test_containerruntime_args(docker_image_built, shards, replicas, port_generator):
exposed_port = port_generator()
f = Flow(port=exposed_port).add(
name='executor_container',
uses=f'docker://{img_name}',
Expand Down
10 changes: 6 additions & 4 deletions tests/integration/high_order_matches/test_document.py
@@ -1,7 +1,5 @@
from jina import Client, Document, Executor, Flow, requests

exposed_port = 12345


def validate_results(results):
req = results[0]
Expand All @@ -25,7 +23,9 @@ def index(self, docs, **kwargs):
doc.matches.append(Document())


def test_single_executor():
def test_single_executor(port_generator):

exposed_port = port_generator()

f = Flow(port=exposed_port).add(
uses={'jtype': 'MatchAdder', 'with': {'traversal_paths': 'r,m'}}
Expand All @@ -38,7 +38,9 @@ def test_single_executor():
validate_results(results)


def test_multi_executor():
def test_multi_executor(port_generator):

exposed_port = port_generator()

f = (
Flow(port=exposed_port)
Expand Down

0 comments on commit 032bd5e

Please sign in to comment.