Skip to content

Commit

Permalink
Add support for custom scoring instances, closes #544
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Sep 6, 2023
1 parent ce5a2fc commit e48a888
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 15 deletions.
44 changes: 37 additions & 7 deletions src/python/txtai/scoring/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Factory module
"""

from ..util import Resolver

from .base import Scoring
from .bm25 import BM25
from .sif import SIF
Expand All @@ -24,18 +26,46 @@ def create(config):
Scoring
"""

# Scoring instance
scoring = None

# Support string and dict configuration
if isinstance(config, str):
config = {"method": config}

method = config.get("method") if config else None
# Get scoring method
method = config.get("method", "bm25")

if method == "bm25":
return BM25(config)
if method == "sif":
return SIF(config)
if method == "tfidf":
scoring = BM25(config)
elif method == "sif":
scoring = SIF(config)
elif method == "tfidf":
# Default scoring class implements tf-idf
return Scoring(config)
scoring = Scoring(config)
else:
# Resolve custom method
scoring = ScoringFactory.resolve(method, config)

# Store config back
config["method"] = method

return scoring

@staticmethod
def resolve(backend, config):
"""
Attempt to resolve a custom backend.
Args:
backend: backend class
config: index configuration parameters
Returns:
Graph
"""

return None
try:
return Resolver()(backend)(config)
except Exception as e:
raise ImportError(f"Unable to resolve scoring backend: '{backend}'") from e
2 changes: 1 addition & 1 deletion test/python/testpipeline/testtrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def testMultiLabel(self):

data = []
for x in self.data:
data.append({"text": x["text"], "label": [float(x["label"])] * 2})
data.append({"text": x["text"], "label": [0.0, 1.0] if x["label"] else [1.0, 0.0]})

trainer = HFTrainer()
model, tokenizer = trainer("google/bert_uncased_L-2_H-128_A-2", data)
Expand Down
22 changes: 15 additions & 7 deletions test/python/testscoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,21 @@ def testBM25(self):

self.runTests("bm25")

def testCustom(self):
"""
Test custom method
"""

self.runTests("txtai.scoring.BM25")

def testCustomNotFound(self):
"""
Test unresolvable custom method
"""

with self.assertRaises(ImportError):
ScoringFactory.create("notfound.scoring")

def testSIF(self):
"""
Test sif
Expand All @@ -55,13 +70,6 @@ def testTFIDF(self):

self.runTests("tfidf")

def testUnknown(self):
"""
Test unknown method
"""

self.assertIsNone(ScoringFactory.create("unknown"))

def runTests(self, method):
"""
Runs a series of tests for a scoring method.
Expand Down

0 comments on commit e48a888

Please sign in to comment.