Skip to content

Commit

Permalink
[NeuralChat] Sync RESTful API with latest OpenAI protocol (#1205)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvliang-intel committed Jan 31, 2024
1 parent 1fdf6f0 commit 2e1c79d
Show file tree
Hide file tree
Showing 31 changed files with 892 additions and 192 deletions.
19 changes: 10 additions & 9 deletions intel_extension_for_transformers/neural_chat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,25 +115,25 @@ def build_chatbot(config: PipelineConfig=None):
# create model adapter
if "llama" in config.model_name_or_path.lower():
from .models.llama_model import LlamaModel
adapter = LlamaModel()
adapter = LlamaModel(config.model_name_or_path, config.task)
elif "mpt" in config.model_name_or_path.lower():
from .models.mpt_model import MptModel
adapter = MptModel()
adapter = MptModel(config.model_name_or_path, config.task)
elif "neural-chat" in config.model_name_or_path.lower():
from .models.neuralchat_model import NeuralChatModel
adapter = NeuralChatModel()
adapter = NeuralChatModel(config.model_name_or_path, config.task)
elif "chatglm" in config.model_name_or_path.lower():
from .models.chatglm_model import ChatGlmModel
adapter = ChatGlmModel()
adapter = ChatGlmModel(config.model_name_or_path, config.task)
elif "qwen" in config.model_name_or_path.lower():
from .models.qwen_model import QwenModel
adapter = QwenModel()
adapter = QwenModel(config.model_name_or_path, config.task)
elif "mistral" in config.model_name_or_path.lower():
from .models.mistral_model import MistralModel
adapter = MistralModel()
adapter = MistralModel(config.model_name_or_path, config.task)
elif "solar" in config.model_name_or_path.lower():
from .models.solar_model import SolarModel
adapter = SolarModel()
adapter = SolarModel(config.model_name_or_path, config.task)
elif "opt" in config.model_name_or_path.lower() or \
"gpt" in config.model_name_or_path.lower() or \
"flan-t5" in config.model_name_or_path.lower() or \
Expand All @@ -144,12 +144,13 @@ def build_chatbot(config: PipelineConfig=None):
"mixtral" in config.model_name_or_path.lower() or \
"phi-2" in config.model_name_or_path.lower():
from .models.base_model import BaseModel
adapter = BaseModel()
adapter = BaseModel(config.model_name_or_path, config.task)
else:
set_latest_error(ErrorCodes.ERROR_MODEL_NOT_SUPPORTED)
logging.error("build_chatbot: unknown model")
return

from .models.base_model import register_model_adapter
register_model_adapter(adapter)
# register plugin instance in model adaptor
if config.plugins:
for plugin_name, plugin_value in config.plugins.items():
Expand Down
3 changes: 2 additions & 1 deletion intel_extension_for_transformers/neural_chat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def __init__(self,
tokenizer_name_or_path=None,
hf_access_token=None,
device="auto",
task="",
plugins=plugins,
loading_config=None,
optimization_config=None,
Expand All @@ -462,7 +463,7 @@ def __init__(self,
self.device = get_device_type()
else:
self.device = device

self.task = task
self.plugins = plugins

self.loading_config = loading_config if loading_config is not None else \
Expand Down
34 changes: 15 additions & 19 deletions intel_extension_for_transformers/neural_chat/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ class BaseModel(ABC):
A base class for LLM.
"""

def __init__(self):
def __init__(self, model_name, task="chat"):
"""
Initializes the BaseModel class.
"""
self.model_name = ""
self.model_name = model_name
self.asr = None
self.tts = None
self.face_animation = None
Expand All @@ -81,16 +81,14 @@ def __init__(self):
self.device = None
self.conv_template = None
self.ipex_int8 = None
self.get_conv_template(task)

def match(self, model_path: str):
def match(self):
"""
Check if the provided model_path matches the current model.
Args:
model_path (str): Path to a model.
Check if the provided model_name matches the current model.
Returns:
bool: True if the model_path matches, False otherwise.
bool: True if the model_name matches, False otherwise.
"""
return True

Expand Down Expand Up @@ -167,7 +165,6 @@ def predict_stream(self, query, origin_query="", config=None):
raise ValueError(f"The audio file path {query} is invalid.")

query_include_prompt = False
self.get_conv_template(self.model_name, config.task)
if (self.conv_template.roles[0] in query and self.conv_template.roles[1] in query) or \
"starcoder" in self.model_name.lower() or "codellama" in self.model_name.lower() or \
"codegen" in self.model_name.lower() or "magicoder" in self.model_name.lower() or \
Expand Down Expand Up @@ -219,7 +216,7 @@ def predict_stream(self, query, origin_query="", config=None):
assert query is not None, "Query cannot be None."

if not query_include_prompt and not is_plugin_enabled("retrieval"):
query = self.prepare_prompt(query, self.model_name, config.task)
query = self.prepare_prompt(query, config.task)

# Phind/Phind-CodeLlama-34B-v2 model accpects Alpaca/Vicuna instruction format.
if "phind" in self.model_name.lower():
Expand Down Expand Up @@ -278,7 +275,6 @@ def predict(self, query, origin_query="", config=None):
raise ValueError(f"The audio file path {query} is invalid.")

query_include_prompt = False
self.get_conv_template(self.model_name, config.task)
if (self.conv_template.roles[0] in query and self.conv_template.roles[1] in query) or \
"starcoder" in self.model_name.lower() or "codellama" in self.model_name.lower() or \
"codegen" in self.model_name.lower() or "magicoder" in self.model_name.lower():
Expand Down Expand Up @@ -326,7 +322,7 @@ def predict(self, query, origin_query="", config=None):

if not query_include_prompt and not is_plugin_enabled("retrieval") \
and not 'vllm' in str(MODELS[self.model_name]['model']):
query = self.prepare_prompt(query, self.model_name, config.task)
query = self.prepare_prompt(query, config.task)

# Phind/Phind-CodeLlama-34B-v2 model accpects Alpaca/Vicuna instruction format.
if "phind" in self.model_name.lower():
Expand Down Expand Up @@ -406,7 +402,7 @@ def face_animate(self, image_path, audio_path=None, text=None, voice=None) -> st
os.remove(audio_path)
return video_path

def get_default_conv_template(self, model_path: str) -> Conversation:
def get_default_conv_template(self) -> Conversation:
"""
Get the default conversation template for the given model path.
Expand All @@ -432,22 +428,22 @@ def get_conv_template(self, model_path: str, task: str = "") -> Conversation:
if self.conv_template:
return
if not task:
self.conv_template = PromptTemplate(self.get_default_conv_template(model_path).name, clear_history=True)
self.conv_template = PromptTemplate(self.get_default_conv_template().name, clear_history=True)
else:
clear_history = True
if task == "completion":
name = "alpaca_without_input"
elif task == "chat":
name = "neural-chat-7b-v2"
clear_history = False
name = self.get_default_conv_template().name
elif task == "summarization":
name = "summarization"
else:
raise NotImplementedError(f"Unsupported task {task}.")
self.conv_template = PromptTemplate(name, clear_history=clear_history)

def prepare_prompt(self, prompt: str, model_path: str, task: str = ""):
self.get_conv_template(model_path, task)
def prepare_prompt(self, prompt: str, task: str = ""):
self.get_conv_template(task)
self.conv_template.append_message(self.conv_template.roles[0], prompt)
self.conv_template.append_message(self.conv_template.roles[1], None)
return self.conv_template.get_prompt()
Expand Down Expand Up @@ -485,15 +481,15 @@ def register_plugin_instance(self, plugin_name, instance):

def register_model_adapter(cls):
"""Register a model adapter."""
model_adapters.append(cls())
model_adapters.append(cls)


def get_model_adapter(model_name_path: str) -> BaseModel:
"""Get a model adapter for a model_name_path."""
model_path_basename = os.path.basename(os.path.normpath(model_name_path)).lower()

for adapter in model_adapters:
if adapter.match(model_path_basename) and type(adapter) != BaseModel:
if adapter.match(model_path_basename):
return adapter

raise ValueError(f"No valid model adapter for {model_name_path}")
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_model import BaseModel, register_model_adapter
from .base_model import BaseModel
from fastchat.conversation import get_conv_template, Conversation
import logging

Expand All @@ -27,7 +27,7 @@
logger = logging.getLogger(__name__)

class ChatGlmModel(BaseModel):
def match(self, model_path: str):
def match(self):
"""
Check if the provided model_path matches the current model.
Expand All @@ -37,9 +37,9 @@ def match(self, model_path: str):
Returns:
bool: True if the model_path matches, False otherwise.
"""
return "chatglm" in model_path.lower()
return "chatglm" in self.model_name.lower()

def get_default_conv_template(self, model_path: str) -> Conversation:
def get_default_conv_template(self) -> Conversation:
"""
Get the default conversation template for the given model path.
Expand All @@ -49,9 +49,6 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
Returns:
Conversation: A default conversation template.
"""
model_path = model_path.lower()
if "chatglm2" in model_path.lower():
if "chatglm2" in self.model_name.lower():
return get_conv_template("chatglm2")
return get_conv_template("chatglm")

register_model_adapter(ChatGlmModel)
16 changes: 4 additions & 12 deletions intel_extension_for_transformers/neural_chat/models/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_model import BaseModel, register_model_adapter
from .base_model import BaseModel
import logging
from fastchat.conversation import get_conv_template, Conversation

Expand All @@ -27,28 +27,20 @@
logger = logging.getLogger(__name__)

class LlamaModel(BaseModel):
def match(self, model_path: str):
def match(self):
"""
Check if the provided model_path matches the current model.
Args:
model_path (str): Path to a model.
Returns:
bool: True if the model_path matches, False otherwise.
"""
return "llama" in model_path.lower()
return "llama" in self.model_name.lower()

def get_default_conv_template(self, model_path: str) -> Conversation:
def get_default_conv_template(self) -> Conversation:
"""
Get the default conversation template for the given model path.
Args:
model_path (str): Path to the model.
Returns:
Conversation: A default conversation template.
"""
return get_conv_template("llama-2")

register_model_adapter(LlamaModel)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_model import BaseModel, register_model_adapter
from .base_model import BaseModel
import logging
from fastchat.conversation import get_conv_template, Conversation

Expand All @@ -27,28 +27,20 @@
logger = logging.getLogger(__name__)

class MistralModel(BaseModel):
def match(self, model_path: str):
def match(self):
"""
Check if the provided model_path matches the current model.
Args:
model_path (str): Path to a model.
Returns:
bool: True if the model_path matches, False otherwise.
"""
return "mistral" in model_path.lower()
return "mistral" in self.model_name.lower()

def get_default_conv_template(self, model_path: str) -> Conversation:
def get_default_conv_template(self) -> Conversation:
"""
Get the default conversation template for the given model path.
Args:
model_path (str): Path to the model.
Returns:
Conversation: A default conversation template.
"""
return get_conv_template("mistral")

register_model_adapter(MistralModel)
16 changes: 4 additions & 12 deletions intel_extension_for_transformers/neural_chat/models/mpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_model import BaseModel, register_model_adapter
from .base_model import BaseModel
import logging
from fastchat.conversation import get_conv_template, Conversation

Expand All @@ -27,28 +27,20 @@
logger = logging.getLogger(__name__)

class MptModel(BaseModel):
def match(self, model_path: str):
def match(self):
"""
Check if the provided model_path matches the current model.
Args:
model_path (str): Path to a model.
Returns:
bool: True if the model_path matches, False otherwise.
"""
return "mpt" in model_path.lower()
return "mpt" in self.model_name.lower()

def get_default_conv_template(self, model_path: str) -> Conversation:
def get_default_conv_template(self) -> Conversation:
"""
Get the default conversation template for the given model path.
Args:
model_path (str): Path to the model.
Returns:
Conversation: A default conversation template.
"""
return get_conv_template("mpt-7b-chat")

register_model_adapter(MptModel)

0 comments on commit 2e1c79d

Please sign in to comment.