Skip to content

Commit

Permalink
Merge 2896d75 into be02bd9
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Feb 17, 2021
2 parents be02bd9 + 2896d75 commit 4dafc09
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 76 deletions.
58 changes: 46 additions & 12 deletions jina/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,41 +125,75 @@ def __init__(self, *args, **kwargs):
self._post_init_vars = set()
self._last_snapshot_ts = datetime.now()


def _post_init_wrapper(self, _metas: Dict = None, _requests: Dict = None, fill_in_metas: bool = True) -> None:
with TimeContext('post_init may take some time', self.logger):
if fill_in_metas:
if not _metas:
_metas = get_default_metas()

self._fill_metas(_metas)

from ..executors.requests import get_default_reqs
default_requests = get_default_reqs(type.mro(self.__class__))

if not _requests:
from ..executors.requests import get_default_reqs
_requests = get_default_reqs(type.mro(self.__class__))
self._drivers = self._get_drivers_from_requests(default_requests)
else:
parsed_drivers = self._get_drivers_from_requests(_requests)

self._fill_metas(_metas)
self._fill_requests(_requests)
if _requests.get('use_default', False):
default_drivers = self._get_drivers_from_requests(default_requests)

for k, v in default_drivers.items():
if k not in parsed_drivers:
parsed_drivers[k] = v

self._drivers = parsed_drivers

_before = set(list(vars(self).keys()))
self.post_init()
self._post_init_vars = {k for k in vars(self) if k not in _before}

def _fill_requests(self, _requests):
self._drivers = {} # type: Dict[str, List['BaseDriver']]
@staticmethod
def _get_drivers_from_requests(_requests):
_drivers = {} # type: Dict[str, List['BaseDriver']]

if _requests and 'on' in _requests and isinstance(_requests['on'], dict):
# if control request is forget in YAML, then fill it
if 'ControlRequest' not in _requests['on']:
from ..drivers.control import ControlReqDriver
_requests['on']['ControlRequest'] = [ControlReqDriver()]

for req_type, drivers in _requests['on'].items():
for req_type, drivers_spec in _requests['on'].items():
if isinstance(req_type, str):
req_type = [req_type]
if isinstance(drivers_spec, list):
# old syntax
drivers = drivers_spec
common_kwargs = {}
elif isinstance(drivers_spec, dict):
drivers = drivers_spec['drivers']
common_kwargs = drivers_spec['with']
else:
raise TypeError(f'unsupported type of driver spec: {drivers_spec}')

for r in req_type:
if r not in self._drivers:
self._drivers[r] = list()
if self._drivers[r] != drivers:
self._drivers[r].extend(drivers)
if r not in _drivers:
_drivers[r] = list()
if _drivers[r] != drivers:
_drivers[r].extend(drivers)

# inject common kwargs to drivers
if common_kwargs:
new_drivers = []
for d in _drivers[r]:
new_init_kwargs_dict = {k:v for k, v in d._init_kwargs_dict.items()}
new_init_kwargs_dict.update(common_kwargs)
new_drivers.append(d.__class__(**new_init_kwargs_dict))
_drivers[r].clear()
_drivers[r] = new_drivers

return _drivers

def _fill_metas(self, _metas):
unresolved_attr = False
Expand Down
23 changes: 11 additions & 12 deletions jina/resources/executors._eval_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@ metas:
requests:
on:
SearchRequest:
- !RankEvaluateDriver
with:
field: tags__id
executor: precision
running_avg: true
traversal_paths: [ 'r' ]
- !RankEvaluateDriver
with:
field: tags__id
executor: recall
running_avg: true
traversal_paths: [ 'r' ]
with:
traversal_paths: ['r']
running_avg: true
field: tags__id
drivers:
- !RankEvaluateDriver
with:
executor: precision
- !RankEvaluateDriver
with:
executor: recall
25 changes: 12 additions & 13 deletions jina/resources/executors._merge_matches_topk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,18 @@ metas:
requests:
on:
[SearchRequest, TrainRequest, IndexRequest]:
- !ReduceAllDriver
with:
traversal_paths: ['m']
- !SortQL
with:
reverse: False
traversal_paths: ['m']
field: 'score__value'
- !SliceQL
with:
start: 0
end: 10 # is overwritten by the QueryLangDriver
traversal_paths: ['m']
with:
traversal_paths: [ 'm' ]
drivers:
- !ReduceAllDriver {}
- !SortQL
with:
reverse: False
field: 'score__value'
- !SliceQL
with:
start: 0
end: 10 # is overwritten by the QueryLangDriver
ControlRequest:
- !ControlReqDriver {}
[DeleteRequest, UpdateRequest]:
Expand Down
27 changes: 14 additions & 13 deletions jina/resources/executors.requests.BaseRanker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@ on:
ControlRequest:
- !ControlReqDriver {}
SearchRequest:
- !ExcludeQL
with:
fields:
- embedding
- !SortQL
with:
field: 'score__value'
traversal_paths: ['m']
- !SliceQL
with:
traversal_paths: ['m']
start: 0
end: 50
with:
traversal_paths: [ 'm' ]
drivers:
- !ExcludeQL
with:
fields:
- embedding
- !SortQL
with:
field: 'score__value'
- !SliceQL
with:
start: 0
end: 50
23 changes: 12 additions & 11 deletions jina/resources/executors.requests.BaseRankingEvaluator.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ on:
ControlRequest:
- !ControlReqDriver {}
SearchRequest:
- !ExcludeQL
with:
fields:
- embedding
- buffer
- blob
- text
with:
traversal_paths: ['r']
- !RankEvaluateDriver
with:
id_tag: 'id'
traversal_paths: ['r']
drivers:
- !ExcludeQL
with:
fields:
- embedding
- buffer
- blob
- text
- !RankEvaluateDriver
with:
id_tag: 'id'
2 changes: 1 addition & 1 deletion jina/resources/executors.requests.BinaryPbIndexer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ on:
- buffer
- !KVIndexDriver {}
DeleteRequest:
- !DeleteDriver { }
- !DeleteDriver {}
27 changes: 13 additions & 14 deletions jina/resources/helloworld.reduce.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@ metas:
name: top50
requests:
on:
[SearchRequest]:
- !ReduceAllDriver
with:
traversal_paths: ['m']
- !SortQL
with:
reverse: true
field: 'score__value'
traversal_paths: ['m']
- !SliceQL
with:
start: 0
end: 20
traversal_paths: ['m']
SearchRequest:
with:
traversal_paths: ['m']
drivers:
- !ReduceAllDriver {}
- !SortQL
with:
reverse: true
field: 'score__value'
- !SliceQL
with:
start: 0
end: 20
ControlRequest:
- !ControlReqDriver {}
119 changes: 119 additions & 0 deletions tests/unit/executors/test_set_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from jina.drivers.delete import DeleteDriver
from jina.drivers.encode import EncodeDriver
from jina.drivers.querylang.filter import FilterQL
from jina.executors import BaseExecutor

y_no_fill = """
!BaseEncoder
requests:
use_default: false
"""


def test_no_fill():
be = BaseExecutor.load_config(y_no_fill)
assert not be._drivers


y_no_fill_with_index_request = """
!BaseEncoder
requests:
use_default: false
on:
IndexRequest:
- !RouteDriver {}
"""


def test_no_fill_with_index_request():
be = BaseExecutor.load_config(y_no_fill_with_index_request)
assert len(be._drivers) == 2
assert 'IndexRequest' in be._drivers
assert 'ControlRequest' in be._drivers


y_fill_default_with_index_request = """
!BaseEncoder
requests:
use_default: true
on:
IndexRequest:
- !EncodeDriver {}
"""


def test_fill_default_with_index_request():
be = BaseExecutor.load_config(y_fill_default_with_index_request)
assert len(be._drivers) == 6
assert isinstance(be._drivers['IndexRequest'][0], EncodeDriver)
print(be._drivers['IndexRequest'][0]._init_kwargs_dict)


y_fill_default_with_index_request_with_common = """
!BaseEncoder
requests:
use_default: true
on:
IndexRequest:
with:
traversal_paths: ['mmm']
drivers:
- !FilterQL
with:
lookups:
mime_type: image/jpeg
- !EncodeDriver {}
"""


def test_with_common_kwargs_on_index():
be = BaseExecutor.load_config(y_fill_default_with_index_request_with_common)
assert len(be._drivers) == 6
assert isinstance(be._drivers['IndexRequest'][1], EncodeDriver)
assert isinstance(be._drivers['IndexRequest'][0], FilterQL)
assert be._drivers['IndexRequest'][0]._traversal_paths == ['mmm']
assert be._drivers['IndexRequest'][1]._traversal_paths == ['mmm']


y_fill_default_with_two_request_with_common = """
!BaseEncoder
requests:
use_default: true
on:
[IndexRequest, SearchRequest]:
with:
traversal_paths: ['mmm']
drivers:
- !FilterQL
with:
lookups:
mime_type: image/jpeg
- !EncodeDriver {}
[DeleteRequest]:
with:
traversal_paths: ['ccc']
drivers:
- !FilterQL
with:
lookups:
mime_type: image/jpeg
- !DeleteDriver {}
"""


def test_with_common_kwargs_on_two_requests():
be = BaseExecutor.load_config(y_fill_default_with_two_request_with_common)
assert len(be._drivers) == 6

for r in ('IndexRequest', 'SearchRequest', 'DeleteRequest'):
if r == 'DeleteRequest':
assert isinstance(be._drivers[r][1], DeleteDriver)
else:
assert isinstance(be._drivers[r][1], EncodeDriver)
assert isinstance(be._drivers[r][0], FilterQL)
if r == 'DeleteRequest':
assert be._drivers[r][0]._traversal_paths == ['ccc']
assert be._drivers[r][1]._traversal_paths == ['ccc']
else:
assert be._drivers[r][0]._traversal_paths == ['mmm']
assert be._drivers[r][1]._traversal_paths == ['mmm']

0 comments on commit 4dafc09

Please sign in to comment.