Skip to content

Commit

Permalink
Fix the bug in rag example (Neural Chat) (#226)
Browse files Browse the repository at this point in the history
* code revision

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>

* revision

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>

* revision

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>

* revision

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>

* revise

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>

* Update __init__.py

* Update document_parser.py

* fixed import error.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

---------

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>
Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>
Co-authored-by: Ye, Xinyu <xinyu.ye@intel.com>
  • Loading branch information
XuhuiRen and XinyuYe-Intel committed Sep 4, 2023
1 parent 81a651d commit d2cee03
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 29 deletions.
3 changes: 1 addition & 2 deletions intel_extension_for_transformers/neural_chat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@
from .pipeline.plugins.audio.asr_chinese import ChineseAudioSpeechRecognition
from .pipeline.plugins.audio.tts import TextToSpeech
from .pipeline.plugins.audio.tts_chinese import ChineseTextToSpeech
from .pipeline.plugins.security import SafetyChecker
from .pipeline.plugins.retrievals import QA_Client
from .pipeline.plugins.security.safety_checker import SafetyChecker
from .pipeline.plugins.intent_detector import IntentDetector
from .models.llama_model import LlamaModel
from .models.mpt_model import MptModel
from .models.chatglm_model import ChatGlmModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import List
from ..utils.command import NeuralChatCommandDict
from .base_executor import BaseCommandExecutor
from ..config import PipelineConfig, FinetuningConfig, GenerationConfig # pylint: disable=E0611
from ..config import PipelineConfig, TextGenerationFinetuningConfig, GenerationConfig
from ..config import ModelArguments, DataArguments, FinetuningArguments
from ..plugins import plugins
from transformers import TrainingArguments
Expand Down Expand Up @@ -311,7 +311,7 @@ def execute(self, argv: List[str]) -> bool:
training_args = TrainingArguments(output_dir="./output")
finetune_args= FinetuningArguments()

self.finetuneCfg = FinetuningConfig(model_args, data_args, training_args, finetune_args)
self.finetuneCfg = TextGenerationFinetuningConfig(model_args, data_args, training_args, finetune_args)
try:
res = self()
print(res)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,23 @@
import os
import sys
from transformers import TrainingArguments, HfArgumentParser
from intel_extension_for_transformers.neural_chat.config import (
PipelineConfig,
RetrieverConfig,
SafetyConfig,
GenerationConfig
)
from intel_extension_for_transformers.neural_chat.config import PipelineConfig
from intel_extension_for_transformers.neural_chat.chatbot import build_chatbot


def main():
# See all possible arguments in config.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser(
(PipelineConfig, RetrieverConfig, SafetyConfig, GenerationConfig)
)
parser = HfArgumentParser(PipelineConfig)

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
pipeline_args, retriever_args, safety_args, generation_args = parser.parse_json_file(
json_file = os.path.abspath(sys.argv[1])
)
pipeline_args= parser.parse_json_file(json_file = os.path.abspath(sys.argv[1]))
else:
(pipeline_args, retriever_args, safety_args, generation_args) = parser.parse_args_into_dataclasses()
pipeline_args= parser.parse_args_into_dataclasses()

pipeline_args.saftey_config = safety_args
pipeline_args.retrieval_config = retriever_args
pipeline_args.generation_config = generation_args
chatbot = build_chatbot(pipeline_args)

response = chatbot.predict(query="What is IDM 2.0?", config=pipeline_args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,17 @@ def KB_construct(self, input):

documents = []
for data, meta in data_collection:
if len(data) < 5:
continue
metadata = {"source": meta}
new_doc = Document(page_content=data, metadata=metadata)
documents.append(new_doc)
assert documents!= [], "The given file/files cannot be loaded."
embedding = HuggingFaceInstructEmbeddings(model_name=self.embedding_model)
vectordb = Chroma.from_documents(documents=documents, embedding=embedding,
persist_directory=self.persist_dir)
vectordb.persist()
print("success")
print("The local knowledge base has been successfully built!")
return vectordb
else:
print("There might be some errors, please wait and try again!")
Expand All @@ -125,11 +128,13 @@ def KB_construct(self, input):
documents = []
for data, meta in data_collection:
metadata = {"source": meta}
# pylint: disable=E1123
new_doc = SDocument(content=data, metadata=metadata)
if len(data) < 5:
continue
new_doc = SDocument(content=data, meta=metadata)
documents.append(new_doc)
assert documents != [], "The given file/files cannot be loaded."
document_store.write_documents(documents)
print("success")
print("The local knowledge base has been successfully built!")
return document_store
else:
print("There might be some errors, please wait and try again!")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,21 @@ def laod_structured_data(input, process, max_length):

def get_chuck_data(content, max_length, input):
"""Process the context to make it maintain a suitable length for the generation."""
sentences = re.split('(?<=[;!.?])', content)
sentences = re.split('(?<=[!.?])', content)

paragraphs = []
current_length = 0
count = 0
current_paragraph = ""
for sub_sen in sentences:
if sub_sen == "":
continue
count +=1
sentence_length = len(sub_sen)
if current_length + sentence_length <= max_length:
current_paragraph += sub_sen
current_length += sentence_length
if count == len(sentences):
if count == len(sentences) and len(current_paragraph.strip())>5:
paragraphs.append([current_paragraph.strip() ,input])
else:
paragraphs.append([current_paragraph.strip() ,input])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class SparseBM25Retriever():

def __int__(self, document_store = None, top_k = 1):
assert document_store is not None, "Please give a document database for retrieving."
self.retriever = BM25Retriever(document_store=document_store)
self.retriever = BM25Retriever(document_store=document_store, top_k=top_k)

def query_the_database(self, query):
documents = self.retriever.retrieve(query)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
ModelArguments,
DataArguments,
FinetuningArguments,
FinetuningConfig,
TextGenerationFinetuningConfig,
)
from intel_extension_for_transformers.neural_chat.server.restful.request import FinetuneRequest

Expand Down Expand Up @@ -66,7 +66,7 @@ def handle_finetune_request(self, request: FinetuneRequest) -> str:
overwrite_output_dir=request.overwrite_output_dir
)
finetune_args = FinetuningArguments(peft=request.peft)
finetune_cfg = FinetuningConfig(
finetune_cfg = TextGenerationFinetuningConfig(
model_args=model_args,
data_args=data_args,
training_args=training_args,
Expand Down

0 comments on commit d2cee03

Please sign in to comment.