Skip to content

Commit

Permalink
[NeuralChat] Support magicoder model (#1067)
Browse files Browse the repository at this point in the history
* Support magicoder model and refine load model

Signed-off-by: lvliang-intel <liang1.lv@intel.com>
  • Loading branch information
lvliang-intel committed Dec 25, 2023
1 parent 3e14b05 commit f29c1ec
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ def optimize(self, model, use_llm_runtime=False):
or re.search("bloom", model_name, re.IGNORECASE)
or re.search("llama", model_name, re.IGNORECASE)
or re.search("opt", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v1", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v2", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v3", model_name, re.IGNORECASE)
or re.search("neural-chat", model_name, re.IGNORECASE)
or re.search("starcoder", model_name, re.IGNORECASE)
or re.search("codegen", model_name, re.IGNORECASE)
or re.search("mistral", model_name, re.IGNORECASE)
or re.search("magicoder", model_name, re.IGNORECASE)
or re.search("solar", model_name, re.IGNORECASE)
):
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
Expand Down
12 changes: 10 additions & 2 deletions intel_extension_for_transformers/neural_chat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def build_chatbot(config: PipelineConfig=None):
return

# create model adapter
if "llama" in config.model_name_or_path.lower() or "magicoder" in config.model_name_or_path.lower():
if "llama" in config.model_name_or_path.lower():
from .models.llama_model import LlamaModel
adapter = LlamaModel()
elif "mpt" in config.model_name_or_path.lower():
Expand All @@ -95,7 +95,8 @@ def build_chatbot(config: PipelineConfig=None):
"flan-t5" in config.model_name_or_path.lower() or \
"bloom" in config.model_name_or_path.lower() or \
"starcoder" in config.model_name_or_path.lower() or \
"codegen" in config.model_name_or_path.lower():
"codegen" in config.model_name_or_path.lower() or \
"magicoder" in config.model_name_or_path.lower():
from .models.base_model import BaseModel
adapter = BaseModel()
else:
Expand Down Expand Up @@ -163,6 +164,7 @@ def build_chatbot(config: PipelineConfig=None):
try:
adapter.load_model(parameters)
except RuntimeError as e:
logger.error(f"Exception: {e}")
if "out of memory" in str(e):
set_latest_error(ErrorCodes.ERROR_OUT_OF_MEMORY)
elif "devices are busy or unavailable" in str(e):
Expand All @@ -173,6 +175,7 @@ def build_chatbot(config: PipelineConfig=None):
set_latest_error(ErrorCodes.ERROR_GENERIC)
return
except ValueError as e:
logger.error(f"Exception: {e}")
if "load_model: unsupported device" in str(e):
set_latest_error(ErrorCodes.ERROR_DEVICE_NOT_SUPPORTED)
elif "load_model: unsupported model" in str(e):
Expand All @@ -187,6 +190,7 @@ def build_chatbot(config: PipelineConfig=None):
set_latest_error(ErrorCodes.ERROR_GENERIC)
return
except Exception as e:
logger.error(f"Exception: {e}")
set_latest_error(ErrorCodes.ERROR_GENERIC)
return
return adapter
Expand All @@ -204,14 +208,17 @@ def finetune_model(config: BaseFinetuningConfig):
try:
finetuning.finetune()
except FileNotFoundError as e:
logger.error(f"Exception: {e}")
if "Couldn't find a dataset script" in str(e):
set_latest_error(ErrorCodes.ERROR_DATASET_NOT_FOUND)
except ValueError as e:
logger.error(f"Exception: {e}")
if "--do_eval requires a validation dataset" in str(e):
set_latest_error(ErrorCodes.ERROR_VALIDATION_FILE_NOT_FOUND)
elif "--do_train requires a train dataset" in str(e):
set_latest_error(ErrorCodes.ERROR_TRAIN_FILE_NOT_FOUND)
except Exception as e:
logger.error(f"Exception: {e}")
if config.finetune_args.peft == "lora":
set_latest_error(ErrorCodes.ERROR_LORA_FINETUNE_FAIL)
elif config.finetune_args.peft == "llama_adapter":
Expand All @@ -237,6 +244,7 @@ def optimize_model(model, config, use_llm_runtime=False):
try:
model = optimization.optimize(model, use_llm_runtime)
except Exception as e:
logger.error(f"Exception: {e}")
from intel_extension_for_transformers.transformers import (
MixedPrecisionConfig,
WeightOnlyQuantConfig,
Expand Down
18 changes: 16 additions & 2 deletions intel_extension_for_transformers/neural_chat/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..utils.common import is_audio_file
from .model_utils import load_model, predict, predict_stream, MODELS
from ..prompts import PromptTemplate
from ..prompts.prompt import MAGICODER_PROMPT
from ..utils.error_utils import set_latest_error
from ..errorcode import ErrorCodes
import logging
Expand Down Expand Up @@ -163,7 +164,7 @@ def predict_stream(self, query, origin_query="", config=None):
self.get_conv_template(self.model_name, config.task)
if (self.conv_template.roles[0] in query and self.conv_template.roles[1] in query) or \
"starcoder" in self.model_name.lower() or "codellama" in self.model_name.lower() or \
"codegen" in self.model_name.lower():
"codegen" in self.model_name.lower() or "magicoder" in self.model_name.lower():
query_include_prompt = True

# plugin pre actions
Expand Down Expand Up @@ -207,6 +208,16 @@ def predict_stream(self, query, origin_query="", config=None):
if not query_include_prompt and not is_plugin_enabled("retrieval"):
query = self.prepare_prompt(query, self.model_name, config.task)

# Phind/Phind-CodeLlama-34B-v2 model accpects Alpaca/Vicuna instruction format.
if "phind" in self.model_name.lower():
conv_template = PromptTemplate(name="phind")
conv_template.append_message(conv_template.roles[0], query)
conv_template.append_message(conv_template.roles[1], None)
query = conv_template.get_prompt()

if "magicoder" in self.model_name.lower():
query = MAGICODER_PROMPT.format(instruction=query)

try:
response = predict_stream(
**construct_parameters(query, self.model_name, self.device, self.assistant_model, config))
Expand Down Expand Up @@ -256,7 +267,7 @@ def predict(self, query, origin_query="", config=None):
self.get_conv_template(self.model_name, config.task)
if (self.conv_template.roles[0] in query and self.conv_template.roles[1] in query) or \
"starcoder" in self.model_name.lower() or "codellama" in self.model_name.lower() or \
"codegen" in self.model_name.lower():
"codegen" in self.model_name.lower() or "magicoder" in self.model_name.lower():
query_include_prompt = True

# plugin pre actions
Expand Down Expand Up @@ -298,6 +309,9 @@ def predict(self, query, origin_query="", config=None):
conv_template.append_message(conv_template.roles[1], None)
query = conv_template.get_prompt()

if "magicoder" in self.model_name.lower():
query = MAGICODER_PROMPT.format(instruction=query)

# LLM inference
try:
response = predict(
Expand Down
84 changes: 49 additions & 35 deletions intel_extension_for_transformers/neural_chat/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,30 +416,45 @@ def load_model(
else:
MODELS[model_name]["assistant_model"] = None

try:
config = AutoConfig.from_pretrained(model_name, use_auth_token=hf_access_token, trust_remote_code=True \
if (re.search("chatglm", model_name, re.IGNORECASE) or \
re.search("qwen", model_name, re.IGNORECASE)) else False)
except ValueError as e:
logging.error(f"Exception: {e}")
if "Unrecognized model in" in str(e):
raise ValueError(f"load_model: model config is not found, {e}")
else:
raise ValueError(f"load_model: unknown ValueError occurred, {e}")
except EnvironmentError as e:
logging.error(f"Exception: {e}")
if "not a local folder and is not a valid model identifier" in str(e):
raise ValueError(f"load_model: model name or path is not found, {e}")
else:
raise ValueError(f"load_model: unknown EnvironmentError occurred, {e}")
except Exception as e:
logging.error(f"Exception: {e}")
raise ValueError(f"load_model: an unexpected error occurred, {e}")

MODELS[model_name]["model_type"] = config.model_type

try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
use_fast=False if (re.search("llama", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v2", model_name, re.IGNORECASE)) else True,
use_fast=False if config.model_type == "llama" else True,
use_auth_token=hf_access_token,
trust_remote_code=True if (re.search("qwen", model_name, re.IGNORECASE) or \
re.search("chatglm", model_name, re.IGNORECASE)) else False,
)
except EnvironmentError as e:
logging.error(f"Exception: {e}")
if "not a local folder and is not a valid model identifier" in str(e):
raise ValueError("load_model: tokenizer is not found")
else:
raise

try:
config = AutoConfig.from_pretrained(model_name, use_auth_token=hf_access_token, trust_remote_code=True \
if (re.search("chatglm", model_name, re.IGNORECASE) or \
re.search("qwen", model_name, re.IGNORECASE)) else False)
except ValueError as e:
if "Unrecognized model in" in str(e):
raise ValueError("load_model: model config is not found")
raise ValueError(f"load_model: tokenizer is not found, {e}")
else:
raise
raise ValueError(f"load_model: unknown EnvironmentError occurred, {e}")
except Exception as e:
logging.error(f"Exception: {e}")
raise ValueError(f"load_model: an unexpected error occurred, {e}")

load_to_meta = model_on_meta(config)

Expand Down Expand Up @@ -478,33 +493,26 @@ def load_model(
trust_remote_code=True)
elif ((
re.search("gpt", model_name, re.IGNORECASE)
or re.search("mpt", model_name, re.IGNORECASE)
or re.search("bloom", model_name, re.IGNORECASE)
or re.search("llama", model_name, re.IGNORECASE)
or re.search("magicoder", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v1", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v2", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v3", model_name, re.IGNORECASE)
or re.search("qwen", model_name, re.IGNORECASE)
or re.search("starcoder", model_name, re.IGNORECASE)
or re.search("codellama", model_name, re.IGNORECASE)
or re.search("mistral", model_name, re.IGNORECASE)
or re.search("codegen", model_name, re.IGNORECASE)
) and not ipex_int8) or re.search("opt", model_name, re.IGNORECASE):
or config.model_type == "bloom"
or config.model_type == "qwen"
or config.model_type == "gpt_bigcode"
or config.model_type == "mpt"
or config.model_type == "llama"
or config.model_type == "mistral"
) and not ipex_int8) or config.model_type == "opt":
with smart_context_manager(use_deepspeed=use_deepspeed):
model = AutoModelForCausalLM.from_pretrained(
model_name,
use_auth_token=hf_access_token,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
quantization_config=bitsandbytes_quant_config,
trust_remote_code=True if (re.search("qwen", model_name, re.IGNORECASE) or \
trust_remote_code=True if (config.model_type == "qwen" or \
re.search("codegen", model_name, re.IGNORECASE)) else False
)
elif (
(re.search("starcoder", model_name, re.IGNORECASE)
or re.search("codellama", model_name, re.IGNORECASE)
or re.search("codegen", model_name, re.IGNORECASE)
(config.model_type == "gpt_bigcode"
or config.model_type == "llama"
) and ipex_int8
):
with smart_context_manager(use_deepspeed=use_deepspeed):
Expand All @@ -520,9 +528,9 @@ def load_model(
model_name,
file_name="best_model.pt",
)
elif(
(re.search("llama", model_name, re.IGNORECASE)
or re.search("opt", model_name, re.IGNORECASE)
elif (
(config.model_type == "llama"
or config.model_type == "opt"
or re.search("gpt_neox", model_name, re.IGNORECASE)
or re.search("gptj", model_name, re.IGNORECASE)
or re.search("falcon", model_name, re.IGNORECASE)
Expand All @@ -547,10 +555,14 @@ def load_model(
raise ValueError(f"unsupported model name or path {model_name}, \
only supports FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL/CODELLAMA/STARCODER/CODEGEN now.")
except EnvironmentError as e:
logging.error(f"Exception: {e}")
if "not a local folder and is not a valid model identifier" in str(e):
raise ValueError("load_model: model name or path is not found")
else:
raise
raise ValueError(f"load_model: unknown EnvironmentError occurred, {e}")
except Exception as e:
logging.error(f"Exception: {e}")
raise ValueError(f"load_model: an unexpected error occurred, {e}")

if re.search("llama", model.config.architectures[0], re.IGNORECASE):
# unwind broken decapoda-research config
Expand Down Expand Up @@ -1192,6 +1204,8 @@ def predict(**params):
output = tokenizer.decode(generation_output.sequences[0], skip_special_tokens=True)
if "### Response:" in output:
return output.split("### Response:")[1].strip()
if "@@ Response" in output:
return output.split("@@ Response")[1].strip()
if "### Assistant" in output:
return output.split("### Assistant:")[1].strip()
if "\nassistant\n" in output:
Expand Down
11 changes: 10 additions & 1 deletion intel_extension_for_transformers/neural_chat/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,13 @@ def get_prompt(self) -> str:
return res

def clear_messages(self) -> str:
self.conv.messages = []
self.conv.messages = []

# pylint: disable=C0301
MAGICODER_PROMPT = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.
@@ Instruction
{instruction}
@@ Response
"""

0 comments on commit f29c1ec

Please sign in to comment.