Skip to content

Commit

Permalink
Support CodeLlama model in NeuralChat (#711)
Browse files Browse the repository at this point in the history
* Support neural-chat-7b-v3 and neural-chat-7b-v3-1

Signed-off-by: lvliang-intel <liang1.lv@intel.com>
  • Loading branch information
lvliang-intel committed Nov 20, 2023
1 parent df40d5e commit 7baa96b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 4 deletions.
2 changes: 1 addition & 1 deletion intel_extension_for_transformers/neural_chat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def build_chatbot(config: PipelineConfig=None):
adapter = BaseModel()
else:
raise ValueError("NeuralChat Error: Unsupported model name or path, \
only supports FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL now.")
only supports FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL/CODELLAMA/STARCODER now.")

# register plugin instance in model adaptor
if config.plugins:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def predict_stream(self, query, config=None):
query_include_prompt = False
self.get_conv_template(self.model_name, config.task)
if (self.conv_template.roles[0] in query and self.conv_template.roles[1] in query) or \
"starcoder" in self.model_name:
"starcoder" in self.model_name or "codellama" in self.model_name.lower():
query_include_prompt = True

# plugin pre actions
Expand Down Expand Up @@ -220,7 +220,7 @@ def predict(self, query, config=None):
query_include_prompt = False
self.get_conv_template(self.model_name, config.task)
if (self.conv_template.roles[0] in query and self.conv_template.roles[1] in query) or \
"starcoder" in self.model_name:
"starcoder" in self.model_name or "codellama" in self.model_name.lower():
query_include_prompt = True

# plugin pre actions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def load_model(
or re.search("neural-chat-7b-v3", model_name, re.IGNORECASE)
or re.search("qwen", model_name, re.IGNORECASE)
or re.search("starcoder", model_name, re.IGNORECASE)
or re.search("codellama", model_name, re.IGNORECASE)
or re.search("Mistral", model_name, re.IGNORECASE)
) and not ipex_int8) or re.search("opt", model_name, re.IGNORECASE):
with smart_context_manager(use_deepspeed=use_deepspeed):
Expand All @@ -377,6 +378,7 @@ def load_model(
)
elif (
(re.search("starcoder", model_name, re.IGNORECASE)
or re.search("codellama", model_name, re.IGNORECASE)
) and ipex_int8
):
with smart_context_manager(use_deepspeed=use_deepspeed):
Expand All @@ -389,7 +391,7 @@ def load_model(
else:
raise ValueError(
f"Unsupported model {model_name}, only supports "
"FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL now."
"FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL/CODELLAMA/STARCODER now."
)

if re.search("llama", model.config.architectures[0], re.IGNORECASE):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,33 @@ def test_get_default_conv_template_v3_1(self):
print(result)
self.assertIn('The Intel Xeon Scalable Processors', str(result))

class TestStarCoderModel(unittest.TestCase):
def setUp(self):
return super().setUp()

def tearDown(self) -> None:
return super().tearDown()

def test_code_gen(self):
config = PipelineConfig(model_name_or_path="bigcode/starcoder")
chatbot = build_chatbot(config=config)
result = chatbot.predict("def print_hello_world():")
print(result)
self.assertIn("""print('Hello World')""", str(result))

class TestCodeLlamaModel(unittest.TestCase):
def setUp(self):
return super().setUp()

def tearDown(self) -> None:
return super().tearDown()

def test_code_gen(self):
config = PipelineConfig(model_name_or_path="codellama/CodeLlama-7b-hf")
chatbot = build_chatbot(config=config)
result = chatbot.predict("def print_hello_world():")
print(result)
self.assertIn("""print('Hello World')""", str(result))

if __name__ == "__main__":
unittest.main()

0 comments on commit 7baa96b

Please sign in to comment.