diff --git a/intel_extension_for_transformers/llm/quantization/optimization.py b/intel_extension_for_transformers/llm/quantization/optimization.py index 6710f8d85fb..30062aa7e58 100644 --- a/intel_extension_for_transformers/llm/quantization/optimization.py +++ b/intel_extension_for_transformers/llm/quantization/optimization.py @@ -56,6 +56,7 @@ def optimize(self, model, use_llm_runtime=False): or re.search("neural-chat-7b-v2", model_name, re.IGNORECASE) or re.search("neural-chat-7b-v3", model_name, re.IGNORECASE) or re.search("starcoder", model_name, re.IGNORECASE) + or re.search("solar", model_name, re.IGNORECASE) ): from intel_extension_for_transformers.transformers import AutoModelForCausalLM optimized_model = AutoModelForCausalLM.from_pretrained( diff --git a/intel_extension_for_transformers/neural_chat/chatbot.py b/intel_extension_for_transformers/neural_chat/chatbot.py index b119d9622f7..b4d9613365f 100644 --- a/intel_extension_for_transformers/neural_chat/chatbot.py +++ b/intel_extension_for_transformers/neural_chat/chatbot.py @@ -87,6 +87,9 @@ def build_chatbot(config: PipelineConfig=None): elif "mistral" in config.model_name_or_path.lower(): from .models.mistral_model import MistralModel adapter = MistralModel() + elif "solar" in config.model_name_or_path.lower(): + from .models.solar_model import SolarModel + adapter = SolarModel() elif "opt" in config.model_name_or_path.lower() or \ "gpt" in config.model_name_or_path.lower() or \ "flan-t5" in config.model_name_or_path.lower() or \ diff --git a/intel_extension_for_transformers/neural_chat/models/solar_model.py b/intel_extension_for_transformers/neural_chat/models/solar_model.py new file mode 100644 index 00000000000..d1f29ae1028 --- /dev/null +++ b/intel_extension_for_transformers/neural_chat/models/solar_model.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base_model import BaseModel, register_model_adapter +import logging +from fastchat.conversation import get_conv_template, Conversation, register_conv_template, SeparatorStyle + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +# Solar-10.7B Chat Template +# Reference: https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0/blob/main/tokenizer_config.json +register_conv_template( + Conversation( + name="solar", + system_message="", + roles=("### User", "### Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, + sep="\n\n", + stop_str="", + ) +) + +class SolarModel(BaseModel): + def match(self, model_path: str): + """ + Check if the provided model_path matches the current model. + + Args: + model_path (str): Path to a model. + + Returns: + bool: True if the model_path matches, False otherwise. + """ + return "solar-" in model_path.lower() and "instruct" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + """ + Get the default conversation template for the given model path. + + Args: + model_path (str): Path to the model. + + Returns: + Conversation: A default conversation template. + """ + return get_conv_template("solar") + +register_model_adapter(SolarModel) +