Skip to content

Commit

Permalink
refactor(types): move add_chunk add_match to Set (#1343)
Browse files Browse the repository at this point in the history
* refactor(types): refactor constructor for ql, doc, req
  • Loading branch information
hanxiao committed Nov 24, 2020
1 parent c596897 commit f3a7c15
Show file tree
Hide file tree
Showing 22 changed files with 255 additions and 210 deletions.
2 changes: 1 addition & 1 deletion jina/drivers/craft.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ def update(doc, ret):
with Document(length=len(ret), **r) as c:
if not c.mime_type:
c.mime_type = doc.mime_type
doc.add_chunk(c)
doc.chunks.append(c)
2 changes: 1 addition & 1 deletion jina/drivers/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class LoadGroundTruthDriver(KVSearchDriver):
def __call__(self, *args, **kwargs):
miss_idx = [] #: missed hit results, some documents may not have groundtruth and thus will be removed
for idx, doc in enumerate(self.docs):
serialized_groundtruth = self.exec_fn(self.id2hash(doc.id))
serialized_groundtruth = self.exec_fn(hash(doc.id))
if serialized_groundtruth:
self.req.groundtruths.append(Document(serialized_groundtruth))
else:
Expand Down
4 changes: 2 additions & 2 deletions jina/drivers/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
self.pea.logger.warning(f'these bad docs can not be added: {bad_doc_ids}')

if docs_pts:
self.exec_fn(np.array([doc.id_in_hash for doc in docs_pts]), np.stack(embed_vecs))
self.exec_fn(np.array([hash(doc.id) for doc in docs_pts]), np.stack(embed_vecs))


class KVIndexDriver(BaseIndexDriver):
"""Serialize the documents/chunks in the request to key-value JSON pairs and write it using the executor
"""

def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
keys = [doc.id_in_hash for doc in docs]
keys = [hash(doc.id) for doc in docs]
values = [doc.SerializeToString() for doc in docs]
self.exec_fn(keys, values)
45 changes: 24 additions & 21 deletions jina/drivers/rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import numpy as np

from . import BaseExecutableDriver
from ..executors.rankers import Chunk2DocRanker, Match2DocRanker
from ..types.document import uid
from ..executors.rankers import Chunk2DocRanker
from ..types.document import uid, Document

if False:
from ..types.sets import DocumentSet
from ..types.document import Document


class BaseRankDriver(BaseExecutableDriver):
Expand Down Expand Up @@ -63,15 +62,15 @@ def _apply_all(self, docs: 'DocumentSet', context_doc: 'Document', *args,
query_chunk_meta = {} # type: Dict[int, Dict]
match_chunk_meta = {} # type: Dict[int, Dict]
for chunk in docs:
query_chunk_meta[chunk.id_in_hash] = chunk.get_attrs(*self.exec.required_keys)
query_chunk_meta[hash(chunk.id)] = chunk.get_attrs(*self.exec.required_keys)
for match in chunk.matches:
match_idx.append(
(match.parent_id_in_hash,
match.id_in_hash,
chunk.id_in_hash,
(hash(match.parent_id),
hash(match.id),
hash(chunk.id),
match.score.value)
)
match_chunk_meta[match.id_in_hash] = match.get_attrs(*self.exec.required_keys)
match_chunk_meta[hash(match.id)] = match.get_attrs(*self.exec.required_keys)

if match_idx:
match_idx = np.array(
Expand All @@ -85,10 +84,12 @@ def _apply_all(self, docs: 'DocumentSet', context_doc: 'Document', *args,
)

docs_scores = self.exec_fn(match_idx, query_chunk_meta, match_chunk_meta)
op_name = exec.__class__.__name__
for doc_hash, score in docs_scores:
context_doc.add_match(doc_id=doc_hash,
score_value=score,
op_name=exec.__class__.__name__)
m = Document(id=doc_hash)
m.score.value = score
m.score.op_name = op_name
context_doc.matches.append(m)


class CollectMatches2DocRankDriver(BaseRankDriver):
Expand Down Expand Up @@ -140,14 +141,14 @@ def _apply_all(self, docs: 'DocumentSet', context_doc: 'Document', *args,
query_chunk_meta = {}
match_chunk_meta = {}
for match in docs:
query_chunk_meta[context_doc.id_in_hash] = context_doc.get_attrs(*self.exec.required_keys)
query_chunk_meta[hash(context_doc.id)] = context_doc.get_attrs(*self.exec.required_keys)
match_idx.append((
match.parent_id_in_hash,
match.id_in_hash,
context_doc.id_in_hash,
hash(match.parent_id),
hash(match.id),
hash(context_doc.id),
match.score.value
))
match_chunk_meta[match.id_in_hash] = match.get_attrs(*self.exec.required_keys)
match_chunk_meta[hash(match.id)] = match.get_attrs(*self.exec.required_keys)

if match_idx:
match_idx = np.array(match_idx,
Expand All @@ -162,10 +163,12 @@ def _apply_all(self, docs: 'DocumentSet', context_doc: 'Document', *args,
docs_scores = self.exec_fn(match_idx, query_chunk_meta, match_chunk_meta)
# These ranker will change the current matches
context_doc.ClearField('matches')
op_name = exec.__class__.__name__
for doc_hash, score in docs_scores:
context_doc.add_match(doc_hash,
score_value=score,
op_name=exec.__class__.__name__)
m = Document(id=doc_hash)
m.score.value = score
m.score.op_name = op_name
context_doc.matches.append(m)


class Matches2DocRankDriver(BaseRankDriver):
Expand Down Expand Up @@ -201,8 +204,8 @@ def _apply_all(self, docs: 'DocumentSet', context_doc: 'Document', *args,
# if at the top-level already, no need to aggregate further
query_meta = context_doc.get_attrs(*self.exec.required_keys)

old_match_scores = {match.id_in_hash: match.score.value for match in docs}
match_meta = {match.id_in_hash: match.get_attrs(*self.exec.required_keys) for match in docs}
old_match_scores = {hash(match.id): match.score.value for match in docs}
match_meta = {hash(match.id): match.get_attrs(*self.exec.required_keys) for match in docs}
# if there are no matches, no need to sort them
if not old_match_scores:
return
Expand Down
13 changes: 6 additions & 7 deletions jina/drivers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ def __init__(
**kwargs
)

self.hash2id = uid.hash2id
self.id2hash = uid.id2hash


class KVSearchDriver(BaseSearchDriver):
"""Fill in the doc/chunk-level top-k results using the :class:`jina.executors.indexers.meta.BinaryPbIndexer`
Expand Down Expand Up @@ -62,7 +59,7 @@ def __init__(self, is_merge: bool = True, *args, **kwargs):
def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
miss_idx = [] #: missed hit results, some search may not end with results. especially in shards
for idx, retrieved_doc in enumerate(docs):
serialized_doc = self.exec_fn(retrieved_doc.id_in_hash)
serialized_doc = self.exec_fn(hash(retrieved_doc.id))
if serialized_doc:
r = Document(serialized_doc)

Expand All @@ -87,7 +84,7 @@ def __init__(self, executor: str = None, method: str = 'query_by_id', *args, **k
super().__init__(executor, method, *args, **kwargs)

def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:
embeds = self.exec_fn([d.id_in_hash for d in docs])
embeds = self.exec_fn([hash(d.id) for d in docs])
for doc, embedding in zip(docs, embeds):
doc.embedding = embedding

Expand Down Expand Up @@ -128,7 +125,9 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs) -> None:

topk_embed = fill_fn(topks) if (self._fill_embedding and fill_fn) else [None] * len(topks)
for match_hash, score, vec in zip(topks, scores, topk_embed):
r = doc.add_match(doc_id=match_hash,
score_value=score, op_name=op_name)
m = Document(id=int(match_hash))
m.score.value = score
m.score.op_name = op_name
r = doc.matches.append(m)
if vec is not None:
r.embedding = vec
8 changes: 4 additions & 4 deletions jina/excepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,6 @@ class BadPersistantFile(Exception):
"""Bad or broken dump file that can not be deserialized with ``pickle.load``"""


class BadRequestType(Exception):
"""Bad request type and the pod does not know how to handle """


class GRPCServerError(Exception):
"""Can not connect to the grpc gateway"""

Expand Down Expand Up @@ -167,5 +163,9 @@ class BadQueryLangType(TypeError):
""" Exception when can not construct a query language from the given data """


class BadRequestType(TypeError):
"""Exception when can not construct a request object from given data"""


class RemotePodClosed(Exception):
""" Exception when remote pod is closed and log streaming needs to exit """
Loading

0 comments on commit f3a7c15

Please sign in to comment.