Skip to content

Commit

Permalink
[NeuralChat] Add bm25 into enabled retrievers and add Uts (#1313)
Browse files Browse the repository at this point in the history
 Add bm25 into enabled retrievers and add Uts

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>
  • Loading branch information
XuhuiRen committed Feb 29, 2024
1 parent f432a7a commit a19467d
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@
from FlagEmbedding import FlagReranker

class BgeReranker(BaseDocumentCompressor):
model_name:str = 'bge_reranker_model_path'
top_n: int = 3 # Number of documents to return.
model:FlagReranker = FlagReranker(model_name)
model:FlagReranker
"""CrossEncoder instance to use for reranking."""

def bge_rerank(self, query, docs):
model_inputs = [[query, doc] for doc in docs]
scores = self.model.compute_score(model_inputs)
if len(docs) == 1:
return [(0, scores)]
results = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
return results[:self.top_n]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self,
allowed_retrieval_type: ClassVar[Collection[str]] = (
"default",
"child_parent",
'bm25',
)
allowed_generation_mode: ClassVar[Collection[str]] = (
"accuracy",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ def __init__(self, retrieval_type='default', document_store=None, child_document
self.retrieval_type = retrieval_type
if enable_rerank:
from intel_extension_for_transformers.langchain.retrievers.bge_reranker import BgeReranker
self.reranker = BgeReranker(model_name = reranker_model, top_n=top_n)
from FlagEmbedding import FlagReranker
reranker = FlagReranker(reranker_model)
self.reranker = BgeReranker(model = reranker, top_n=top_n)
else:
self.reranker = None

if self.retrieval_type == "default":
self.retriever = VectorStoreRetriever(vectorstore=document_store, **kwargs)
if self.retrieval_type == "bm25":
elif self.retrieval_type == "bm25":
self.retriever = BM25Retriever.from_documents(docs, **kwargs)
elif self.retrieval_type == "child_parent":
self.retriever = ChildParentRetriever(parentstore=document_store, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,61 @@ def test_accuracy_mode(self):
plugins.retrieval.args = {}
plugins.retrieval.enable = False

class TestBM25Retriever(unittest.TestCase):
def setUp(self):
if os.path.exists("./bm25"):
shutil.rmtree("./bm25", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("./bm25"):
shutil.rmtree("./bm25", ignore_errors=True)
return super().tearDown()

def test_accuracy_mode(self):
plugins.retrieval.args = {}
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/sample.txt"
plugins.retrieval.args["persist_directory"] = "./bm25"
plugins.retrieval.args["retrieval_type"] = 'bm25'
plugins.retrieval.args["mode"] = 'accuracy'
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("How many cores does the Intel Xeon Platinum 8480+ Processor have in total?")
print(response)
self.assertIsNotNone(response)
plugins.retrieval.args = {}
plugins.retrieval.enable = False

class TestRerank(unittest.TestCase):
def setUp(self):
if os.path.exists("./rerank"):
shutil.rmtree("./rerank", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("./rerank"):
shutil.rmtree("./rerank", ignore_errors=True)
return super().tearDown()

def test_general_mode(self):
plugins.retrieval.args = {}
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = "../assets/docs/sample.txt"
plugins.retrieval.args["persist_directory"] = "./rerank"
plugins.retrieval.args["retrieval_type"] = 'default'
plugins.retrieval.args['enable_rerank'] = True
plugins.retrieval.args['reranker_model'] = 'BAAI/bge-reranker-base'
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("How many cores does the Intel Xeon Platinum 8480+ Processor have in total?")
print(response)
self.assertIsNotNone(response)
plugins.retrieval.args = {}
plugins.retrieval.enable = False

class TestGeneralMode(unittest.TestCase):
def setUp(self):
if os.path.exists("./general_mode"):
Expand Down

0 comments on commit a19467d

Please sign in to comment.