Skip to content

Commit

Permalink
[NeuralChat] Support chatbot inference with huggingface endpoint (#1310)
Browse files Browse the repository at this point in the history
* Support access huggingface service

Signed-off-by: lvliang-intel <liang1.lv@intel.com>

* fix chatbot issue

Signed-off-by: lvliang-intel <liang1.lv@intel.com>

* fix chatbot issue

Signed-off-by: lvliang-intel <liang1.lv@intel.com>

* remove redundan huggingface_hub

Signed-off-by: lvliang-intel <liang1.lv@intel.com>

---------

Signed-off-by: lvliang-intel <liang1.lv@intel.com>
  • Loading branch information
lvliang-intel committed Feb 27, 2024
1 parent e22998e commit 5b84e58
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 49 deletions.
97 changes: 54 additions & 43 deletions intel_extension_for_transformers/neural_chat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,50 +113,59 @@ def build_chatbot(config: PipelineConfig=None):
if not config:
config = PipelineConfig()

# create model adapter
if "llama" in config.model_name_or_path.lower():
from .models.llama_model import LlamaModel
adapter = LlamaModel(config.model_name_or_path, config.task)
elif "mpt" in config.model_name_or_path.lower():
from .models.mpt_model import MptModel
adapter = MptModel(config.model_name_or_path, config.task)
elif "neural-chat" in config.model_name_or_path.lower():
from .models.neuralchat_model import NeuralChatModel
adapter = NeuralChatModel(config.model_name_or_path, config.task)
elif "chatglm" in config.model_name_or_path.lower():
from .models.chatglm_model import ChatGlmModel
adapter = ChatGlmModel(config.model_name_or_path, config.task)
elif "qwen" in config.model_name_or_path.lower():
from .models.qwen_model import QwenModel
adapter = QwenModel(config.model_name_or_path, config.task)
elif "mistral" in config.model_name_or_path.lower():
from .models.mistral_model import MistralModel
adapter = MistralModel(config.model_name_or_path, config.task)
elif "solar" in config.model_name_or_path.lower():
from .models.solar_model import SolarModel
adapter = SolarModel(config.model_name_or_path, config.task)
elif "decilm" in config.model_name_or_path.lower():
from .models.decilm_model import DeciLMModel
adapter = DeciLMModel(config.model_name_or_path, config.task)
elif "deepseek-coder" in config.model_name_or_path.lower():
from .models.deepseek_coder_model import DeepseekCoderModel
adapter = DeepseekCoderModel(config.model_name_or_path, config.task)
elif "opt" in config.model_name_or_path.lower() or \
"gpt" in config.model_name_or_path.lower() or \
"flan-t5" in config.model_name_or_path.lower() or \
"bloom" in config.model_name_or_path.lower() or \
"starcoder" in config.model_name_or_path.lower() or \
"codegen" in config.model_name_or_path.lower() or \
"magicoder" in config.model_name_or_path.lower() or \
"mixtral" in config.model_name_or_path.lower() or \
"phi-2" in config.model_name_or_path.lower() or \
"sqlcoder" in config.model_name_or_path.lower():
from .models.base_model import BaseModel
adapter = BaseModel(config.model_name_or_path, config.task)
if config.hf_endpoint_url:
if not config.hf_access_token:
set_latest_error(ErrorCodes.ERROR_HF_TOKEN_NOT_PROVIDED)
logging.error("build_chatbot: \
the huggingface token must be provided to access the huggingface endpoint service.")
return
from .models.huggingface_model import HuggingfaceModel
adapter = HuggingfaceModel(config.hf_endpoint_url, config.hf_access_token)
else:
set_latest_error(ErrorCodes.ERROR_MODEL_NOT_SUPPORTED)
logging.error("build_chatbot: unknown model")
return
# create model adapter
if "llama" in config.model_name_or_path.lower():
from .models.llama_model import LlamaModel
adapter = LlamaModel(config.model_name_or_path, config.task)
elif "mpt" in config.model_name_or_path.lower():
from .models.mpt_model import MptModel
adapter = MptModel(config.model_name_or_path, config.task)
elif "neural-chat" in config.model_name_or_path.lower():
from .models.neuralchat_model import NeuralChatModel
adapter = NeuralChatModel(config.model_name_or_path, config.task)
elif "chatglm" in config.model_name_or_path.lower():
from .models.chatglm_model import ChatGlmModel
adapter = ChatGlmModel(config.model_name_or_path, config.task)
elif "qwen" in config.model_name_or_path.lower():
from .models.qwen_model import QwenModel
adapter = QwenModel(config.model_name_or_path, config.task)
elif "mistral" in config.model_name_or_path.lower():
from .models.mistral_model import MistralModel
adapter = MistralModel(config.model_name_or_path, config.task)
elif "solar" in config.model_name_or_path.lower():
from .models.solar_model import SolarModel
adapter = SolarModel(config.model_name_or_path, config.task)
elif "decilm" in config.model_name_or_path.lower():
from .models.decilm_model import DeciLMModel
adapter = DeciLMModel(config.model_name_or_path, config.task)
elif "deepseek-coder" in config.model_name_or_path.lower():
from .models.deepseek_coder_model import DeepseekCoderModel
adapter = DeepseekCoderModel(config.model_name_or_path, config.task)
elif "opt" in config.model_name_or_path.lower() or \
"gpt" in config.model_name_or_path.lower() or \
"flan-t5" in config.model_name_or_path.lower() or \
"bloom" in config.model_name_or_path.lower() or \
"starcoder" in config.model_name_or_path.lower() or \
"codegen" in config.model_name_or_path.lower() or \
"magicoder" in config.model_name_or_path.lower() or \
"mixtral" in config.model_name_or_path.lower() or \
"phi-2" in config.model_name_or_path.lower() or \
"sqlcoder" in config.model_name_or_path.lower():
from .models.base_model import BaseModel
adapter = BaseModel(config.model_name_or_path, config.task)
else:
set_latest_error(ErrorCodes.ERROR_MODEL_NOT_SUPPORTED)
logging.error("build_chatbot: unknown model")
return
from .models.base_model import register_model_adapter
register_model_adapter(adapter)
# register plugin instance in model adaptor
Expand Down Expand Up @@ -284,6 +293,8 @@ def build_chatbot(config: PipelineConfig=None):
else:
parameters["use_vllm"] = False
parameters["vllm_engine_params"] = None
if config.hf_endpoint_url:
return adapter
adapter.load_model(parameters)
if get_latest_error():
return
Expand Down
2 changes: 2 additions & 0 deletions intel_extension_for_transformers/neural_chat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ class PipelineConfig:
def __init__(self,
model_name_or_path="Intel/neural-chat-7b-v3-1",
tokenizer_name_or_path=None,
hf_endpoint_url=None,
hf_access_token=None,
device="auto",
task="",
Expand All @@ -461,6 +462,7 @@ def __init__(self,
self.model_name_or_path = model_name_or_path
self.tokenizer_name_or_path = tokenizer_name_or_path
self.hf_access_token = hf_access_token
self.hf_endpoint_url = hf_endpoint_url
if device == "auto":
self.device = get_device_type()
else:
Expand Down
1 change: 1 addition & 0 deletions intel_extension_for_transformers/neural_chat/errorcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ErrorCodes:
ERROR_CACHE_DIR_NO_WRITE_PERMISSION = 2004
ERROR_INVALID_MODEL_VERSION = 2005
ERROR_MODEL_NOT_SUPPORTED = 2006
ERROR_HF_TOKEN_NOT_PROVIDED = 2007
WARNING_INPUT_EXCEED_MAX_SEQ_LENGTH = 2101

# General Service Error Code - Dataset related
Expand Down
33 changes: 27 additions & 6 deletions intel_extension_for_transformers/neural_chat/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(self, model_name, task="chat"):
self.device = None
self.conv_template = None
self.ipex_int8 = None
self.hf_client = None
self.get_conv_template(task)

def match(self):
Expand Down Expand Up @@ -169,7 +170,7 @@ def predict_stream(self, query, origin_query="", config=None):
if (self.conv_template.roles[0] in query and self.conv_template.roles[1] in query) or \
"starcoder" in self.model_name.lower() or "codellama" in self.model_name.lower() or \
"codegen" in self.model_name.lower() or "magicoder" in self.model_name.lower() or \
"phi-2" in self.model_name.lower() or "sqlcoder" in self.model_name.lower():
"phi-2" in self.model_name.lower() or "sqlcoder" in self.model_name.lower() or self.hf_client:
query_include_prompt = True

# plugin pre actions
Expand Down Expand Up @@ -233,8 +234,18 @@ def predict_stream(self, query, origin_query="", config=None):
query = generate_sqlcoder_prompt(query, config.sql_metadata)

try:
response = predict_stream(
**construct_parameters(query, self.model_name, self.device, self.assistant_model, config))
if self.hf_client:
response = self.hf_client.text_generation(prompt=query,
max_new_tokens=config.max_new_tokens,
do_sample=config.do_sample,
repetition_penalty=config.repetition_penalty,
temperature=config.temperature,
top_k=config.top_k,
top_p=config.top_p,
stream=True)
else:
response = predict_stream(
**construct_parameters(query, self.model_name, self.device, self.assistant_model, config))
except Exception as e:
set_latest_error(ErrorCodes.ERROR_MODEL_INFERENCE_FAIL)
return
Expand Down Expand Up @@ -282,7 +293,7 @@ def predict(self, query, origin_query="", config=None):
if (self.conv_template.roles[0] in query and self.conv_template.roles[1] in query) or \
"starcoder" in self.model_name.lower() or "codellama" in self.model_name.lower() or \
"codegen" in self.model_name.lower() or "magicoder" in self.model_name.lower() or \
"sqlcoder" in self.model_name.lower():
"sqlcoder" in self.model_name.lower() or self.hf_client:
query_include_prompt = True

# plugin pre actions
Expand Down Expand Up @@ -344,8 +355,18 @@ def predict(self, query, origin_query="", config=None):

# LLM inference
try:
response = predict(
**construct_parameters(query, self.model_name, self.device, self.assistant_model, config))
if self.hf_client:
response = self.hf_client.text_generation(prompt=query,
max_new_tokens=config.max_new_tokens,
do_sample=config.do_sample,
repetition_penalty=config.repetition_penalty,
temperature=config.temperature,
top_k=config.top_k,
top_p=config.top_p,
stream=False)
else:
response = predict(
**construct_parameters(query, self.model_name, self.device, self.assistant_model, config))
except Exception as e:
set_latest_error(ErrorCodes.ERROR_MODEL_INFERENCE_FAIL)
return
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_model import BaseModel
import logging
from huggingface_hub import InferenceClient

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)

class HuggingfaceModel(BaseModel):
def __init__(self, hf_endpoint_url, hf_access_token):
self.hf_client = InferenceClient(model=hf_endpoint_url, token=hf_access_token)
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ExifRead
fastapi
fschat==0.2.32
git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2
huggingface_hub
intel-extension-for-transformers
intel_extension_for_pytorch==2.2.0
neural-compressor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ ExifRead
fastapi
fschat==0.2.35
git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2
huggingface_hub
neural-compressor
numpy==1.23.5
optimum
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ExifRead
fastapi
fschat==0.2.35
git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2
huggingface_hub
intel-extension-for-transformers
neural-compressor
numpy==1.23.5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ evaluate
ExifRead
fastapi
fschat==0.2.35
huggingface_hub
intel-extension-for-pytorch==2.1.10+xpu
intel-extension-for-transformers
neural-compressor
Expand Down

0 comments on commit 5b84e58

Please sign in to comment.