Skip to content

Commit

Permalink
Fix the code structure and the plugin name in NeuralChat (#233)
Browse files Browse the repository at this point in the history
* revision

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>

* revision

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>

* revision

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>

* Update intent_detection.py

* Rename context_util.py to context_utils.py

* revision

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>

* add

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>

* fix parm value

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>

* fix

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>

* fix

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>

---------

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>
  • Loading branch information
XuhuiRen committed Sep 5, 2023
1 parent cd4c33d commit 486e9ec
Show file tree
Hide file tree
Showing 16 changed files with 27 additions and 43 deletions.
2 changes: 1 addition & 1 deletion intel_extension_for_transformers/neural_chat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .pipeline.plugins.audio.tts import TextToSpeech
from .pipeline.plugins.audio.tts_chinese import ChineseTextToSpeech
from .pipeline.plugins.security import SafetyChecker
from .pipeline.plugins.retrievals import QA_Client
from .pipeline.plugins.retrieval import Agent_QA
from .models.llama_model import LlamaModel
from .models.mpt_model import MptModel
from .models.chatglm_model import ChatGlmModel
Expand Down
6 changes: 3 additions & 3 deletions intel_extension_for_transformers/neural_chat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from .pipeline.plugins.audio.asr_chinese import ChineseAudioSpeechRecognition
from .pipeline.plugins.audio.tts import TextToSpeech
from .pipeline.plugins.audio.tts_chinese import ChineseTextToSpeech
from .pipeline.plugins.retrievals.indexing import DocumentIndexing
from .pipeline.plugins.retrievals.retrieval import SparseBM25Retriever, ChromaRetriever
from .pipeline.plugins.intent_detector import IntentDetector
from .pipeline.plugins.retrieval.indexing import DocumentIndexing
from .pipeline.plugins.retrieval import SparseBM25Retriever, ChromaRetriever
from .pipeline.plugins.retrieval.detector import IntentDetector
from .pipeline.plugins.security import SafetyChecker
from .plugins import plugins

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ The user can costomize the retrieval parameters to meet the personal demmads for
```
>>>persist_dir [str]: The local path to save the processed database. Default to "./output".
>>>process [bool]: Select to process the too long document into small chucks. Default to "False".
>>>process [bool]: Select to process the too long document into small chucks. Default to "True".
>>>input_path [str]: The user local path to a file folder or a specific file path. The code itself will check the path is a folder or a file. If it is a folder, the code will process all the files in the given folder. If it is a file, the code will prcess this single file.
Expand All @@ -55,9 +55,9 @@ The user can costomize the retrieval parameters to meet the personal demmads for
>>>top_k [int]: The number of the retrieved documents. Default to "1".
>>>search_type [str]: Select a ranking method for dense retrieval from "mmr" or "similarity". "similarity" will return the most similar docs to the input query. "mmr" will rank the docs using the maximal marginal relevance method. Deault to "mmr".
>>>search_type [str]: Select a ranking method for dense retrieval from "mmr", "similarity" and "similarity_score_threshold". "similarity" will return the most similar docs to the input query. "mmr" will rank the docs using the maximal marginal relevance method. "similarity_score_threshold" will return the mosy similar docs that also meet the threshold. Deault to "mmr".
>>>search_kwargs [dict]: Used by dense retrieval. Should be in the same format with {"k":1, "fetch_k":5}. "fetch_k" determines the amount of documents to pass to the ranking algorithm. Default to {"k":1, "fetch_k":5}.
>>>search_kwargs [dict]: Used by dense retrieval. Should be in the same format like {"k":1, "fetch_k":5}. "k" is the amount of documents to return. "score_threshold" is the minimal relevance threshold for "similarity_score_threshold" search. "lambda_mult" is the diversity of results returned by "mmr". "fetch_k" determines the amount of documents to pass to the "mmr" algorithm. Default to {"k":1, "fetch_k":5}.
```


Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .prompt import generate_intent_prompt, generate_qa_prompt, generate_prompt
from .prompt_template import generate_intent_prompt, generate_qa_prompt, generate_prompt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .document_parser import DocumentIndexing
from .retrieval_agent import Agent_QA
from .retrieval_base import Retriever
from .retrieval_bm25 import SparseBM25Retriever
from .retrieval_chroma import ChromaRetriever
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import transformers
import torch
from ..prompts import generate_intent_prompt
from intel_extension_for_transformers.neural_chat.pipeline.plugins.prompt import generate_intent_prompt
from intel_extension_for_transformers.llm.inference import predict

class IntentDetector:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .retrievals import QA_Client
from .indexing import DocumentIndexing
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def read_docx(doc_path):

def read_md(md_path):
"""Read docx file."""
loader = UnstructuredMarkdownLoader("instruction_data.md")
loader = UnstructuredMarkdownLoader(md_path)
text = loader.load()[0].page_content
return text

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings
from langchain.vectorstores import Chroma
from haystack.schema import Document as SDocument
from .utils import load_unstructured_data, laod_structured_data, get_chuck_data
from .context_utils import load_unstructured_data, laod_structured_data, get_chuck_data


class DocumentIndexing:
def __init__(self, retrieval_type="dense", document_store=None, persist_dir="./output",
process=False, embedding_model="hkunlp/instructor-large", max_length=512):
process=True, embedding_model="hkunlp/instructor-large", max_length=512):
"""
Wrapper for document indexing. Support dense and sparse indexing method.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
import os
import torch
import transformers
from intel_extension_for_transformers.neural_chat.pipeline.plugins.intent_detector import IntentDetector
from intel_extension_for_transformers.neural_chat.pipeline.plugins.retrievals.indexing import DocumentIndexing
from intel_extension_for_transformers.neural_chat.pipeline.plugins.retrievals.retrieval import Retriever
from intel_extension_for_transformers.neural_chat.pipeline.plugins.retrieval import Retriever
from intel_extension_for_transformers.neural_chat.pipeline.plugins.retrieval.detector import IntentDetector
from intel_extension_for_transformers.neural_chat.pipeline.plugins.retrieval.indexing import DocumentIndexing
from intel_extension_for_transformers.neural_chat.pipeline.plugins.prompt import generate_qa_prompt, generate_prompt
from intel_extension_for_transformers.neural_chat.plugins import register_plugin
from intel_extension_for_transformers.neural_chat.pipeline.plugins.prompts import generate_qa_prompt, generate_prompt


@register_plugin("retrieval")
class QA_Client():
def __init__(self, persist_dir="./output", process=False, input_path=None,
class Agent_QA():
def __init__(self, persist_dir="./output", process=True, input_path=None,
embedding_model="hkunlp/instructor-large", max_length=512, retrieval_type="dense",
document_store=None, top_k=1, search_type="mmr", search_kwargs={"k": 1, "fetch_k": 5}):
self.model = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

"""The class defination for the retriever. Supporting langchain-based and haystack-based retriever."""

from .bm25_retrieval import SparseBM25Retriever
from .chroma_retrieval import ChromaRetriever
from .retrieval_bm25 import SparseBM25Retriever
from .retrieval_chroma import ChromaRetriever

class Retriever():
"""Retrieve the document database with BM25 sparse algorithm."""
"""The wrapper for sparse retriever and dense retriever."""

def __init__(self, retrieval_type="dense", document_store=None,
top_k=1, search_type="mmr", search_kwargs={"k": 1, "fetch_k": 5}):
Expand Down

This file was deleted.

0 comments on commit 486e9ec

Please sign in to comment.