Skip to content

Commit

Permalink
Refine notebook and fix restful api issues (#445)
Browse files Browse the repository at this point in the history
Signed-off-by: lvliang-intel <liang1.lv@intel.com>
  • Loading branch information
lvliang-intel committed Oct 15, 2023
1 parent ebbbff4 commit d8cc116
Show file tree
Hide file tree
Showing 19 changed files with 84 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ Welcome to use Jupyter Notebooks to explore how to build and customize chatbots
| 2.4 | Deploying Chatbot on Habana Gaudi1/Gaudi2 | Learn how to deploy chatbot on Habana Gaudi1/Gaudi2 | [Notebook](./notebooks/deploy_chatbot_on_habana_gaudi.ipynb) |
| 2.5 | Deploying Chatbot on Nvidia A100 | Learn how to deploy chatbot on A100 | [Notebook](./notebooks/deploy_chatbot_on_nv_a100.ipynb) |
| 2.6 | Deploying Chatbot with Load Balance | Learn how to deploy chatbot with load balance | [Notebook](./notebooks/chatbot_with_load_balance.ipynb) |
| 2.7 | Deploying End-to-end Chatbot on Intel CPU SPR | Learn how to deploy an end to end text chatbot on Intel CPU SPR including frontend GUI and backend | [Notebook](./notebooks/setup_text_chatbot_service_on_spr.ipynb) |
| 2.7 | Deploying End-to-end text Chatbot on Intel CPU SPR | Learn how to deploy an end to end text chatbot on Intel CPU SPR including frontend GUI and backend | [Notebook](./notebooks/setup_text_chatbot_service_on_spr.ipynb) |
| 2.8 | Deploying End-to-end talkingbot on Intel CPU SPR | Learn how to deploy an end to end talkingbot on Intel CPU SPR including frontend GUI and backend | [Notebook](./notebooks/setup_talking_chatbot_service_on_spr.ipynb) |
| 3 | Optimizing Chatbots | | |
| 3.1 | Enabling Chatbot with BF16 Optimization on SPR | Learn how to optimize chatbot using mixed precision on SPR | [Notebook](./notebooks/amp_optimization_on_spr.ipynb) |
| 3.2 | Enabling Chatbot with BF16 Optimization on Habana Gaudi1/Gaudi2 | Learn how to optimze chatbot using mixed precision on Habana Gaudi1/Gaudi2 | [Notebook](./notebooks/amp_optimization_on_habana_gaudi.ipynb) |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,27 +154,6 @@
"print(\" Play Output Audio ......\")\n",
"IPython.display.display(IPython.display.Audio(\"welcome.wav\"))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Access Finetune Service"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from neural_chat import FinetuningClientExecutor\n",
"executor = FinetuningClientExecutor()\n",
"tuning_status = executor(\n",
" server_ip=\"127.0.0.1\", # master server ip\n",
" port=8000 # master server port (port on socket 0, if both sockets are deployed)\n",
" )"
]
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,27 +167,6 @@
"print(\" Play Output Audio ......\")\n",
"IPython.display.display(IPython.display.Audio(\"welcome.wav\"))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Access Finetune Service"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from neural_chat import FinetuningClientExecutor\n",
"executor = FinetuningClientExecutor()\n",
"tuning_status = executor(\n",
" server_ip=\"127.0.0.1\", # master server ip\n",
" port=8000 # master server port (port on socket 0, if both sockets are deployed)\n",
" )"
]
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,27 +167,6 @@
"print(\" Play Output Audio ......\")\n",
"IPython.display.display(IPython.display.Audio(\"welcome.wav\"))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Access Finetune Service"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from neural_chat import FinetuningClientExecutor\n",
"executor = FinetuningClientExecutor()\n",
"tuning_status = executor(\n",
" server_ip=\"127.0.0.1\", # master server ip\n",
" port=8000 # master server port (port on socket 0, if both sockets are deployed)\n",
" )"
]
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,27 +169,6 @@
"print(\" Play Output Audio ......\")\n",
"IPython.display.display(IPython.display.Audio(\"welcome.wav\"))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Access Finetune Service"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from intel_extension_for_transformers.neural_chat import FinetuningClientExecutor\n",
"executor = FinetuningClientExecutor()\n",
"tuning_status = executor(\n",
" server_ip=\"127.0.0.1\", # master server ip\n",
" port=8000 # master server port (port on socket 0, if both sockets are deployed)\n",
" )"
]
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,27 +197,6 @@
"print(\" Play Output Audio ......\")\n",
"IPython.display.display(IPython.display.Audio(\"welcome.wav\"))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Access Finetune Service"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from neural_chat import FinetuningClientExecutor\n",
"executor = FinetuningClientExecutor()\n",
"tuning_status = executor(\n",
" server_ip=\"127.0.0.1\", # master server ip\n",
" port=8000 # master server port (port on socket 0, if both sockets are deployed)\n",
" )"
]
}
],
"metadata": {
Expand Down
26 changes: 22 additions & 4 deletions intel_extension_for_transformers/neural_chat/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from abc import ABC
from typing import List
import os
import os, types
from fastchat.conversation import get_conv_template, Conversation
from ..config import GenerationConfig
from ..plugins import is_plugin_enabled, get_plugin_instance, get_registered_plugins, plugins
Expand Down Expand Up @@ -135,6 +135,11 @@ def predict_stream(self, query, config=None):
if not os.path.exists(query):
raise ValueError(f"The audio file path {query} is invalid.")

query_include_prompt = False
self.get_conv_template(self.model_name, config.task)
if self.conv_template.roles[0] in query and self.conv_template.roles[1] in query:
query_include_prompt = True

# plugin pre actions
for plugin_name in get_registered_plugins():
if is_plugin_enabled(plugin_name):
Expand All @@ -150,18 +155,25 @@ def predict_stream(self, query, config=None):
if plugin_name == "safety_checker" and response:
return "Your query contains sensitive words, please try another query."
else:
query = response
if response != None and response != False:
query = response
assert query is not None, "Query cannot be None."

query = self.prepare_prompt(query, self.model_name, config.task)
if not query_include_prompt:
query = self.prepare_prompt(query, self.model_name, config.task)
response = predict_stream(**construct_parameters(query, self.model_name, self.device, config))

def is_generator(obj):
return isinstance(obj, types.GeneratorType)

# plugin post actions
for plugin_name in get_registered_plugins():
if is_plugin_enabled(plugin_name):
plugin_instance = get_plugin_instance(plugin_name)
if plugin_instance:
if hasattr(plugin_instance, 'post_llm_inference_actions'):
if plugin_name == "safety_checker" and is_generator(response):
continue
response = plugin_instance.post_llm_inference_actions(response)

# clear plugins config
Expand Down Expand Up @@ -195,6 +207,11 @@ def predict(self, query, config=None):
if not os.path.exists(query):
raise ValueError(f"The audio file path {query} is invalid.")

query_include_prompt = False
self.get_conv_template(self.model_name, config.task)
if self.conv_template.roles[0] in query and self.conv_template.roles[1] in query:
query_include_prompt = True

# plugin pre actions
for plugin_name in get_registered_plugins():
if is_plugin_enabled(plugin_name):
Expand All @@ -214,8 +231,9 @@ def predict(self, query, config=None):
query = response
assert query is not None, "Query cannot be None."

if not query_include_prompt:
query = self.prepare_prompt(query, self.model_name, config.task)
# LLM inference
query = self.prepare_prompt(query, self.model_name, config.task)
response = predict(**construct_parameters(query, self.model_name, self.device, config))

# plugin post actions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
if "neural-chat-7b-v2" in model_path.lower():
return get_conv_template("neural-chat-7b-v2")
else:
return get_conv_template("neural-chat-7b-v1.1")
return get_conv_template("neural-chat-7b-v1-1")

register_model_adapter(NeuralChatModel)
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

# pylint: disable=wrong-import-position
from typing import Any

import os
import gptcache.processor.post
import gptcache.processor.pre
from gptcache import Cache, Config
from gptcache import Cache, cache, Config
from gptcache.adapter.adapter import adapt
from gptcache.embedding import (
Onnx,
Expand Down Expand Up @@ -55,8 +55,11 @@
import time

class ChatCache:
def __init__(self):
self.cache_obj = Cache()
def __init__(self, config_dir: str=os.path.join(
os.path.dirname(os.path.abspath(__file__)), "./cache_config.yaml"),
embedding_model_dir: str="hkunlp/instructor-large"):
self.cache_obj = cache
self.init_similar_cache_from_config(config_dir, embedding_model_dir)

def _cache_data_converter(self, cache_data):
return self._construct_resp_from_cache(cache_data)
Expand Down Expand Up @@ -130,8 +133,7 @@ def init_similar_cache(self, data_dir: str = "api_cache", pre_func=get_prompt,
config=config,
)

def init_similar_cache_from_config(self, config_dir: str="./config/cache_config.yml",
embedding_model_dir: str="hkunlp/instructor-large"):
def init_similar_cache_from_config(self, config_dir, embedding_model_dir):
import_ruamel()
from ruamel.yaml import YAML # pylint: disable=C0415

Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
gptcache
git+https://github.com/UKPLab/sentence-transformers.git
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
# neuralchat-v1.1 prompt template
register_conv_template(
Conversation(
name="neural-chat-7b-v1.1",
name="neural-chat-7b-v1-1",
system_template="""<|im_start|>system
{system_message}""",
system_message="""- You are a helpful assistant chatbot trained by Intel.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ yacs
uvicorn
optimum
optimum-intel
sentence_transformers
git+https://github.com/UKPLab/sentence-transformers.git
unstructured
markdown
rouge_score
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ yacs
uvicorn
optimum
optimum-intel
sentence_transformers
git+https://github.com/UKPLab/sentence-transformers.git
unstructured
markdown
rouge_score
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ starlette
yacs
uvicorn
optimum
sentence_transformers
git+https://github.com/UKPLab/sentence-transformers.git
unstructured
markdown
rouge_score
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ yacs
uvicorn
optimum
optimum-intel
sentence_transformers
git+https://github.com/UKPLab/sentence-transformers.git
unstructured
markdown
rouge_score
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ...cli.log import logger
from ...server.restful.openai_protocol import ChatCompletionRequest, ChatCompletionResponse
from ...config import GenerationConfig
import json

def check_completion_request(request: BaseModel) -> Optional[str]:
logger.info(f"Checking parameters of completion request...")
Expand Down Expand Up @@ -66,11 +67,11 @@ def get_chatbot(self):
return self.chatbot


async def handle_completion_request(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
return await self.handle_chat_completion_request(request)
def handle_completion_request(self, request: ChatCompletionRequest):
return self.handle_chat_completion_request(request)


async def handle_chat_completion_request(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
def handle_chat_completion_request(self, request: ChatCompletionRequest):
chatbot = self.get_chatbot()

try:
Expand All @@ -83,35 +84,44 @@ async def handle_chat_completion_request(self, request: ChatCompletionRequest) -
setattr(config, attr, value)
if request.stream:
generator = chatbot.predict_stream(query=request.prompt, config=config)
return StreamingResponse(generator, media_type="text/event-stream")
def stream_generator():
for output in generator:
ret = {
"text": output,
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
yield f"data: [DONE]\n\n"
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = chatbot.predict(query=request.prompt, config=config)
except Exception as e:
raise Exception(e)
logger.error(f"An error occurred: {e}")
else:
logger.info(f"Chat completion finished.")
return ChatCompletionResponse(response=response)
return ChatCompletionResponse(response=response)


router = TextChatAPIRouter()


@router.post("/v1/completions")
async def completion_endpoint(request: ChatCompletionRequest) -> ChatCompletionResponse:
async def completion_endpoint(request: ChatCompletionRequest):
ret = check_completion_request(request)
if ret is not None:
raise RuntimeError("Invalid parameter.")
return await router.handle_completion_request(request)
return router.handle_completion_request(request)


@router.post("/v1/chat/completions")
async def chat_completion_endpoint(chat_request: ChatCompletionRequest) -> ChatCompletionResponse:
async def chat_completion_endpoint(chat_request: ChatCompletionRequest):
ret = check_completion_request(chat_request)
if ret is not None:
raise RuntimeError("Invalid parameter.")
return await router.handle_chat_completion_request(chat_request)
return router.handle_chat_completion_request(chat_request)

@router.post("/v1/models")
async def show_available_models():
models = router.get_chatbot().model_name
models = []
models.append(router.get_chatbot().model_name)
return {"models": models}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ yacs
uvicorn
optimum
optimum[habana]
sentence_transformers
git+https://github.com/UKPLab/sentence-transformers.git
unstructured
markdown
rouge_score
Expand Down

0 comments on commit d8cc116

Please sign in to comment.