Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(executor): add use_default and with to requests field #1959

Merged
merged 3 commits into from
Feb 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 46 additions & 12 deletions jina/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,41 +125,75 @@ def __init__(self, *args, **kwargs):
self._post_init_vars = set()
self._last_snapshot_ts = datetime.now()


def _post_init_wrapper(self, _metas: Dict = None, _requests: Dict = None, fill_in_metas: bool = True) -> None:
with TimeContext('post_init may take some time', self.logger):
if fill_in_metas:
if not _metas:
_metas = get_default_metas()

self._fill_metas(_metas)

from ..executors.requests import get_default_reqs
default_requests = get_default_reqs(type.mro(self.__class__))

if not _requests:
from ..executors.requests import get_default_reqs
_requests = get_default_reqs(type.mro(self.__class__))
self._drivers = self._get_drivers_from_requests(default_requests)
else:
parsed_drivers = self._get_drivers_from_requests(_requests)

self._fill_metas(_metas)
self._fill_requests(_requests)
if _requests.get('use_default', False):
default_drivers = self._get_drivers_from_requests(default_requests)

for k, v in default_drivers.items():
if k not in parsed_drivers:
parsed_drivers[k] = v

self._drivers = parsed_drivers

_before = set(list(vars(self).keys()))
self.post_init()
self._post_init_vars = {k for k in vars(self) if k not in _before}

def _fill_requests(self, _requests):
self._drivers = {} # type: Dict[str, List['BaseDriver']]
@staticmethod
def _get_drivers_from_requests(_requests):
_drivers = {} # type: Dict[str, List['BaseDriver']]

if _requests and 'on' in _requests and isinstance(_requests['on'], dict):
# if control request is forget in YAML, then fill it
if 'ControlRequest' not in _requests['on']:
from ..drivers.control import ControlReqDriver
_requests['on']['ControlRequest'] = [ControlReqDriver()]

for req_type, drivers in _requests['on'].items():
for req_type, drivers_spec in _requests['on'].items():
if isinstance(req_type, str):
req_type = [req_type]
if isinstance(drivers_spec, list):
# old syntax
drivers = drivers_spec
common_kwargs = {}
elif isinstance(drivers_spec, dict):
drivers = drivers_spec['drivers']
common_kwargs = drivers_spec['with']
else:
raise TypeError(f'unsupported type of driver spec: {drivers_spec}')

for r in req_type:
if r not in self._drivers:
self._drivers[r] = list()
if self._drivers[r] != drivers:
self._drivers[r].extend(drivers)
if r not in _drivers:
_drivers[r] = list()
if _drivers[r] != drivers:
_drivers[r].extend(drivers)

# inject common kwargs to drivers
if common_kwargs:
new_drivers = []
for d in _drivers[r]:
new_init_kwargs_dict = {k:v for k, v in d._init_kwargs_dict.items()}
new_init_kwargs_dict.update(common_kwargs)
new_drivers.append(d.__class__(**new_init_kwargs_dict))
_drivers[r].clear()
_drivers[r] = new_drivers

return _drivers

def _fill_metas(self, _metas):
unresolved_attr = False
Expand Down
23 changes: 11 additions & 12 deletions jina/resources/executors._eval_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@ metas:
requests:
on:
SearchRequest:
- !RankEvaluateDriver
with:
field: tags__id
executor: precision
running_avg: true
traversal_paths: [ 'r' ]
- !RankEvaluateDriver
with:
field: tags__id
executor: recall
running_avg: true
traversal_paths: [ 'r' ]
with:
traversal_paths: ['r']
running_avg: true
field: tags__id
drivers:
- !RankEvaluateDriver
with:
executor: precision
- !RankEvaluateDriver
with:
executor: recall
25 changes: 12 additions & 13 deletions jina/resources/executors._merge_matches_topk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,18 @@ metas:
requests:
on:
[SearchRequest, TrainRequest, IndexRequest]:
- !ReduceAllDriver
with:
traversal_paths: ['m']
- !SortQL
with:
reverse: False
traversal_paths: ['m']
field: 'score__value'
- !SliceQL
with:
start: 0
end: 10 # is overwritten by the QueryLangDriver
traversal_paths: ['m']
with:
traversal_paths: [ 'm' ]
drivers:
- !ReduceAllDriver {}
- !SortQL
with:
reverse: False
field: 'score__value'
- !SliceQL
with:
start: 0
end: 10 # is overwritten by the QueryLangDriver
ControlRequest:
- !ControlReqDriver {}
[DeleteRequest, UpdateRequest]:
Expand Down
27 changes: 14 additions & 13 deletions jina/resources/executors.requests.BaseRanker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@ on:
ControlRequest:
- !ControlReqDriver {}
SearchRequest:
- !ExcludeQL
with:
fields:
- embedding
- !SortQL
with:
field: 'score__value'
traversal_paths: ['m']
- !SliceQL
with:
traversal_paths: ['m']
start: 0
end: 50
with:
traversal_paths: [ 'm' ]
drivers:
- !ExcludeQL
with:
fields:
- embedding
- !SortQL
with:
field: 'score__value'
- !SliceQL
with:
start: 0
end: 50
23 changes: 12 additions & 11 deletions jina/resources/executors.requests.BaseRankingEvaluator.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ on:
ControlRequest:
- !ControlReqDriver {}
SearchRequest:
- !ExcludeQL
with:
fields:
- embedding
- buffer
- blob
- text
with:
traversal_paths: ['r']
- !RankEvaluateDriver
with:
id_tag: 'id'
traversal_paths: ['r']
drivers:
- !ExcludeQL
with:
fields:
- embedding
- buffer
- blob
- text
- !RankEvaluateDriver
with:
id_tag: 'id'
2 changes: 1 addition & 1 deletion jina/resources/executors.requests.BinaryPbIndexer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ on:
- buffer
- !KVIndexDriver {}
DeleteRequest:
- !DeleteDriver { }
- !DeleteDriver {}
27 changes: 13 additions & 14 deletions jina/resources/helloworld.reduce.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@ metas:
name: top50
requests:
on:
[SearchRequest]:
- !ReduceAllDriver
with:
traversal_paths: ['m']
- !SortQL
with:
reverse: true
field: 'score__value'
traversal_paths: ['m']
- !SliceQL
with:
start: 0
end: 20
traversal_paths: ['m']
SearchRequest:
with:
traversal_paths: ['m']
drivers:
- !ReduceAllDriver {}
- !SortQL
with:
reverse: true
field: 'score__value'
- !SliceQL
with:
start: 0
end: 20
ControlRequest:
- !ControlReqDriver {}
119 changes: 119 additions & 0 deletions tests/unit/executors/test_set_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from jina.drivers.delete import DeleteDriver
from jina.drivers.encode import EncodeDriver
from jina.drivers.querylang.filter import FilterQL
from jina.executors import BaseExecutor

y_no_fill = """
!BaseEncoder
requests:
use_default: false
"""


def test_no_fill():
be = BaseExecutor.load_config(y_no_fill)
assert not be._drivers


y_no_fill_with_index_request = """
!BaseEncoder
requests:
use_default: false
on:
IndexRequest:
- !RouteDriver {}
"""


def test_no_fill_with_index_request():
be = BaseExecutor.load_config(y_no_fill_with_index_request)
assert len(be._drivers) == 2
assert 'IndexRequest' in be._drivers
assert 'ControlRequest' in be._drivers


y_fill_default_with_index_request = """
!BaseEncoder
requests:
use_default: true
on:
IndexRequest:
- !EncodeDriver {}
"""


def test_fill_default_with_index_request():
be = BaseExecutor.load_config(y_fill_default_with_index_request)
assert len(be._drivers) == 6
assert isinstance(be._drivers['IndexRequest'][0], EncodeDriver)
print(be._drivers['IndexRequest'][0]._init_kwargs_dict)


y_fill_default_with_index_request_with_common = """
!BaseEncoder
requests:
use_default: true
on:
IndexRequest:
with:
traversal_paths: ['mmm']
drivers:
- !FilterQL
with:
lookups:
mime_type: image/jpeg
- !EncodeDriver {}
"""


def test_with_common_kwargs_on_index():
be = BaseExecutor.load_config(y_fill_default_with_index_request_with_common)
assert len(be._drivers) == 6
assert isinstance(be._drivers['IndexRequest'][1], EncodeDriver)
assert isinstance(be._drivers['IndexRequest'][0], FilterQL)
assert be._drivers['IndexRequest'][0]._traversal_paths == ['mmm']
assert be._drivers['IndexRequest'][1]._traversal_paths == ['mmm']


y_fill_default_with_two_request_with_common = """
!BaseEncoder
requests:
use_default: true
on:
[IndexRequest, SearchRequest]:
with:
traversal_paths: ['mmm']
drivers:
- !FilterQL
with:
lookups:
mime_type: image/jpeg
- !EncodeDriver {}
[DeleteRequest]:
with:
traversal_paths: ['ccc']
drivers:
- !FilterQL
with:
lookups:
mime_type: image/jpeg
- !DeleteDriver {}
"""


def test_with_common_kwargs_on_two_requests():
be = BaseExecutor.load_config(y_fill_default_with_two_request_with_common)
assert len(be._drivers) == 6

for r in ('IndexRequest', 'SearchRequest', 'DeleteRequest'):
if r == 'DeleteRequest':
assert isinstance(be._drivers[r][1], DeleteDriver)
else:
assert isinstance(be._drivers[r][1], EncodeDriver)
assert isinstance(be._drivers[r][0], FilterQL)
if r == 'DeleteRequest':
assert be._drivers[r][0]._traversal_paths == ['ccc']
assert be._drivers[r][1]._traversal_paths == ['ccc']
else:
assert be._drivers[r][0]._traversal_paths == ['mmm']
assert be._drivers[r][1]._traversal_paths == ['mmm']