Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(driver): move dbms to index module #2312

Merged
merged 4 commits into from
Apr 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 0 additions & 26 deletions jina/drivers/dbms.py

This file was deleted.

28 changes: 26 additions & 2 deletions jina/drivers/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

from typing import Iterable, Optional

import numpy as np

from . import BaseExecutableDriver, FlatRecursiveMixin

if False:
Expand Down Expand Up @@ -68,3 +66,29 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
keys, values = zip(*info)
self.check_key_length(keys)
self.exec_fn(keys, values)


class DBMSIndexDriver(BaseIndexDriver):
"""Forwards ids, vectors, serialized Document to a BaseDBMSIndexer"""

def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
info = [
(
doc.id,
doc.embedding,
self._doc_without_embedding(doc).SerializeToString(),
)
for doc in docs
]
if info:
ids, vecs, metas = zip(*info)
self.check_key_length(ids)
self.exec_fn(ids, vecs, metas)

@staticmethod
def _doc_without_embedding(d):
from .. import Document

new_doc = Document(d, copy=True)
new_doc.ClearField('embedding')
return new_doc
17 changes: 8 additions & 9 deletions jina/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,16 @@ def _post_init_wrapper(
_requests: Optional[Dict] = None,
fill_in_metas: bool = True,
) -> None:
with TimeContext('post_init may take some time', self.logger):
if fill_in_metas:
if not _metas:
_metas = get_default_metas()
if fill_in_metas:
if not _metas:
_metas = get_default_metas()

self._fill_metas(_metas)
self.fill_in_drivers(_requests)
self._fill_metas(_metas)
self.fill_in_drivers(_requests)

_before = set(list(vars(self).keys()))
self.post_init()
self._post_init_vars = {k for k in vars(self) if k not in _before}
_before = set(list(vars(self).keys()))
self.post_init()
self._post_init_vars = {k for k in vars(self) if k not in _before}

def fill_in_drivers(self, _requests: Optional[Dict]):
"""
Expand Down
12 changes: 9 additions & 3 deletions tests/integration/dump/test_dump_dbms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from jina import Flow, Document
from jina.drivers.dbms import _doc_without_embedding
from jina.drivers.index import DBMSIndexDriver
from jina.executors.indexers.dump import import_vectors, import_metas
from jina.executors.indexers.query import BaseQueryIndexer
from jina.executors.indexers.query.compound import CompoundQueryExecutor
Expand Down Expand Up @@ -70,7 +70,10 @@ def assert_dump_data(dump_path, docs, shards, pea_id):
metas_dump = list(metas_dump)
np.testing.assert_equal(
metas_dump,
[_doc_without_embedding(d).SerializeToString() for d in docs_expected],
[
DBMSIndexDriver._doc_without_embedding(d).SerializeToString()
for d in docs_expected
],
)

# assert with Indexers
Expand Down Expand Up @@ -129,7 +132,10 @@ def _validate_results_nonempty(resp):
assert len(d.matches) > 0
for m in d.matches:
assert m.embedding.shape[0] == emb_size
assert _doc_without_embedding(m).SerializeToString() is not None
assert (
DBMSIndexDriver._doc_without_embedding(m).SerializeToString()
is not None
)
assert 'hello world' in m.text
assert f'tag data' in m.tags['tag_field']

Expand Down
14 changes: 11 additions & 3 deletions tests/unit/executors/dbms/test_dbms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jina.drivers.dbms import _doc_without_embedding
from jina.drivers.index import DBMSIndexDriver
from jina.executors.indexers.dbms import BaseDBMSIndexer
from jina.executors.indexers.dbms.keyvalue import KeyValueDBMSIndexer
from tests import get_documents
Expand All @@ -8,7 +8,11 @@ def test_dbms_keyvalue(tmpdir, test_metas):
docs = list(get_documents(chunks=False, nr=10, same_content=True))
ids, vecs, meta = zip(
*[
(doc.id, doc.embedding, _doc_without_embedding(doc).SerializeToString())
(
doc.id,
doc.embedding,
DBMSIndexDriver._doc_without_embedding(doc).SerializeToString(),
)
for doc in docs
]
)
Expand All @@ -21,7 +25,11 @@ def test_dbms_keyvalue(tmpdir, test_metas):
new_docs = list(get_documents(chunks=False, nr=10, same_content=False))
ids, vecs, meta = zip(
*[
(doc.id, doc.embedding, _doc_without_embedding(doc).SerializeToString())
(
doc.id,
doc.embedding,
DBMSIndexDriver._doc_without_embedding(doc).SerializeToString(),
)
for doc in new_docs
]
)
Expand Down