diff --git a/jina/serve/runtimes/gateway/graph/topology_graph.py b/jina/serve/runtimes/gateway/graph/topology_graph.py index fde8a11d337b3..eea9015bca82c 100644 --- a/jina/serve/runtimes/gateway/graph/topology_graph.py +++ b/jina/serve/runtimes/gateway/graph/topology_graph.py @@ -101,7 +101,8 @@ async def _wait_previous_and_send( endpoint: Optional[str], executor_endpoint_mapping: Optional[Dict] = None, target_executor_pattern: Optional[str] = None, - request_input_parameters: Dict = {} + request_input_parameters: Dict = {}, + copy_request_at_send: bool = False ): # Check my condition and send request with the condition metadata = {} @@ -114,7 +115,10 @@ async def _wait_previous_and_send( request.parameters = _parse_specific_params( request.parameters, self.name ) - self.parts_to_send.append(copy.deepcopy(request)) + if copy_request_at_send: + self.parts_to_send.append(copy.deepcopy(request)) + else: + self.parts_to_send.append(request) # this is a specific needs if len(self.parts_to_send) == self.number_of_parts: self.start_time = datetime.utcnow() @@ -174,6 +178,8 @@ def get_leaf_tasks( executor_endpoint_mapping: Optional[Dict] = None, target_executor_pattern: Optional[str] = None, request_input_parameters: Dict = {}, + request_input_has_specific_params: bool = False, + copy_request_at_send: bool = False ) -> List[Tuple[bool, asyncio.Task]]: """ Gets all the tasks corresponding from all the subgraphs born from this node @@ -185,6 +191,8 @@ def get_leaf_tasks( :param executor_endpoint_mapping: Optional map that maps the name of a Deployment with the endpoints that it binds to so that they can be skipped if needed :param target_executor_pattern: Optional regex pattern for the target executor to decide whether or not the Executor should receive the request :param request_input_parameters: The parameters coming from the Request as they arrive to the gateway + :param request_input_has_specific_params: Parameter added for optimization. If this is False, there is no need to copy at all the request + :param copy_request_at_send: Copy the request before actually calling the `ConnectionPool` sending .. note: deployment1 -> outgoing_nodes: deployment2 @@ -216,7 +224,8 @@ def get_leaf_tasks( endpoint=endpoint, executor_endpoint_mapping=executor_endpoint_mapping, target_executor_pattern=target_executor_pattern, - request_input_parameters=request_input_parameters + request_input_parameters=request_input_parameters, + copy_request_at_send=copy_request_at_send ) ) if self.leaf: # I am like a leaf @@ -224,6 +233,7 @@ def get_leaf_tasks( (not self.floating, wait_previous_and_send_task) ] # I am the last in the chain hanging_tasks_tuples = [] + num_outgoing_nodes = len(self.outgoing_nodes) for outgoing_node in self.outgoing_nodes: t = outgoing_node.get_leaf_tasks( connection_pool=connection_pool, @@ -232,7 +242,9 @@ def get_leaf_tasks( endpoint=endpoint, executor_endpoint_mapping=executor_endpoint_mapping, target_executor_pattern=target_executor_pattern, - request_input_parameters=request_input_parameters + request_input_parameters=request_input_parameters, + request_input_has_specific_params=request_input_has_specific_params, + copy_request_at_send=num_outgoing_nodes > 1 and request_input_has_specific_params ) # We are interested in the last one, that will be the task that awaits all the previous hanging_tasks_tuples.extend(t) @@ -392,8 +404,12 @@ def _get_all_nodes(node, accum, accum_names): return nodes def collect_all_results(self): + """Collect all the results from every node into a single dictionary so that gateway can collect them + + :return: A dictionary of the results + """ res = {} for node in self.all_nodes: if node.result_in_params_returned: - res.update(**node.result_in_params_returned) + res.update(node.result_in_params_returned) return res diff --git a/jina/serve/runtimes/gateway/request_handling.py b/jina/serve/runtimes/gateway/request_handling.py index 72d7dd42437ad..9b056007d46a9 100644 --- a/jina/serve/runtimes/gateway/request_handling.py +++ b/jina/serve/runtimes/gateway/request_handling.py @@ -10,6 +10,7 @@ from jina.importer import ImportExtensions from jina.serve.networking import GrpcConnectionPool from jina.serve.runtimes.gateway.graph.topology_graph import TopologyGraph +from jina.serve.runtimes.helper import _is_param_for_specific_executor from jina.serve.runtimes.request_handlers.data_request_handler import DataRequestHandler if TYPE_CHECKING: @@ -29,9 +30,9 @@ class RequestHandler: """ def __init__( - self, - metrics_registry: Optional['CollectorRegistry'] = None, - runtime_name: Optional[str] = None, + self, + metrics_registry: Optional['CollectorRegistry'] = None, + runtime_name: Optional[str] = None, ): self._request_init_time = {} if metrics_registry else None self._executor_endpoint_mapping = None @@ -39,8 +40,8 @@ def __init__( if metrics_registry: with ImportExtensions( - required=True, - help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', + required=True, + help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Gauge, Summary @@ -81,7 +82,7 @@ def _update_end_request_metrics(self, result: 'Request'): self._pending_requests_metrics.dec() def handle_request( - self, graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool' + self, graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool' ) -> Callable[['Request'], 'Tuple[Future, Optional[Future]]']: """ Function that handles the requests arriving to the gateway. This will be passed to the streamer. @@ -102,8 +103,8 @@ async def gather_endpoints(request_graph): err_code = err.code() if err_code == grpc.StatusCode.UNAVAILABLE: err._details = ( - err.details() - + f' |Gateway: Communication error with deployment at address(es) {err.dest_addr}. Head or worker(s) may be down.' + err.details() + + f' |Gateway: Communication error with deployment at address(es) {err.dest_addr}. Head or worker(s) may be down.' ) raise err else: @@ -121,8 +122,8 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]': if graph.has_filter_conditions: request_doc_ids = request.data.docs[ - :, 'id' - ] # used to maintain order of docs that are filtered by executors + :, 'id' + ] # used to maintain order of docs that are filtered by executors responding_tasks = [] floating_tasks = [] endpoint = request.header.exec_endpoint @@ -131,6 +132,13 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]': r.start_time.GetCurrentTime() # If the request is targeting a specific deployment, we can send directly to the deployment instead of # querying the graph + num_outgoing_nodes = len(request_graph.origin_nodes) + has_specific_params = False + request_input_parameters = request.parameters + for key in request_input_parameters: + if _is_param_for_specific_executor(key): + has_specific_params = True + for origin_node in request_graph.origin_nodes: leaf_tasks = origin_node.get_leaf_tasks( connection_pool=connection_pool, @@ -139,7 +147,9 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]': endpoint=endpoint, executor_endpoint_mapping=self._executor_endpoint_mapping, target_executor_pattern=request.header.target_executor, - request_input_parameters=request.parameters + request_input_parameters=request_input_parameters, + request_input_has_specific_params=has_specific_params, + copy_request_at_send=num_outgoing_nodes > 1 and has_specific_params ) # Every origin node returns a set of tasks that are the ones corresponding to the leafs of each of their # subtrees that unwrap all the previous tasks. It starts like a chain of waiting for tasks from previous @@ -159,12 +169,12 @@ def sort_by_request_order(doc): response.data.docs = DocumentArray(sorted_docs) async def _process_results_at_end_gateway( - tasks: List[asyncio.Task], request_graph: TopologyGraph + tasks: List[asyncio.Task], request_graph: TopologyGraph ) -> asyncio.Future: try: if ( - self._executor_endpoint_mapping is None - and not self._gathering_endpoints + self._executor_endpoint_mapping is None + and not self._gathering_endpoints ): self._gathering_endpoints = True asyncio.create_task(gather_endpoints(request_graph)) @@ -185,7 +195,9 @@ async def _process_results_at_end_gateway( _sort_response_docs(response) collect_results = request_graph.collect_all_results() - response.parameters[DataRequestHandler._KEY_RESULT] = collect_results + resp_params = response.parameters + resp_params[DataRequestHandler._KEY_RESULT] = collect_results + response.parameters = resp_params return response # In case of empty topologies diff --git a/jina/serve/runtimes/helper.py b/jina/serve/runtimes/helper.py index 100e8463f94e4..7a4a39235cd05 100644 --- a/jina/serve/runtimes/helper.py +++ b/jina/serve/runtimes/helper.py @@ -36,7 +36,7 @@ def _get_name_from_replicas_name(name: str) -> Tuple[str]: def _is_param_for_specific_executor(key_name: str) -> bool: """Tell if a key is for a specific Executor - ex: 'key' is for every Executor whereas 'key__my_executor' is only for 'my_executor' + ex: 'key' is for every Executor whereas 'my_executor__key' is only for 'my_executor' :param key_name: key name of the param :return: return True if key_name is for specific Executor, False otherwise diff --git a/tests/integration/reduce/test_reduce.py b/tests/integration/reduce/test_reduce.py index 8fea4dec02ce9..b58badf2b34c8 100644 --- a/tests/integration/reduce/test_reduce.py +++ b/tests/integration/reduce/test_reduce.py @@ -3,8 +3,6 @@ from jina import Client, Document, DocumentArray, Executor, Flow, requests -exposed_port = 12345 - class ShardsExecutor(Executor): def __init__(self, n_docs: int = 5, **kwargs): @@ -46,7 +44,8 @@ def fake_reduce(self, **kwargs): @pytest.mark.parametrize('n_docs', [3, 5]) -def test_reduce_shards(n_docs): +def test_reduce_shards(n_docs, port_generator): + exposed_port = port_generator() n_shards = 3 search_flow = Flow(port=exposed_port).add( uses=ShardsExecutor, @@ -55,10 +54,10 @@ def test_reduce_shards(n_docs): uses_with={'n_docs': n_docs}, ) - with search_flow as f: + with search_flow: da = DocumentArray([Document() for _ in range(5)]) - resp = Client(port=exposed_port, return_responses=True).post( - '/search', inputs=da + resp = Client(port=exposed_port).post( + '/search', inputs=da, return_responses=True ) assert len(resp[0].docs) == 5 @@ -85,7 +84,8 @@ def test_reduce_shards(n_docs): @pytest.mark.parametrize('n_shards', [3, 5]) @pytest.mark.parametrize('n_docs', [3, 5]) -def test_uses_after_no_reduce(n_shards, n_docs): +def test_uses_after_no_reduce(n_shards, n_docs, port_generator): + exposed_port = port_generator() search_flow = Flow(port=exposed_port).add( uses=ShardsExecutor, shards=n_shards, @@ -94,10 +94,10 @@ def test_uses_after_no_reduce(n_shards, n_docs): uses_with={'n_docs': n_docs}, ) - with search_flow as f: + with search_flow: da = DocumentArray([Document() for _ in range(5)]) - resp = Client(port=exposed_port, return_responses=True).post( - '/search', inputs=da + resp = Client(port=exposed_port).post( + '/search', inputs=da, return_responses=True ) # assert no reduce happened @@ -140,18 +140,19 @@ def endpoint(self, docs: DocumentArray, **kwargs): return status -def test_reduce_needs(): +def test_reduce_needs(port_generator): + exposed_port = port_generator() flow = ( Flow(port=exposed_port) - .add(uses=Executor1, name='pod0') - .add(uses=Executor2, needs='gateway', name='pod1') - .add(uses=Executor3, needs='gateway', name='pod2') - .add(needs=['pod0', 'pod1', 'pod2'], name='pod3') + .add(uses=Executor1, name='pod0') + .add(uses=Executor2, needs='gateway', name='pod1') + .add(uses=Executor3, needs='gateway', name='pod2') + .add(needs=['pod0', 'pod1', 'pod2'], name='pod3') ) - with flow as f: + with flow: da = DocumentArray([Document() for _ in range(5)]) - resp = Client(port=exposed_port, return_responses=True).post('/', inputs=da) + resp = Client(port=exposed_port).post('/', inputs=da, return_responses=True) assert len(resp[0].docs) == 5 for doc in resp[0].docs: @@ -161,60 +162,64 @@ def test_reduce_needs(): assert (doc.embedding == np.zeros(3)).all() -def test_uses_before_reduce(): +def test_uses_before_reduce(port_generator): + exposed_port = port_generator() flow = ( Flow(port=exposed_port) - .add(uses=Executor1, name='pod0') - .add(uses=Executor2, needs='gateway', name='pod1') - .add(uses=Executor3, needs='gateway', name='pod2') - .add(needs=['pod0', 'pod1', 'pod2'], name='pod3', uses_before='BaseExecutor') + .add(uses=Executor1, name='pod0') + .add(uses=Executor2, needs='gateway', name='pod1') + .add(uses=Executor3, needs='gateway', name='pod2') + .add(needs=['pod0', 'pod1', 'pod2'], name='pod3', uses_before='BaseExecutor') ) - with flow as f: + with flow: da = DocumentArray([Document() for _ in range(5)]) - resp = Client(port=exposed_port, return_responses=True).post('/', inputs=da) + resp = Client(port=exposed_port).post('/', inputs=da, return_responses=True) # assert reduce happened because there is only BaseExecutor as uses_before assert len(resp[0].docs) == 5 -def test_uses_before_no_reduce_real_executor(): +def test_uses_before_no_reduce_real_executor(port_generator): + exposed_port = port_generator() flow = ( Flow(port=exposed_port) - .add(uses=Executor1, name='pod0') - .add(uses=Executor2, needs='gateway', name='pod1') - .add(uses=Executor3, needs='gateway', name='pod2') - .add(needs=['pod0', 'pod1', 'pod2'], name='pod3', uses=DummyExecutor) + .add(uses=Executor1, name='pod0') + .add(uses=Executor2, needs='gateway', name='pod1') + .add(uses=Executor3, needs='gateway', name='pod2') + .add(needs=['pod0', 'pod1', 'pod2'], name='pod3', uses=DummyExecutor) ) - with flow as f: + with flow: da = DocumentArray([Document() for _ in range(5)]) - resp = Client(port=exposed_port, return_responses=True).post('/', inputs=da) + resp = Client(port=exposed_port).post('/', inputs=da, return_responses=True) # assert no reduce happened assert len(resp[0].docs) == 1 assert resp[0].docs[0].id == 'fake_document' -def test_uses_before_no_reduce_real_executor_uses(): +def test_uses_before_no_reduce_real_executor_uses(port_generator): + exposed_port = port_generator() flow = ( Flow(port=exposed_port) - .add(uses=Executor1, name='pod0') - .add(uses=Executor2, needs='gateway', name='pod1') - .add(uses=Executor3, needs='gateway', name='pod2') - .add(needs=['pod0', 'pod1', 'pod2'], name='pod3', uses=DummyExecutor) + .add(uses=Executor1, name='pod0') + .add(uses=Executor2, needs='gateway', name='pod1') + .add(uses=Executor3, needs='gateway', name='pod2') + .add(needs=['pod0', 'pod1', 'pod2'], name='pod3', uses=DummyExecutor) ) - with flow as f: + with flow: da = DocumentArray([Document() for _ in range(5)]) - resp = Client(port=exposed_port, return_responses=True).post('/', inputs=da) + resp = Client(port=exposed_port).post('/', inputs=da, return_responses=True) # assert no reduce happened assert len(resp[0].docs) == 1 assert resp[0].docs[0].id == 'fake_document' -def test_reduce_status(): +def test_reduce_status(port_generator): + exposed_port = port_generator() n_shards = 2 flow = Flow(port=exposed_port).add( uses=ExecutorStatus, name='pod0', shards=n_shards, polling='all' @@ -222,8 +227,8 @@ def test_reduce_status(): with flow as f: da = DocumentArray([Document() for _ in range(5)]) - resp = Client(port=exposed_port, return_responses=True).post( - '/status', parameters={'foo': 'bar'}, inputs=da + resp = Client(port=exposed_port).post( + '/status', parameters={'foo': 'bar'}, inputs=da, return_responses=True ) assert resp[0].parameters['foo'] == 'bar'