Skip to content

Commit

Permalink
feat(peas): add modes_id argument, prepend filterql driver when found
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jul 24, 2020
1 parent e581304 commit 8b850e2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
14 changes: 6 additions & 8 deletions jina/executors/__init__.py
Expand Up @@ -533,21 +533,19 @@ def _dump_instance_to_yaml(data):
r['metas'] = p
return r

def add_driver(self, driver: 'BaseDriver', req_type: str):
""" Add a driver to this executor.
def prepend_driver(self, driver: 'BaseDriver', req_type: str):
""" Add a driver at the beginning of the drivers array.
.. warning::
This has to be used *before* ``.attach(pea=self)`` to be effective
:param driver: the driver to add
:param req_type: the request type to handle by this driver
:param req_type: the request type to handle by this driver.
If the req_type is not already handled by a driver, it won't be added
:return:
"""
if not isinstance(driver, list):
driver = [driver]
if req_type not in self._drivers:
self._drivers[req_type] = []
self._drivers[req_type].extend(driver)
if req_type in self._drivers.keys():
self._drivers[req_type].insert(0, driver)

def attach(self, *args, **kwargs):
"""Attach this executor to a :class:`jina.peapods.pea.BasePea`.
Expand Down
1 change: 1 addition & 0 deletions jina/main/parser.py
Expand Up @@ -197,6 +197,7 @@ def set_pea_parser(parser=None):
gp0.add_argument('--py-modules', type=str, nargs='*',
help='the customized python modules need to be imported before loading the'
' executor')
gp0.add_argument('--modes', type=str, nargs='*', help='List of mode_id the pea wants to process documents from')

gp1 = add_arg_group(parser, 'pea container arguments')
gp1.add_argument('--uses-internal', type=str, default='BaseExecutor',
Expand Down
5 changes: 5 additions & 0 deletions jina/peapods/pea.py
Expand Up @@ -15,6 +15,7 @@
from .zmq import send_ctrl_message, Zmqlet, ZmqStreamlet
from .. import __ready_msg__, __stop_msg__
from ..drivers.helper import routes2str, add_route
from ..drivers.querylang.filter import FilterQL
from ..enums import PeaRoleType, OnErrorSkip
from ..excepts import NoExplicitMessage, ExecutorFailToLoad, MemoryOverHighWatermark, DriverError
from ..executors import BaseExecutor
Expand Down Expand Up @@ -191,6 +192,10 @@ def load_executor(self):
try:
self.executor = BaseExecutor.load_config(self.args.uses if valid_local_config_source(self.args.uses) else self.args.uses_internal,
self.args.separated_workspace, self.args.replica_id)
if self.args.mode_ids:
# if mode_ids are explicitly requested, it means we are in multimode so we prepend a driver that will filter mode_ids not requested
for req_type in ['IndexRequest', 'SearchRequest', 'TrainRequest', 'ControlRequest']:
self.executor.prepend_driver(FilterQL({'mode_id__in': self.args.mode_ids}), req_type)
self.executor.attach(pea=self)
except FileNotFoundError:
raise ExecutorFailToLoad
Expand Down

0 comments on commit 8b850e2

Please sign in to comment.