Skip to content
Permalink
Browse files

refactor(score_fn): use post_init instead of property

  • Loading branch information...
hanxiao committed Sep 4, 2019
1 parent 909a44b commit e9feaa6174ada9242b42a66564607d86671c6bab
Showing with 123 additions and 34 deletions.
  1. +1 −1 gnes/base/__init__.py
  2. +32 −0 gnes/score_fn/__init__.py
  3. +41 −21 gnes/score_fn/base.py
  4. +15 −0 gnes/score_fn/chunk.py
  5. +15 −0 gnes/score_fn/doc.py
  6. +19 −12 gnes/score_fn/normalize.py
@@ -50,7 +50,7 @@ def _import(module_name, class_name):
if class_name in cls2file:
return getattr(importlib.import_module('gnes.%s.%s' % (module_name, cls2file[class_name])), class_name)

search_modules = ['encoder', 'indexer', 'preprocessor', 'router']
search_modules = ['encoder', 'indexer', 'preprocessor', 'router', 'score_fn']

for m in search_modules:
r = _import(m, name)
@@ -0,0 +1,32 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# A key-value map for Class to the (module)file it located in
from ..base import register_all_class

_cls2file_map = {
'BaseScoreFn': 'base',
'ScoreCombinedFn': 'base',
'ModifierFn': 'base',
'WeightedChunkScoreFn': 'chunk',
'WeightedDocScoreFn': 'doc',
'Normalizer1': 'normalize',
'Normalizer2': 'normalize',
'Normalizer3': 'normalize',
'Normalizer4': 'normalize',
'Normalizer5': 'normalize',
}

register_all_class(_cls2file_map, 'score_fn')
@@ -1,3 +1,18 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from functools import reduce
from math import log, log1p, log10, sqrt
@@ -18,6 +33,8 @@ def get_unary_score(value: float, **kwargs):


class BaseScoreFn(TrainableBase):
"""Base score function. A score function must implement __call__ method"""

warn_unnamed = False

def __call__(self, *args, **kwargs) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score':
@@ -32,9 +49,6 @@ def new_score(self, *, operands: Sequence['gnes_pb2.Response.QueryResponse.Score
operands=[json.loads(s.explained) for s in operands],
**kwargs)

def op(self, *args, **kwargs) -> float:
raise NotImplementedError


class ScoreCombinedFn(BaseScoreFn):
"""Combine multiple scores into one score, defaults to 'multiply'"""
@@ -44,24 +58,29 @@ def __init__(self, score_mode: str = 'multiply', *args, **kwargs):
:param score_mode: specifies how the computed scores are combined
"""
super().__init__(*args, **kwargs)
if score_mode not in {'multiply', 'sum', 'avg', 'max', 'min'}:
raise AttributeError('score_mode=%s is not supported!' % score_mode)
if score_mode not in self.supported_ops:
raise AttributeError(
'score_mode=%s is not supported! must be one of %s' % (score_mode, self.supported_ops.keys()))
self.score_mode = score_mode

def __call__(self, *last_scores) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score':
return self.new_score(
value=self.op(s.value for s in last_scores),
operands=last_scores,
score_mode=self.score_mode)

def op(self, *args, **kwargs) -> float:
@property
def supported_ops(self):
return {
'multiply': lambda v: reduce(mul, v),
'sum': lambda v: reduce(add, v),
'max': lambda v: reduce(max, v),
'min': lambda v: reduce(min, v),
'avg': lambda v: reduce(add, v) / len(v),
}[self.score_mode](*args, **kwargs)
}

def post_init(self):
self.op = self.supported_ops[self.score_mode]

def __call__(self, *last_scores) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score':
return self.new_score(
value=self.op(s.value for s in last_scores),
operands=last_scores,
score_mode=self.score_mode)


class ModifierFn(BaseScoreFn):
@@ -72,18 +91,15 @@ class ModifierFn(BaseScoreFn):
def __init__(self, modifier: str = 'none', factor: float = 1.0, factor_name: str = 'GivenConstant', *args,
**kwargs):
super().__init__(*args, **kwargs)
if modifier not in {'none', 'log', 'log1p', 'log2p', 'ln', 'ln1p', 'ln2p', 'square', 'sqrt', 'reciprocal',
'reciprocal1p', 'abs'}:
raise AttributeError('modifier=%s is not supported!' % modifier)
if modifier not in self.supported_ops:
raise AttributeError(
'modifier=%s is not supported! must be one of %s' % (modifier, self.supported_ops.keys()))
self._modifier = modifier
self._factor = factor
self._factor_name = factor_name

@property
def factor(self):
return get_unary_score(value=self._factor, name=self._factor_name)

def op(self, *args, **kwargs) -> float:
def supported_ops(self):
return {
'none': lambda x: x,
'log': log10,
@@ -99,7 +115,11 @@ def op(self, *args, **kwargs) -> float:
'abs': abs,
'invert': lambda x: - x,
'invert1p': lambda x: 1 - x
}[self._modifier](*args, **kwargs)
}

def post_init(self):
self.factor = get_unary_score(value=self._factor, name=self._factor_name)
self.op = self.supported_ops[self._modifier]

def __call__(self,
last_score: 'gnes_pb2.Response.QueryResponse.ScoredResult.Score',
@@ -1,3 +1,18 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import get_unary_score, ScoreCombinedFn


@@ -1,3 +1,18 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import get_unary_score, ScoreCombinedFn


@@ -1,12 +1,26 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import ModifierFn, ScoreOps as so


class Normalizer1(ModifierFn):
"""Do normalizing: score = 1 / (1 + sqrt(score))"""

def __init__(self):
super().__init__()
self._modifier = 'reciprocal1p'
super().__init__(modifier='reciprocal1p')

def __call__(self, last_score, *args, **kwargs):
return super().__call__(so.sqrt(last_score))
@@ -16,10 +30,7 @@ class Normalizer2(ModifierFn):
"""Do normalizing: score = 1 / (1 + score / num_dim)"""

def __init__(self, num_dim: int):
super().__init__()
self._modifier = 'reciprocal1p'
self._factor = 1.0 / num_dim
self._factor_name = '1/num_dim'
super().__init__(modifier='reciprocal1p', factor=1.0 / num_dim, factor_name='1/num_dim')


class Normalizer3(Normalizer2):
@@ -33,18 +44,14 @@ class Normalizer4(ModifierFn):
"""Do normalizing: score = 1 - score / num_bytes """

def __init__(self, num_bytes: int):
super().__init__()
self._modifier = 'invert1p'
self._factor = 1.0 / num_bytes
self._factor_name = '1/num_bytes'
super().__init__(modifier='invert1p', factor=1.0 / num_bytes, factor_name='1/num_bytes')


class Normalizer5(ModifierFn):
"""Do normalizing: score = 1 / (1 + sqrt(abs(score)))"""

def __init__(self):
super().__init__()
self._modifier = 'reciprocal1p'
super().__init__(modifier='reciprocal1p')

def __call__(self, last_score, *args, **kwargs):
return super().__call__(so.sqrt(so.abs(last_score)))

0 comments on commit e9feaa6

Please sign in to comment.
You can’t perform that action at this time.