Skip to content

Commit

Permalink
fix: pass shards args in k8s (#3689)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Oct 18, 2021
1 parent f91dce4 commit cf6705f
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 6 deletions.
5 changes: 4 additions & 1 deletion jina/peapods/pods/k8s.py
Expand Up @@ -402,7 +402,10 @@ def _parse_deployment_args(self, args):
args.uses_after or __default_executor__
)

parsed_args['deployments'] = [args] * shards
for i in range(shards):
cargs = copy.deepcopy(args)
cargs.shard_id = i
parsed_args['deployments'].append(cargs)
return parsed_args

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
Expand Down
11 changes: 10 additions & 1 deletion tests/k8s/executor-merger/exec_merger.py
Expand Up @@ -12,17 +12,26 @@ def __init__(self, *args, **kwargs):
self.logger = JinaLogger(self.__class__.__name__)

@requests
def debug(self, docs_matrix: List[DocumentArray], parameters: Dict, **kwargs):
def debug(self, docs_matrix: List[DocumentArray], **kwargs):
self.logger.debug(
f'Received doc matrix in exec-merger with length {len(docs_matrix)}.'
)

result = DocumentArray()
for docs in zip(*docs_matrix):
traversed_executors = [doc.tags['traversed-executors'] for doc in docs]
shard_ids = [doc.tags['shard_id'] for doc in docs]
pea_ids = [doc.tags['pea_id'] for doc in docs]
shards = [doc.tags['shards'] for doc in docs]
parallels = [doc.tags['parallel'] for doc in docs]
traversed_executors = list(chain(*traversed_executors))
doc = Document()
doc.tags['traversed-executors'] = traversed_executors
doc.tags['shard_id'] = shard_ids
doc.tags['pea_id'] = pea_ids
doc.tags['shards'] = shards
doc.tags['parallel'] = parallels

result.append(doc)

return result
5 changes: 4 additions & 1 deletion tests/k8s/test-executor/debug_executor.py
@@ -1,5 +1,4 @@
import os
from typing import Dict

from jina import Executor, requests, DocumentArray

Expand All @@ -25,6 +24,10 @@ def debug(self, docs: DocumentArray, **kwargs):
traversed = list(doc.tags.get(key))
traversed.append(self._name)
doc.tags[key] = traversed
doc.tags['parallel'] = self.runtime_args.parallel
doc.tags['shards'] = self.runtime_args.shards
doc.tags['shard_id'] = self.runtime_args.shard_id
doc.tags['pea_id'] = self.runtime_args.pea_id

@requests(on='/env')
def env(self, docs: DocumentArray, **kwargs):
Expand Down
7 changes: 6 additions & 1 deletion tests/k8s/test_k8s.py
Expand Up @@ -190,6 +190,10 @@ def test_flow_with_sharding(
assert len(docs) == 10
for doc in docs:
assert set(doc.tags['traversed-executors']) == expected_traversed_executors
assert set(doc.tags['pea_id']) == {0, 1}
assert set(doc.tags['shard_id']) == {0, 1}
assert doc.tags['parallel'] == [2, 2]
assert doc.tags['shards'] == [2, 2]


@pytest.mark.timeout(3600)
Expand Down Expand Up @@ -281,7 +285,7 @@ def send_requests(
except:
_logger.error(f' Some error happened while sending requests')
exception_to_raise_event.set()
_logger.debug(f' finishing the process')
_logger.debug(f' send requests finished')

with k8s_flow_with_reload_executor as flow:
with kubernetes_tools.get_port_forward_contextmanager(
Expand Down Expand Up @@ -333,6 +337,7 @@ def send_requests(
)
logger.debug(f' Joining the process')
process.join()
logger.debug(f' Process succesfully joined')

assert not exception_to_raise.set()

Expand Down
14 changes: 12 additions & 2 deletions tests/unit/peapods/pods/test_k8s_pod.py
Expand Up @@ -64,7 +64,12 @@ def test_parse_args(shards: int):
assert namespace_equal(
pod.deployment_args['tail_deployment'], None if shards == 1 else args
)
assert pod.deployment_args['deployments'] == [args] * shards
for i, depl_arg in enumerate(pod.deployment_args['deployments']):
import copy

cargs = copy.deepcopy(args)
cargs.shard_id = i
assert depl_arg == cargs


@pytest.mark.parametrize('shards', [2, 3, 4, 5])
Expand All @@ -91,7 +96,12 @@ def test_parse_args_custom_executor(shards: int):
args, pod.deployment_args['tail_deployment'], skip_attr=('uses',)
)
assert pod.deployment_args['tail_deployment'].uses == uses_after
assert pod.deployment_args['deployments'] == [args] * shards
for i, depl_arg in enumerate(pod.deployment_args['deployments']):
import copy

cargs = copy.deepcopy(args)
cargs.shard_id = i
assert depl_arg == cargs


@pytest.mark.parametrize(
Expand Down

0 comments on commit cf6705f

Please sign in to comment.