Skip to content
Permalink
Browse files

feat(score_fn): make score_fn as a TrainableBase

  • Loading branch information...
hanxiao committed Sep 3, 2019
1 parent 14c7e52 commit f908f3811790b17209154c638a1fe6d67f81780a
Showing with 23 additions and 12 deletions.
  1. +8 −6 gnes/base/__init__.py
  2. +5 −2 gnes/indexer/base.py
  3. +8 −3 gnes/score_fn/base.py
  4. +1 −1 gnes/score_fn/normalize.py
  5. +1 −0 tests/test_annoyindexer.py
@@ -66,7 +66,8 @@ class TrainableType(type):
'batch_size': None,
'work_dir': os.environ.get('GNES_VOLUME', os.getcwd()),
'name': None,
'on_gpu': False
'on_gpu': False,
'unnamed_warning': True
}

def __new__(cls, *args, **kwargs):
@@ -180,11 +181,12 @@ def _post_init_wrapper(self):
if not getattr(self, 'name', None) and os.environ.get('GNES_WARN_UNNAMED_COMPONENT', '1') == '1':
_id = str(uuid.uuid4()).split('-')[0]
_name = '%s-%s' % (self.__class__.__name__, _id)
self.logger.warning(
'this object is not named ("name" is not found under "gnes_config" in YAML config), '
'i will call it "%s". '
'naming the object is important as it provides an unique identifier when '
'serializing/deserializing this object.' % _name)
if self.unnamed_warning:
self.logger.warning(
'this object is not named ("name" is not found under "gnes_config" in YAML config), '
'i will call it "%s". '
'naming the object is important as it provides an unique identifier when '
'serializing/deserializing this object.' % _name)
setattr(self, 'name', _name)

_before = set(list(self.__dict__.keys()))
@@ -22,8 +22,11 @@


class BaseIndexer(TrainableBase):
normalize_fn = ModifierFn()
score_fn = ModifierFn()
def __init__(self, normalize_fn=ModifierFn(),
score_fn=ModifierFn(), *args, **kwargs):
super().__init__(*args, **kwargs)
self.normalize_fn = normalize_fn
self.score_fn = score_fn

def add(self, keys: Any, docs: Any, weights: List[float], *args, **kwargs):
pass
@@ -4,6 +4,7 @@
from operator import mul, add
from typing import Sequence

from ..base import TrainableBase
from ..proto import gnes_pb2


@@ -16,7 +17,8 @@ def get_unary_score(value: float, **kwargs):
return score


class BaseScoreFn:
class BaseScoreFn(TrainableBase):
unnamed_warning = False

def __call__(self, *args, **kwargs) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score':
raise NotImplementedError
@@ -37,10 +39,11 @@ def op(self, *args, **kwargs) -> float:
class ScoreCombinedFn(BaseScoreFn):
"""Combine multiple scores into one score, defaults to 'multiply'"""

def __init__(self, score_mode: str = 'multiply'):
def __init__(self, score_mode: str = 'multiply', *args, **kwargs):
"""
:param score_mode: specifies how the computed scores are combined
"""
super().__init__(*args, **kwargs)
if score_mode not in {'multiply', 'sum', 'avg', 'max', 'min'}:
raise AttributeError('score_mode=%s is not supported!' % score_mode)
self.score_mode = score_mode
@@ -66,7 +69,9 @@ class ModifierFn(BaseScoreFn):
score = modifier(factor * value)
"""

def __init__(self, modifier: str = 'none', factor: float = 1.0, factor_name: str = 'GivenConstant'):
def __init__(self, modifier: str = 'none', factor: float = 1.0, factor_name: str = 'GivenConstant', *args,
**kwargs):
super().__init__(*args, **kwargs)
if modifier not in {'none', 'log', 'log1p', 'log2p', 'ln', 'ln1p', 'ln2p', 'square', 'sqrt', 'reciprocal',
'reciprocal1p', 'abs'}:
raise AttributeError('modifier=%s is not supported!' % modifier)
@@ -40,7 +40,7 @@ def __init__(self, num_bytes: int):
class Normalizer5(ModifierFn):
"""Do normalizing: score = 1 / (1 + sqrt(abs(score)))"""

def __init__(self, num_dim: int):
def __init__(self):
super().__init__()
self.modifier = 'reciprocal1p'

@@ -25,3 +25,4 @@ def test_search(self):
self.assertEqual(top_1, list(range(10)))
a.close()
a.dump()
a.dump_yaml()

0 comments on commit f908f38

Please sign in to comment.
You can’t perform that action at this time.