diff --git a/jina/drivers/dbms.py b/jina/drivers/dbms.py deleted file mode 100644 index 63736947352be..0000000000000 --- a/jina/drivers/dbms.py +++ /dev/null @@ -1,26 +0,0 @@ -from jina.drivers.index import BaseIndexDriver -from .. import Document - -# noinspection PyUnreachableCode -if False: - from ..types.sets import DocumentSet - - -def _doc_without_embedding(d): - new_doc = Document(d, copy=True) - new_doc.ClearField('embedding') - return new_doc - - -class DBMSIndexDriver(BaseIndexDriver): - """Forwards ids, vectors, serialized Document to a BaseDBMSIndexer""" - - def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None: - info = [ - (doc.id, doc.embedding, _doc_without_embedding(doc).SerializeToString()) - for doc in docs - ] - if info: - ids, vecs, metas = zip(*info) - self.check_key_length(ids) - self.exec_fn(ids, vecs, metas) diff --git a/jina/drivers/index.py b/jina/drivers/index.py index f877c824477b2..7f7483f132c0c 100644 --- a/jina/drivers/index.py +++ b/jina/drivers/index.py @@ -3,8 +3,6 @@ from typing import Iterable, Optional -import numpy as np - from . import BaseExecutableDriver, FlatRecursiveMixin if False: @@ -78,3 +76,29 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None: keys, values = zip(*info) self.check_key_length(keys) self.exec_fn(keys, values) + + +class DBMSIndexDriver(BaseIndexDriver): + """Forwards ids, vectors, serialized Document to a BaseDBMSIndexer""" + + def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None: + info = [ + ( + doc.id, + doc.embedding, + self._doc_without_embedding(doc).SerializeToString(), + ) + for doc in docs + ] + if info: + ids, vecs, metas = zip(*info) + self.check_key_length(ids) + self.exec_fn(ids, vecs, metas) + + @staticmethod + def _doc_without_embedding(d): + from .. import Document + + new_doc = Document(d, copy=True) + new_doc.ClearField('embedding') + return new_doc diff --git a/jina/executors/__init__.py b/jina/executors/__init__.py index fdb9161cf53df..ae9793ad4afa1 100644 --- a/jina/executors/__init__.py +++ b/jina/executors/__init__.py @@ -177,17 +177,16 @@ def _post_init_wrapper( _requests: Optional[Dict] = None, fill_in_metas: bool = True, ) -> None: - with TimeContext('post_init may take some time', self.logger): - if fill_in_metas: - if not _metas: - _metas = get_default_metas() + if fill_in_metas: + if not _metas: + _metas = get_default_metas() - self._fill_metas(_metas) - self.fill_in_drivers(_requests) + self._fill_metas(_metas) + self.fill_in_drivers(_requests) - _before = set(list(vars(self).keys())) - self.post_init() - self._post_init_vars = {k for k in vars(self) if k not in _before} + _before = set(list(vars(self).keys())) + self.post_init() + self._post_init_vars = {k for k in vars(self) if k not in _before} def fill_in_drivers(self, _requests: Optional[Dict]): """ diff --git a/tests/integration/dump/test_dump_dbms.py b/tests/integration/dump/test_dump_dbms.py index ace290ad8ad3c..f306152700e53 100644 --- a/tests/integration/dump/test_dump_dbms.py +++ b/tests/integration/dump/test_dump_dbms.py @@ -5,7 +5,7 @@ import pytest from jina import Flow, Document -from jina.drivers.dbms import _doc_without_embedding +from jina.drivers.index import DBMSIndexDriver from jina.executors.indexers.dump import import_vectors, import_metas from jina.executors.indexers.query import BaseQueryIndexer from jina.executors.indexers.query.compound import CompoundQueryExecutor @@ -70,7 +70,10 @@ def assert_dump_data(dump_path, docs, shards, pea_id): metas_dump = list(metas_dump) np.testing.assert_equal( metas_dump, - [_doc_without_embedding(d).SerializeToString() for d in docs_expected], + [ + DBMSIndexDriver._doc_without_embedding(d).SerializeToString() + for d in docs_expected + ], ) # assert with Indexers @@ -129,7 +132,10 @@ def _validate_results_nonempty(resp): assert len(d.matches) > 0 for m in d.matches: assert m.embedding.shape[0] == emb_size - assert _doc_without_embedding(m).SerializeToString() is not None + assert ( + DBMSIndexDriver._doc_without_embedding(m).SerializeToString() + is not None + ) assert 'hello world' in m.text assert f'tag data' in m.tags['tag_field'] diff --git a/tests/unit/executors/dbms/test_dbms.py b/tests/unit/executors/dbms/test_dbms.py index b475ac05ae50b..926d3691078d8 100644 --- a/tests/unit/executors/dbms/test_dbms.py +++ b/tests/unit/executors/dbms/test_dbms.py @@ -1,4 +1,4 @@ -from jina.drivers.dbms import _doc_without_embedding +from jina.drivers.index import DBMSIndexDriver from jina.executors.indexers.dbms import BaseDBMSIndexer from jina.executors.indexers.dbms.keyvalue import KeyValueDBMSIndexer from tests import get_documents @@ -8,7 +8,11 @@ def test_dbms_keyvalue(tmpdir, test_metas): docs = list(get_documents(chunks=False, nr=10, same_content=True)) ids, vecs, meta = zip( *[ - (doc.id, doc.embedding, _doc_without_embedding(doc).SerializeToString()) + ( + doc.id, + doc.embedding, + DBMSIndexDriver._doc_without_embedding(doc).SerializeToString(), + ) for doc in docs ] ) @@ -21,7 +25,11 @@ def test_dbms_keyvalue(tmpdir, test_metas): new_docs = list(get_documents(chunks=False, nr=10, same_content=False)) ids, vecs, meta = zip( *[ - (doc.id, doc.embedding, _doc_without_embedding(doc).SerializeToString()) + ( + doc.id, + doc.embedding, + DBMSIndexDriver._doc_without_embedding(doc).SerializeToString(), + ) for doc in new_docs ] )