Skip to content

Commit

Permalink
[NeuralChat] enable gramma check and query polish to enhance RAG perf…
Browse files Browse the repository at this point in the history
…ormance (#1245)
  • Loading branch information
XuhuiRen committed Feb 23, 2024
1 parent 15feadf commit a63ec0d
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,10 @@ def generate_intent_prompt(query):
conv = PromptTemplate("intent")
conv.append_message(conv.roles[0], query)
conv.append_message(conv.roles[1], None)
return conv.get_prompt()

def polish_query_prompt(query):
conv = PromptTemplate("polish")
conv.append_message(conv.roles[0], query)
conv.append_message(conv.roles[1], None)
return conv.get_prompt()
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# !/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Polish and clarify the input user query with LLM."""

from intel_extension_for_transformers.neural_chat.pipeline.plugins.prompt.prompt_template \
import polish_query_prompt
from intel_extension_for_transformers.neural_chat.models.model_utils import predict


class QueryPolisher:
def __init__(self):
pass

def polish_query(self, model_name, query):
"""Using the LLM to polish the user query to fix the gramma errors and clarify the user query."""
prompt = polish_query_prompt(query)
params = {}
params["model_name"] = model_name
params["prompt"] = prompt
params["temperature"] = 0.1
params["top_k"] = 3
params["max_new_tokens"] = 512
params['do_sample'] = True
new_query = predict(**params)
return new_query
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
from typing import Dict, List, Any, ClassVar, Collection
from .detector.intent_detection import IntentDetector
from .detector.query_explainer import QueryPolisher
from .parser.parser import DocumentParser
from .retriever_adapter import RetrieverAdapter
from intel_extension_for_transformers.neural_chat.pipeline.plugins.prompt.prompt_template \
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(self,
mode = "accuracy",
process=True,
append=True,
polish=False,
**kwargs):

self.intent_detector = IntentDetector()
Expand All @@ -90,6 +92,10 @@ def __init__(self,
"accuracy",
"general",
)
if polish:
self.polisher = QueryPolisher()
else:
self.polisher = None

assert self.retrieval_type in allowed_retrieval_type, "search_type of {} not allowed.".format( \
self.retrieval_type)
Expand Down Expand Up @@ -259,6 +265,13 @@ def append_localdb(self, append_path, **kwargs):


def pre_llm_inference_actions(self, model_name, query):
if self.polisher:
try:
query = self.polisher.polish_query(model_name, query)
except Exception as e:
logging.info(f"Polish the user query failed, {e}")
raise Exception("[Rereieval ERROR] query polish failed!")

try:
intent = self.intent_detector.intent_detection(model_name, query)
except Exception as e:
Expand Down
12 changes: 11 additions & 1 deletion intel_extension_for_transformers/neural_chat/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@
)
)


# Intent template
register_conv_template(
Conversation(
Expand All @@ -192,6 +191,17 @@
)
)

# Query Polish template
register_conv_template(
Conversation(
name="polish",
system_message="### Please polish the following user query to make it clear and easy to be understood.\n",
roles=("### User Query: ", "### Polished Query: "),
sep_style=SeparatorStyle.NO_COLON_SINGLE,
sep="\n",
)
)

# NER template
register_conv_template(
Conversation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

from intel_extension_for_transformers.neural_chat.pipeline.plugins.retrieval.parser.parser import DocumentParser
import unittest

import os
import shutil
from intel_extension_for_transformers.neural_chat import build_chatbot
from intel_extension_for_transformers.neural_chat import PipelineConfig
from intel_extension_for_transformers.neural_chat import plugins

class TestMemory(unittest.TestCase):
def setUp(self):
Expand All @@ -33,5 +37,32 @@ def test_html_loader(self):
vectordb = doc_parser.load(url)
self.assertIsNotNone(vectordb)

class TestPolisher(unittest.TestCase):
def setUp(self):
if os.path.exists("test_for_polish"):
shutil.rmtree("test_for_polish", ignore_errors=True)
return super().setUp()

def tearDown(self) -> None:
if os.path.exists("test_for_polish"):
shutil.rmtree("test_for_polish", ignore_errors=True)
return super().tearDown()

def test_retrieval_accuracy(self):
plugins.retrieval.enable = True
plugins.retrieval.args["input_path"] = ['https://www.ces.tech/']
plugins.retrieval.args["persist_directory"] = "./test_for_polish"
plugins.retrieval.args["retrieval_type"] = 'default'
plugins.retrieval.args["polish"] = True
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(config)
response = chatbot.predict("How many cores does the Intel Xeon Platinum 8480+ Processor have in total?")
print(response)
plugins.retrieval.args["persist_directory"] = "./output"
self.assertIsNotNone(response)
plugins.retrieval.enable = False
plugins.retrieval.args["polish"] = False

if __name__ == "__main__":
unittest.main()

0 comments on commit a63ec0d

Please sign in to comment.