diff --git a/auto_round/compressors/mllm/compressor.py b/auto_round/compressors/mllm/compressor.py index 2cfa457b7..94210388f 100644 --- a/auto_round/compressors/mllm/compressor.py +++ b/auto_round/compressors/mllm/compressor.py @@ -27,6 +27,7 @@ from auto_round.logger import logger from auto_round.schemes import QuantizationScheme from auto_round.special_model_handler import ( + MISTRAL_3_2_MODELS, NOT_SUPPORT_ONLY_TEXT_MODELS, SUPPORT_ONLY_TEXT_MODELS, _handle_special_model, @@ -164,6 +165,7 @@ def __init__( seed: int = 42, **kwargs, ): + extra_data_dir = kwargs.pop("extra_data_dir", None) template = kwargs.pop("template", None) @@ -189,7 +191,7 @@ def __init__( if model.config.model_type == "llava" and isinstance(model, PreTrainedModel): template = "default" - if hasattr(model, "name_or_path") and "Mistral-Small-3.2" in model.name_or_path: + if hasattr(model, "name_or_path") and any([name in model.name_or_path for name in MISTRAL_3_2_MODELS]): template = "mistral3_2" if iters > 0: self.template = template if template is not None else model.config.model_type diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index d0fa8c962..a999b358b 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -39,6 +39,8 @@ CONVERT_EXPERT_TO_LINEAR_MODELS = ["llama4", "gpt_oss"] +MISTRAL_3_2_MODELS = ["Mistral-Small-3.2", "Magistral-Small", "Devstral-Small"] + def _get_moe_converter(config): # Dispatch table for model_type to replacement_info functions diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 427ec7f9c..1aab07983 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -307,6 +307,8 @@ def mllm_load_model( model_dtype: str = None, **kwargs, ): + from auto_round.special_model_handler import MISTRAL_3_2_MODELS + assert platform.lower() in [ "hf", "model_scope", @@ -410,7 +412,7 @@ def mllm_load_model( else: raise - if "Mistral-Small-3.2" in pretrained_model_name_or_path: + if any([name in model.name_or_path for name in MISTRAL_3_2_MODELS]): from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pylint: disable=E0401 if os.path.isdir(pretrained_model_name_or_path):