Skip to content

Commit

Permalink
Support Mistral model in NeuralChat (#710)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvliang-intel committed Nov 20, 2023
1 parent f7d6baa commit fcee612
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 12 deletions.
22 changes: 12 additions & 10 deletions intel_extension_for_transformers/neural_chat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,26 @@ def build_chatbot(config: PipelineConfig=None):
if "llama" in config.model_name_or_path.lower():
from .models.llama_model import LlamaModel
adapter = LlamaModel()
elif "mpt" in config.model_name_or_path:
elif "mpt" in config.model_name_or_path.lower():
from .models.mpt_model import MptModel
adapter = MptModel()
elif "neural-chat" in config.model_name_or_path:
elif "neural-chat" in config.model_name_or_path.lower():
from .models.neuralchat_model import NeuralChatModel
adapter = NeuralChatModel()
elif "chatglm" in config.model_name_or_path:
elif "chatglm" in config.model_name_or_path.lower():
from .models.chatglm_model import ChatGlmModel
adapter = ChatGlmModel()
elif "Qwen" in config.model_name_or_path:
elif "Qwen" in config.model_name_or_path.lower():
from .models.qwen_model import QwenModel
adapter = QwenModel()
elif "opt" in config.model_name_or_path or \
"gpt" in config.model_name_or_path or \
"Mistral" in config.model_name_or_path or \
"flan-t5" in config.model_name_or_path or \
"bloom" in config.model_name_or_path or \
"starcoder" in config.model_name_or_path:
elif "mistral" in config.model_name_or_path.lower():
from .models.mistral_model import MistralModel
adapter = MistralModel()
elif "opt" in config.model_name_or_path.lower() or \
"gpt" in config.model_name_or_path.lower() or \
"flan-t5" in config.model_name_or_path.lower() or \
"bloom" in config.model_name_or_path.lower() or \
"starcoder" in config.model_name_or_path.lower():
from .models.base_model import BaseModel
adapter = BaseModel()
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_model import BaseModel, register_model_adapter
import logging
from fastchat.conversation import get_conv_template, Conversation

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)

class MistralModel(BaseModel):
def match(self, model_path: str):
"""
Check if the provided model_path matches the current model.
Args:
model_path (str): Path to a model.
Returns:
bool: True if the model_path matches, False otherwise.
"""
return "mistral" in model_path.lower()

def get_default_conv_template(self, model_path: str) -> Conversation:
"""
Get the default conversation template for the given model path.
Args:
model_path (str): Path to the model.
Returns:
Conversation: A default conversation template.
"""
return get_conv_template("mistral")

register_model_adapter(MistralModel)
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def load_model(
or re.search("qwen", model_name, re.IGNORECASE)
or re.search("starcoder", model_name, re.IGNORECASE)
or re.search("codellama", model_name, re.IGNORECASE)
or re.search("Mistral", model_name, re.IGNORECASE)
or re.search("mistral", model_name, re.IGNORECASE)
) and not ipex_int8) or re.search("opt", model_name, re.IGNORECASE):
with smart_context_manager(use_deepspeed=use_deepspeed):
model = AutoModelForCausalLM.from_pretrained(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from intel_extension_for_transformers.neural_chat.models.llama_model import LlamaModel
from intel_extension_for_transformers.neural_chat.models.mpt_model import MptModel
from intel_extension_for_transformers.neural_chat.models.neuralchat_model import NeuralChatModel
from intel_extension_for_transformers.neural_chat.models.mistral_model import MistralModel
from intel_extension_for_transformers.neural_chat import build_chatbot, PipelineConfig
import unittest

Expand Down Expand Up @@ -144,6 +145,26 @@ def test_get_default_conv_template_v3_1(self):
print(result)
self.assertIn('The Intel Xeon Scalable Processors', str(result))

class TestMistralModel(unittest.TestCase):
def setUp(self):
return super().setUp()

def tearDown(self) -> None:
return super().tearDown()

def test_match(self):
result = MistralModel().match(model_path='mistralai/Mistral-7B-v0.1')
self.assertTrue(result)

def test_get_default_conv_template(self):
result = MistralModel().get_default_conv_template(model_path='mistralai/Mistral-7B-v0.1')
self.assertIn("[INST]{system_message}", str(result))
config = PipelineConfig(model_name_or_path="mistralai/Mistral-7B-v0.1")
chatbot = build_chatbot(config=config)
result = chatbot.predict("Tell me about Intel Xeon Scalable Processors.")
print(result)
self.assertIn('Intel Xeon Scalable processors', str(result))

class TestStarCoderModel(unittest.TestCase):
def setUp(self):
return super().setUp()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def http_bot(state, model_selector, temperature, max_new_tokens, topk, request:

if len(state.messages) == state.offset + 2:
# model conversation name: "mpt-7b-chat", "chatglm", "chatglm2", "llama-2",
# "neural-chat-7b-v3-1", "neural-chat-7b-v3",
# "mistral", "neural-chat-7b-v3-1", "neural-chat-7b-v3",
# "neural-chat-7b-v2", "neural-chat-7b-v1-1"
# First round of Conversation
if "Llama-2-7b-chat-hf" in model_name:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,19 @@ def get_conv_template(name: str) -> Conversation:
)
)

# Mistral template
# source: https://docs.mistral.ai/llm/mistral-instruct-v0.1#chat-template
register_conv_template(
Conversation(
name="mistral",
system_template="[INST]{system_message}\n",
roles=("[INST]", "[/INST]"),
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2="</s>",
)
)

# llama2 template
# reference: https://huggingface.co/blog/codellama#conversational-instructions
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
Expand Down

0 comments on commit fcee612

Please sign in to comment.