Skip to content

Commit

Permalink
fix(ranker): correct column names in ranker score matrix (#1973)
Browse files Browse the repository at this point in the history
* fix(ranker): correct column names in ranker score matrix

* fix(ranker): correct column names in ranker score matrix
  • Loading branch information
hanxiao committed Feb 18, 2021
1 parent 9cea5e7 commit c379b09
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
14 changes: 7 additions & 7 deletions jina/drivers/rank/aggregate/__init__.py
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from ....executors.rankers import Chunk2DocRanker
from ....executors.rankers import Chunk2DocRanker, COL_STR_TYPE
from ....types.document import Document
from ....types.score import NamedScore

Expand Down Expand Up @@ -124,9 +124,9 @@ def _apply_all(self, docs: 'DocumentSet',
if match_idx:
match_idx = np.array(match_idx,
dtype=[
(Chunk2DocRanker.COL_MATCH_PARENT_ID, np.object),
(Chunk2DocRanker.COL_MATCH_ID, np.object),
(Chunk2DocRanker.COL_DOC_CHUNK_ID, np.object),
(Chunk2DocRanker.COL_PARENT_ID, COL_STR_TYPE),
(Chunk2DocRanker.COL_DOC_CHUNK_ID, COL_STR_TYPE),
(Chunk2DocRanker.COL_QUERY_CHUNK_ID, COL_STR_TYPE),
(Chunk2DocRanker.COL_SCORE, np.float64)
]
)
Expand Down Expand Up @@ -203,9 +203,9 @@ def _apply_all(self, docs: 'DocumentSet', context_doc: 'Document', *args,
if match_idx:
match_idx = np.array(match_idx,
dtype=[
(Chunk2DocRanker.COL_MATCH_PARENT_ID, np.object),
(Chunk2DocRanker.COL_MATCH_ID, np.object),
(Chunk2DocRanker.COL_DOC_CHUNK_ID, np.object),
(Chunk2DocRanker.COL_PARENT_ID, COL_STR_TYPE),
(Chunk2DocRanker.COL_DOC_CHUNK_ID, COL_STR_TYPE),
(Chunk2DocRanker.COL_QUERY_CHUNK_ID, COL_STR_TYPE),
(Chunk2DocRanker.COL_SCORE, np.float64)
]
)
Expand Down
16 changes: 9 additions & 7 deletions jina/executors/rankers/__init__.py
Expand Up @@ -7,6 +7,8 @@

from .. import BaseExecutor

COL_STR_TYPE = 'U64' #: the ID column data type for score matrix


class BaseRanker(BaseExecutor):
"""The base class for a `Ranker`"""
Expand Down Expand Up @@ -40,9 +42,9 @@ class Chunk2DocRanker(BaseRanker):
:meth:`get_attrs` of :class:`Document`
"""
COL_MATCH_PARENT_ID = 'match_parent_id'
COL_MATCH_ID = 'match_id'
COL_DOC_CHUNK_ID = 'doc_chunk_id'
COL_PARENT_ID = 'match_parent_id'
COL_DOC_CHUNK_ID = 'match_doc_chunk_id'
COL_QUERY_CHUNK_ID = 'match_query_chunk_id'
COL_SCORE = 'score'

def score(self, match_idx: 'np.ndarray', query_chunk_meta: Dict, match_chunk_meta: Dict) -> 'np.ndarray':
Expand Down Expand Up @@ -80,7 +82,7 @@ def group_by_doc_id(self, match_idx):
:return: an iterator over the groups.
:rtype: :class:`Chunk2DocRanker`.
"""
return self._group_by(match_idx, self.COL_MATCH_PARENT_ID)
return self._group_by(match_idx, self.COL_PARENT_ID)

@staticmethod
def _group_by(match_idx, col_name):
Expand All @@ -103,14 +105,14 @@ def sort_doc_by_score(r):
:rtype: np.ndarray
"""
r = np.array(r, dtype=[
(Chunk2DocRanker.COL_MATCH_PARENT_ID, np.object),
(Chunk2DocRanker.COL_PARENT_ID, COL_STR_TYPE),
(Chunk2DocRanker.COL_SCORE, np.float64)]
)
return np.sort(r, order=Chunk2DocRanker.COL_SCORE)[::-1]

def get_doc_id(self, match_with_same_doc_id):
"""Return document id that matches with given id :param:`match_with_same_doc_id`"""
return match_with_same_doc_id[0][self.COL_MATCH_PARENT_ID]
return match_with_same_doc_id[0][self.COL_PARENT_ID]


class Match2DocRanker(BaseRanker):
Expand All @@ -124,7 +126,7 @@ class Match2DocRanker(BaseRanker):
- BucketShuffleRanker (first buckets matches and then sort each bucket).
"""

COL_MATCH_ID = 'match_id'
COL_MATCH_ID = 'match_doc_chunk_id'
COL_SCORE = 'score'

def score(self, query_meta: Dict, old_match_scores: Dict, match_meta: Dict) -> 'np.ndarray':
Expand Down
Expand Up @@ -31,7 +31,7 @@ def __init__(self, *args, **kwargs):
self.required_keys = {'length'}

def _get_score(self, match_idx, query_chunk_meta, match_chunk_meta, *args, **kwargs):
return match_idx[0][self.COL_MATCH_PARENT_ID], match_chunk_meta[match_idx[0][self.COL_MATCH_ID]]['length']
return match_idx[0][self.COL_PARENT_ID], match_chunk_meta[match_idx[0][self.COL_DOC_CHUNK_ID]]['length']


def create_document_to_score_same_depth_level():
Expand Down
Expand Up @@ -25,7 +25,7 @@ def __init__(self, *args, **kwargs):
self.required_keys = {'length'}

def _get_score(self, match_idx, query_chunk_meta, match_chunk_meta, *args, **kwargs):
return match_idx[0][self.COL_MATCH_PARENT_ID], match_chunk_meta[match_idx[0][self.COL_MATCH_ID]]['length']
return match_idx[0][self.COL_PARENT_ID], match_chunk_meta[match_idx[0][self.COL_DOC_CHUNK_ID]]['length']


class SimpleChunk2DocRankDriver(Chunk2DocRankDriver):
Expand Down

0 comments on commit c379b09

Please sign in to comment.