Skip to content

Commit

Permalink
refactor(driver): move dbms to index module (#2312)
Browse files Browse the repository at this point in the history
* refactor(driver): move dbms to index module

* refactor(driver): remove post_init takes too long hint
  • Loading branch information
hanxiao committed Apr 18, 2021
1 parent d858a66 commit 4375be5
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 43 deletions.
26 changes: 0 additions & 26 deletions jina/drivers/dbms.py

This file was deleted.

28 changes: 26 additions & 2 deletions jina/drivers/index.py
Expand Up @@ -3,8 +3,6 @@

from typing import Iterable, Optional

import numpy as np

from . import BaseExecutableDriver, FlatRecursiveMixin

if False:
Expand Down Expand Up @@ -78,3 +76,29 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
keys, values = zip(*info)
self.check_key_length(keys)
self.exec_fn(keys, values)


class DBMSIndexDriver(BaseIndexDriver):
"""Forwards ids, vectors, serialized Document to a BaseDBMSIndexer"""

def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
info = [
(
doc.id,
doc.embedding,
self._doc_without_embedding(doc).SerializeToString(),
)
for doc in docs
]
if info:
ids, vecs, metas = zip(*info)
self.check_key_length(ids)
self.exec_fn(ids, vecs, metas)

@staticmethod
def _doc_without_embedding(d):
from .. import Document

new_doc = Document(d, copy=True)
new_doc.ClearField('embedding')
return new_doc
17 changes: 8 additions & 9 deletions jina/executors/__init__.py
Expand Up @@ -177,17 +177,16 @@ def _post_init_wrapper(
_requests: Optional[Dict] = None,
fill_in_metas: bool = True,
) -> None:
with TimeContext('post_init may take some time', self.logger):
if fill_in_metas:
if not _metas:
_metas = get_default_metas()
if fill_in_metas:
if not _metas:
_metas = get_default_metas()

self._fill_metas(_metas)
self.fill_in_drivers(_requests)
self._fill_metas(_metas)
self.fill_in_drivers(_requests)

_before = set(list(vars(self).keys()))
self.post_init()
self._post_init_vars = {k for k in vars(self) if k not in _before}
_before = set(list(vars(self).keys()))
self.post_init()
self._post_init_vars = {k for k in vars(self) if k not in _before}

def fill_in_drivers(self, _requests: Optional[Dict]):
"""
Expand Down
12 changes: 9 additions & 3 deletions tests/integration/dump/test_dump_dbms.py
Expand Up @@ -5,7 +5,7 @@
import pytest

from jina import Flow, Document
from jina.drivers.dbms import _doc_without_embedding
from jina.drivers.index import DBMSIndexDriver
from jina.executors.indexers.dump import import_vectors, import_metas
from jina.executors.indexers.query import BaseQueryIndexer
from jina.executors.indexers.query.compound import CompoundQueryExecutor
Expand Down Expand Up @@ -70,7 +70,10 @@ def assert_dump_data(dump_path, docs, shards, pea_id):
metas_dump = list(metas_dump)
np.testing.assert_equal(
metas_dump,
[_doc_without_embedding(d).SerializeToString() for d in docs_expected],
[
DBMSIndexDriver._doc_without_embedding(d).SerializeToString()
for d in docs_expected
],
)

# assert with Indexers
Expand Down Expand Up @@ -129,7 +132,10 @@ def _validate_results_nonempty(resp):
assert len(d.matches) > 0
for m in d.matches:
assert m.embedding.shape[0] == emb_size
assert _doc_without_embedding(m).SerializeToString() is not None
assert (
DBMSIndexDriver._doc_without_embedding(m).SerializeToString()
is not None
)
assert 'hello world' in m.text
assert f'tag data' in m.tags['tag_field']

Expand Down
14 changes: 11 additions & 3 deletions tests/unit/executors/dbms/test_dbms.py
@@ -1,4 +1,4 @@
from jina.drivers.dbms import _doc_without_embedding
from jina.drivers.index import DBMSIndexDriver
from jina.executors.indexers.dbms import BaseDBMSIndexer
from jina.executors.indexers.dbms.keyvalue import KeyValueDBMSIndexer
from tests import get_documents
Expand All @@ -8,7 +8,11 @@ def test_dbms_keyvalue(tmpdir, test_metas):
docs = list(get_documents(chunks=False, nr=10, same_content=True))
ids, vecs, meta = zip(
*[
(doc.id, doc.embedding, _doc_without_embedding(doc).SerializeToString())
(
doc.id,
doc.embedding,
DBMSIndexDriver._doc_without_embedding(doc).SerializeToString(),
)
for doc in docs
]
)
Expand All @@ -21,7 +25,11 @@ def test_dbms_keyvalue(tmpdir, test_metas):
new_docs = list(get_documents(chunks=False, nr=10, same_content=False))
ids, vecs, meta = zip(
*[
(doc.id, doc.embedding, _doc_without_embedding(doc).SerializeToString())
(
doc.id,
doc.embedding,
DBMSIndexDriver._doc_without_embedding(doc).SerializeToString(),
)
for doc in new_docs
]
)
Expand Down

0 comments on commit 4375be5

Please sign in to comment.