Skip to content

Commit

Permalink
[NeuralChat] Add askdoc retrieval api & example (#514)
Browse files Browse the repository at this point in the history
  • Loading branch information
letonghan committed Oct 27, 2023
1 parent af741bb commit 89cf760
Show file tree
Hide file tree
Showing 34 changed files with 887 additions and 46 deletions.
1 change: 1 addition & 0 deletions .github/workflows/chatbot-inference-llama-2-7b-chat-hf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs:
pip uninstall intel-extension-for-transformers -y; \
pip install -r requirements.txt; \
python setup.py install; \
pip install -r intel_extension_for_transformers/neural_chat/requirements.txt; \
python workflows/chatbot/inference/generate.py --base_model_path \"meta-llama/Llama-2-7b-chat-hf\" --hf_access_token \"${{ env.HF_ACCESS_TOKEN }}\" --instructions \"Transform the following sentence into one that shows contrast. The tree is rotten.\" "
- name: Stop Container
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/chatbot-inference-mpt-7b-chat.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs:
pip uninstall intel-extension-for-transformers -y; \
pip install -r requirements.txt; \
python setup.py install; \
pip install -r intel_extension_for_transformers/neural_chat/requirements.txt; \
python workflows/chatbot/inference/generate.py --base_model_path \"mosaicml/mpt-7b-chat\" --instructions \"Transform the following sentence into one that shows contrast. The tree is rotten.\" "
- name: Stop Container
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
This README is intended to guide you through setting up the server for the AskDoc demo using the NeuralChat framework. You can deploy it on various platforms, including Intel XEON Scalable Processors, Habana's Gaudi processors (HPU), Intel Data Center GPU and Client GPU, Nvidia Data Center GPU and Client GPU.

# Introduction
The popularity of applications like ChatGPT has attracted many users seeking to address everyday problems. However, some users have encountered a challenge known as "model hallucination," where LLMs generate incorrect or nonexistent information, raising concerns about content accuracy. This example introduce our solution to build a retrieval-based chatbot backend server. Though few lines of code, our api could help the user build a local refernece database to enhance the accuracy of the generation results.

Before deploying this example, please follow the instructions in the [README](../../README.md) to install the necessary dependencies.

# Setup Environment

## Setup Conda

First, you need to install and configure the Conda environment:

```shell
# Download and install Miniconda
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash `Miniconda*.sh`
source ~/.bashrc
```

## Install numactl

Next, install the numactl library:

```shell
sudo apt install numactl
```

## Install Python dependencies

Install the following Python dependencies using Conda:

```shell
conda install astunparse ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses -y
conda install jemalloc gperftools -c conda-forge -y
conda install git-lfs -y
```

Install other dependencies using pip:

```bash
pip install -r ../../../requirements.txt
```


## Download Models
```shell
git-lfs install
git clone https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
```


# Configure YAML

You can customize the configuration file 'askdoc.yaml' to match your environment setup. Here's a table to help you understand the configurable options:

| Item | Value |
| --------------------------------- | ---------------------------------------|
| host | 127.0.0.1 |
| port | 8000 |
| model_name_or_path | "./Llama-2-7b-chat-hf" |
| device | "auto" |
| retrieval.enable | true |
| retrieval.args.input_path | "./docs" |
| retrieval.args.persist_dir | "./example_persist" |
| retrieval.args.response_template | "We cannot find suitable content to answer your query, please contact to find help." |
| retrieval.args.append | True |
| tasks_list | ['textchat', 'retrieval'] |


# Run the AskDoc server
The Neural Chat API offers an easy way to create and utilize chatbot models while integrating local documents. Our API simplifies the process of automatically handling and storing local documents in a document store. In this example, we use `./docs/test_doc.txt` for example. You can construct your own retrieval doc of Intel® oneAPI DPC++/C++ Compiler following [this link](https://www.intel.com/content/www/us/en/docs/dpcpp-cpp-compiler/developer-guide-reference/2023-2/overview.html).


To start the PhotoAI server, run the following command:

```shell
nohup bash run.sh &
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from intel_extension_for_transformers.neural_chat import NeuralChatServerExecutor

def main():
server_executor = NeuralChatServerExecutor()
server_executor(
config_file="./askgm.yaml",
log_file="./askgm.log")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This is the parameter configuration file for NeuralChat Serving.

#################################################################################
# SERVER SETTING #
#################################################################################
host: 127.0.0.1
port: 8000

model_name_or_path: "./Llama-2-7b-chat-hf"
device: "auto"

retrieval:
enable: true
args:
input_path: "./docs"
persist_dir: "./example_persist"
response_template: "We cannot find suitable content to answer your query, please contact AskGM to find help. Mail: ask.gm.zizhu@intel.com."
append: True

safety_cheker:
enable: true

tasks_list: ['textchat', 'retrieval']

Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
This guide provides information about the Intel® oneAPI DPC++/C++ Compiler and runtime environment. This document is valid for version 2024.0 of the compilers.

The Intel® oneAPI DPC++/C++ Compiler is available as part of the Intel® oneAPI Base Toolkit, Intel® oneAPI HPC Toolkit, Intel® oneAPI IoT Toolkit, or as a standalone compiler.

Refer to the Intel® oneAPI DPC++/C++ Compiler product page and the Release Notes for more information about features, specifications, and downloads.


The compiler supports these key features:
Intel® oneAPI Level Zero: The Intel® oneAPI Level Zero (Level Zero) Application Programming Interface (API) provides direct-to-metal interfaces to offload accelerator devices.
OpenMP* Support: Compiler support for OpenMP 5.0 Version TR4 features and some OpenMP Version 5.1 features.
Pragmas: Information about directives to provide the compiler with instructions for specific tasks, including splitting large loops into smaller ones, enabling or disabling optimization for code, or offloading computation to the target.
Offload Support: Information about SYCL*, OpenMP, and parallel processing options you can use to affect optimization, code generation, and more.
Latest Standards: Use the latest standards including C++ 20, SYCL, and OpenMP 5.0 and 5.1 for GPU offload.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Kill the exist and re-run
ps -ef |grep 'askgm' |awk '{print $2}' |xargs kill -9

# KMP
export KMP_BLOCKTIME=1
export KMP_SETTINGS=1
export KMP_AFFINITY=granularity=fine,compact,1,0

# OMP
export OMP_NUM_THREADS=56
export LD_PRELOAD=${CONDA_PREFIX}/lib/libiomp5.so

# tc malloc
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so

# database
export MYSQL_PASSWORD="root"
export MYSQL_HOST="127.0.0.1"
export MYSQL_DB="fastrag"

numactl -l -C 0-55 askdoc -m askgm 2>&1 | tee run.log
31 changes: 9 additions & 22 deletions intel_extension_for_transformers/neural_chat/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def predict_stream(self, query, config=None):
query_include_prompt = True

# plugin pre actions
link = []
for plugin_name in get_registered_plugins():
if is_plugin_enabled(plugin_name):
plugin_instance = get_plugin_instance(plugin_name)
Expand All @@ -156,11 +157,13 @@ def predict_stream(self, query, config=None):
if plugin_name == "asr" and not is_audio_file(query):
continue
if plugin_name == "retrieval":
response = plugin_instance.pre_llm_inference_actions(self.model_name, query)
response, link = plugin_instance.pre_llm_inference_actions(self.model_name, query)
if response == "Response with template.":
return plugin_instance.response_template, link
else:
response = plugin_instance.pre_llm_inference_actions(query)
if plugin_name == "safety_checker" and response:
return "Your query contains sensitive words, please try another query."
return "Your query contains sensitive words, please try another query.", link
else:
if response != None and response != False:
query = response
Expand All @@ -183,16 +186,7 @@ def is_generator(obj):
continue
response = plugin_instance.post_llm_inference_actions(response)

# clear plugins config
for key in plugins:
plugins[key] = {
"enable": False,
"class": None,
"args": {},
"instance": None
}

return response
return response, link

def predict(self, query, config=None):
"""
Expand Down Expand Up @@ -230,7 +224,9 @@ def predict(self, query, config=None):
if plugin_name == "asr" and not is_audio_file(query):
continue
if plugin_name == "retrieval":
response = plugin_instance.pre_llm_inference_actions(self.model_name, query)
response, link = plugin_instance.pre_llm_inference_actions(self.model_name, query)
if response == "Response with template.":
return plugin_instance.response_template
else:
response = plugin_instance.pre_llm_inference_actions(query)
if plugin_name == "safety_checker" and response:
Expand All @@ -253,15 +249,6 @@ def predict(self, query, config=None):
if hasattr(plugin_instance, 'post_llm_inference_actions'):
response = plugin_instance.post_llm_inference_actions(response)

# clear plugins config
for key in plugins:
plugins[key] = {
"enable": False,
"class": None,
"args": {},
"instance": None
}

return response

def chat_stream(self, query, config=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ def generate_qa_prompt(query, context=None, history=None):
conv.append_message(conv.roles[1], None)
return conv.get_prompt()

def generate_qa_enterprise(query, context=None, history=None):
if context and history:
conv = PromptTemplate("rag_with_threshold")
conv.append_message(conv.roles[0], query)
conv.append_message(conv.roles[1], context)
conv.append_message(conv.roles[2], history)
conv.append_message(conv.roles[3], None)
else:
conv = PromptTemplate("rag_with_threshold")
conv.append_message(conv.roles[0], query)
conv.append_message(conv.roles[1], context)
conv.append_message(conv.roles[3], None)
return conv.get_prompt()


def generate_prompt(query, history=None):
if history:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def intent_detection(self, model_name, query):
params["prompt"] = prompt
params["temperature"] = 0.001
params["top_k"] = 1
params["max_new_tokens"] = 5
params["max_new_tokens"] = 10
intent = predict(**params)
return intent

Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def load_xlsx(input):
df = pd.read_excel(input)
all_data = []
documents = []

for index, row in df.iterrows():
sub = "User Query: " + row['Questions'] + "Answer: " + row["Answers"]
all_data.append(sub)
Expand All @@ -134,6 +133,38 @@ def load_xlsx(input):
return documents


def load_faq_xlsx(input):
"""Load and process faq xlsx file."""
df = pd.read_excel(input)
all_data = []

for index, row in df.iterrows():
sub = "Question: " + row['question'] + " Answer: " + row["answer"]
sub = sub.replace('#', " ")
sub = sub.replace(r'\t', " ")
sub = sub.replace('\n', ' ')
sub = sub.replace('\n\n', ' ')
sub = re.sub(r'\s+', ' ', sub)
all_data.append([sub, row['link']])
return all_data


def load_general_xlsx(input):
"""Load and process doc xlsx file."""
df = pd.read_excel(input)
all_data = []

for index, row in df.iterrows():
sub = row['context']
sub = sub.replace('#', " ")
sub = sub.replace(r'\t', " ")
sub = sub.replace('\n', ' ')
sub = sub.replace('\n\n', ' ')
sub = re.sub(r'\s+', ' ', sub)
all_data.append([sub, row['link']])
return all_data


def load_unstructured_data(input):
"""Load unstructured context."""
if input.endswith("pdf"):
Expand All @@ -158,6 +189,10 @@ def laod_structured_data(input, process, max_length):
"""Load structured context."""
if input.endswith("jsonl"):
content = load_json(input, process, max_length)
elif "faq" in input and input.endswith("xlsx"):
content = load_faq_xlsx(input)
elif "enterprise_docs" in input and input.endswith("xlsx"):
content = load_general_xlsx(input)
else:
content = load_xlsx(input)
return content
Expand Down

0 comments on commit 89cf760

Please sign in to comment.