From e48a8885ce7328cb35d8c1a02034054605b427c5 Mon Sep 17 00:00:00 2001 From: davidmezzetti <561939+davidmezzetti@users.noreply.github.com> Date: Wed, 6 Sep 2023 17:21:56 -0400 Subject: [PATCH] Add support for custom scoring instances, closes #544 --- src/python/txtai/scoring/factory.py | 44 +++++++++++++++++++++---- test/python/testpipeline/testtrainer.py | 2 +- test/python/testscoring.py | 22 +++++++++---- 3 files changed, 53 insertions(+), 15 deletions(-) diff --git a/src/python/txtai/scoring/factory.py b/src/python/txtai/scoring/factory.py index 40c1d27e0..cd15a60f6 100644 --- a/src/python/txtai/scoring/factory.py +++ b/src/python/txtai/scoring/factory.py @@ -2,6 +2,8 @@ Factory module """ +from ..util import Resolver + from .base import Scoring from .bm25 import BM25 from .sif import SIF @@ -24,18 +26,46 @@ def create(config): Scoring """ + # Scoring instance + scoring = None + # Support string and dict configuration if isinstance(config, str): config = {"method": config} - method = config.get("method") if config else None + # Get scoring method + method = config.get("method", "bm25") if method == "bm25": - return BM25(config) - if method == "sif": - return SIF(config) - if method == "tfidf": + scoring = BM25(config) + elif method == "sif": + scoring = SIF(config) + elif method == "tfidf": # Default scoring class implements tf-idf - return Scoring(config) + scoring = Scoring(config) + else: + # Resolve custom method + scoring = ScoringFactory.resolve(method, config) + + # Store config back + config["method"] = method + + return scoring + + @staticmethod + def resolve(backend, config): + """ + Attempt to resolve a custom backend. + + Args: + backend: backend class + config: index configuration parameters + + Returns: + Graph + """ - return None + try: + return Resolver()(backend)(config) + except Exception as e: + raise ImportError(f"Unable to resolve scoring backend: '{backend}'") from e diff --git a/test/python/testpipeline/testtrainer.py b/test/python/testpipeline/testtrainer.py index b5afdc54b..b992acc12 100644 --- a/test/python/testpipeline/testtrainer.py +++ b/test/python/testpipeline/testtrainer.py @@ -202,7 +202,7 @@ def testMultiLabel(self): data = [] for x in self.data: - data.append({"text": x["text"], "label": [float(x["label"])] * 2}) + data.append({"text": x["text"], "label": [0.0, 1.0] if x["label"] else [1.0, 0.0]}) trainer = HFTrainer() model, tokenizer = trainer("google/bert_uncased_L-2_H-128_A-2", data) diff --git a/test/python/testscoring.py b/test/python/testscoring.py index 54c460837..8b6886739 100644 --- a/test/python/testscoring.py +++ b/test/python/testscoring.py @@ -41,6 +41,21 @@ def testBM25(self): self.runTests("bm25") + def testCustom(self): + """ + Test custom method + """ + + self.runTests("txtai.scoring.BM25") + + def testCustomNotFound(self): + """ + Test unresolvable custom method + """ + + with self.assertRaises(ImportError): + ScoringFactory.create("notfound.scoring") + def testSIF(self): """ Test sif @@ -55,13 +70,6 @@ def testTFIDF(self): self.runTests("tfidf") - def testUnknown(self): - """ - Test unknown method - """ - - self.assertIsNone(ScoringFactory.create("unknown")) - def runTests(self, method): """ Runs a series of tests for a scoring method.