diff --git a/auto_round/__main__.py b/auto_round/__main__.py index c4b2683f5..76a8f73d1 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -52,6 +52,7 @@ def __init__(self, *args, **kwargs): help="Path to the pre-trained model or model identifier from huggingface.co/models. " "Examples: 'facebook/opt-125m', 'bert-base-uncased', or local path like '/path/to/model'", ) + basic.add_argument("--model_dtype", default=None, help="model dtype used to load the pre-trained model") basic.add_argument( "--platform", default="hf", @@ -589,6 +590,7 @@ def tune(args): enable_adam=args.adam, extra_config=extra_config, layer_config=layer_config, + model_dtype=args.model_dtype, ) model_name = args.model.rstrip("/") @@ -765,6 +767,7 @@ def tune(args): batch_size=args.eval_bs, limit=args.limit, eval_model_dtype=eval_model_dtype, + mllm=autoround.mllm, # pylint: disable=E1101 ) else: from auto_round.eval.evaluation import simple_evaluate @@ -775,8 +778,15 @@ def tune(args): st = time.time() if "llama" in args.model.lower(): model_args += ",add_bos_token=True" + if autoround.mllm: # pylint: disable=E1101 + model_type = "hf-multimodal" + if args.eval_bs is None or args.eval_bs == "auto": + logger.warning("hf-multimodal models does not support auto currently, reset eval_bs to 16") + args.eval_bs = 16 + else: + model_type = "hf" res = simple_evaluate( - model="hf", + model=model_type, model_args=model_args, tasks=tasks, device=device_str, @@ -794,10 +804,15 @@ def setup_eval_parser(): def run_eval(): + from auto_round.utils import is_mllm_model + args = setup_eval_parser() assert args.model or args.model_name, "[model] or --model MODEL_NAME should be set." if args.model is None: args.model = args.model_name + if is_mllm_model(args.model): + args.mllm = True + if args.eval_task_by_task: eval_task_by_task( model=args.model, diff --git a/auto_round/autoround.py b/auto_round/autoround.py index ff06ec851..d52914447 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -120,6 +120,7 @@ def __new__( device_map (str | dict, optional): Device placement map. Defaults to None. disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0). Defaults to False. enable_alg_ext (bool, optional): Enable algorithm extension (primarily for INT2). Defaults to False. + model_dtype (str): model dtype used to load pre-trained model. **kwargs: Backward compatible options: - enable_alg_ext, quant_lm_head, lr, lr_scheduler, sampler, not_use_best_mse, dynamic_max_gap, super_group_size, super_bits, scale_dtype ("fp16" etc.), diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index bed19c3e9..ef05b3eea 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -229,6 +229,7 @@ def __init__( disable_deterministic_algorithms = kwargs.pop("disable_deterministic_algorithms", True) enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False) static_kv_dtype = kwargs.pop("static_kv_dtype", None) + model_dtype = kwargs.pop("model_dtype", None) device = kwargs.pop("device", None) if envs.AR_USE_MODELSCOPE: platform = "model_scope" @@ -266,6 +267,7 @@ def __init__( model, platform=platform, device="cpu", # always load cpu first + model_dtype=model_dtype, ) elif tokenizer is None and not self.diffusion and iters > 0: raise ValueError("A tokenizer must be set for non-str model input") @@ -481,6 +483,7 @@ def _parse_and_set(scheme, kwargs): # We’d better keep the string scheme instead of the dict config, # since GGUF uses different mixed-bit strategies for q4_k_s and q4_k_m # even though they share the same scheme dict. + scheme = scheme.strip("'\" ") res = scheme scheme = scheme.upper() scheme = asdict(preset_name_to_scheme(scheme)) diff --git a/auto_round/compressors/diffusion/compressor.py b/auto_round/compressors/diffusion/compressor.py index 904767602..f8680774f 100644 --- a/auto_round/compressors/diffusion/compressor.py +++ b/auto_round/compressors/diffusion/compressor.py @@ -101,6 +101,7 @@ def __init__( **kwargs, ): logger.warning("Diffusion model quantization is experimental and is only validated on Flux models.") + model_dtype = kwargs.pop("model_dtype", None) self.guidance_scale = guidance_scale self.num_inference_steps = num_inference_steps @@ -112,7 +113,7 @@ def __init__( self._set_device(device_map) if isinstance(model, str): - pipe, model = diffusion_load_model(model, platform=platform, device=self.device) + pipe, model = diffusion_load_model(model, platform=platform, device=self.device, model_dtype=model_dtype) elif isinstance(model, pipeline_utils.DiffusionPipeline): pipe = model model = pipe.transformer diff --git a/auto_round/compressors/mllm/compressor.py b/auto_round/compressors/mllm/compressor.py index 2cfa457b7..ada5dcf27 100644 --- a/auto_round/compressors/mllm/compressor.py +++ b/auto_round/compressors/mllm/compressor.py @@ -166,6 +166,7 @@ def __init__( ): extra_data_dir = kwargs.pop("extra_data_dir", None) template = kwargs.pop("template", None) + model_dtype = kwargs.pop("model_dtype", None) to_quant_block_names: Union[str, list, None] = kwargs.pop("to_quant_block_names", None) if device_map is None: @@ -173,7 +174,9 @@ def __init__( self._set_device(device_map) if isinstance(model, str): - model, processor, tokenizer, image_processor = mllm_load_model(model, platform=platform, device=self.device) + model, processor, tokenizer, image_processor = mllm_load_model( + model, platform=platform, device=self.device, model_dtype=model_dtype + ) self.model = model quant_nontext_module = self._check_quant_nontext(layer_config, quant_nontext_module) diff --git a/auto_round/eval/eval_cli.py b/auto_round/eval/eval_cli.py index 4311cdeb8..009b6458d 100644 --- a/auto_round/eval/eval_cli.py +++ b/auto_round/eval/eval_cli.py @@ -220,6 +220,7 @@ def eval_task_by_task( trust_remote_code=True, eval_model_dtype=None, retry_times=3, + mllm=False, ): set_cuda_visible_devices(device) device_str, parallelism = get_device_and_parallelism(device) @@ -228,6 +229,7 @@ def eval_task_by_task( import traceback from lm_eval import simple_evaluate as lm_simple_evaluate # pylint: disable=E0611 + from lm_eval.models.hf_vlms import HFMultimodalLM from lm_eval.models.huggingface import HFLM from transformers import AutoModelForCausalLM, AutoTokenizer @@ -263,16 +265,31 @@ def eval_task_by_task( ) model.eval() parallelism = False - hflm = HFLM( - pretrained=model, - tokenizer=tokenizer, - device=device_str, - batch_size=batch_size, - max_batch_size=max_batch_size, - parallelize=parallelism, - trust_remote_code=trust_remote_code, - dtype=eval_model_dtype, - ) + if mllm: + if batch_size is None or batch_size == "auto": + logger.warning("hf-multimodal models does not support auto currently, reset eval_bs to 16") + batch_size = 16 + hflm = HFMultimodalLM( + pretrained=model, + tokenizer=tokenizer, + device=device_str, + batch_size=batch_size, + max_batch_size=max_batch_size, + parallelize=parallelism, + trust_remote_code=trust_remote_code, + dtype=eval_model_dtype, + ) + else: + hflm = HFLM( + pretrained=model, + tokenizer=tokenizer, + device=device_str, + batch_size=batch_size, + max_batch_size=max_batch_size, + parallelize=parallelism, + trust_remote_code=trust_remote_code, + dtype=eval_model_dtype, + ) if isinstance(tasks, str): tasks = tasks.replace(" ", "").split(",") diff --git a/auto_round/eval/evaluation.py b/auto_round/eval/evaluation.py index 9722b6696..00a0fdca0 100644 --- a/auto_round/eval/evaluation.py +++ b/auto_round/eval/evaluation.py @@ -17,8 +17,11 @@ from lm_eval import simple_evaluate as lm_simple_evaluate # pylint: disable=E0611 +from auto_round.logger import logger + os.environ["TOKENIZERS_PARALLELISM"] = "false" +from lm_eval.models.hf_vlms import HFMultimodalLM from lm_eval.models.huggingface import HFLM @@ -30,16 +33,30 @@ def simple_evaluate_user_model( max_batch_size: Optional[int] = 64, eval_model_dtype="auto", add_bos_token: bool = False, + mllm: bool = False, **kwargs ): - hflm = HFLM( - pretrained=user_model, - tokenizer=tokenizer, - batch_size=batch_size, - max_batch_size=max_batch_size, - dtype=eval_model_dtype, - add_bos_token=add_bos_token, - ) + if mllm: + if batch_size is None or batch_size == "auto": + logger.warning("hf-multimodal models does not support auto currently, reset eval_bs to 16") + batch_size = 16 + hflm = HFMultimodalLM( + pretrained=user_model, + tokenizer=tokenizer, + batch_size=batch_size, + max_batch_size=max_batch_size, + dtype=eval_model_dtype, + add_bos_token=add_bos_token, + ) + else: + hflm = HFLM( + pretrained=user_model, + tokenizer=tokenizer, + batch_size=batch_size, + max_batch_size=max_batch_size, + dtype=eval_model_dtype, + add_bos_token=add_bos_token, + ) return lm_simple_evaluate( model=hflm, model_args=None, batch_size=batch_size, max_batch_size=max_batch_size, limit=limit, **kwargs )