From 72e89a4e0e8f299b933e1b7800d350ba1d46f406 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Mon, 3 Nov 2025 21:19:49 -0500 Subject: [PATCH 1/4] support model_dtype and fix bug of scheme contains quotes Signed-off-by: n1ck-guo --- auto_round/__main__.py | 2 ++ auto_round/autoround.py | 3 +++ auto_round/compressors/adam.py | 2 ++ auto_round/compressors/base.py | 4 ++++ auto_round/compressors/diffusion/compressor.py | 4 +++- auto_round/compressors/mllm/compressor.py | 6 +++++- 6 files changed, 19 insertions(+), 2 deletions(-) diff --git a/auto_round/__main__.py b/auto_round/__main__.py index c4b2683f5..5b3a77060 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -52,6 +52,7 @@ def __init__(self, *args, **kwargs): help="Path to the pre-trained model or model identifier from huggingface.co/models. " "Examples: 'facebook/opt-125m', 'bert-base-uncased', or local path like '/path/to/model'", ) + basic.add_argument("--model_dtype", default=None, help="model dtype used to load the pre-trained model") basic.add_argument( "--platform", default="hf", @@ -589,6 +590,7 @@ def tune(args): enable_adam=args.adam, extra_config=extra_config, layer_config=layer_config, + model_dtype=args.model_dtype, ) model_name = args.model.rstrip("/") diff --git a/auto_round/autoround.py b/auto_round/autoround.py index ff06ec851..6e6a5de42 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -85,6 +85,7 @@ def __new__( enable_adam: bool = False, # for MLLM and Diffusion extra_config: ExtraConfig = None, + model_dtype: str = None, **kwargs, ) -> BaseCompressor: """Initialize AutoRound with quantization and tuning configuration. @@ -120,6 +121,7 @@ def __new__( device_map (str | dict, optional): Device placement map. Defaults to None. disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0). Defaults to False. enable_alg_ext (bool, optional): Enable algorithm extension (primarily for INT2). Defaults to False. + model_dtype (str): model dtype used to load pre-trained model. **kwargs: Backward compatible options: - enable_alg_ext, quant_lm_head, lr, lr_scheduler, sampler, not_use_best_mse, dynamic_max_gap, super_group_size, super_bits, scale_dtype ("fp16" etc.), @@ -185,6 +187,7 @@ def __new__( device_map=device_map, enable_torch_compile=enable_torch_compile, seed=seed, + model_dtype=model_dtype, **kwargs, ) return ar diff --git a/auto_round/compressors/adam.py b/auto_round/compressors/adam.py index fb79cf39a..e427a0444 100644 --- a/auto_round/compressors/adam.py +++ b/auto_round/compressors/adam.py @@ -101,6 +101,7 @@ def __init__( enable_torch_compile: bool = False, seed: int = 42, optimizer="AdamW", + model_dtype: str = None, **kwargs, ): super(AdamCompressor, self).__init__( @@ -119,6 +120,7 @@ def __init__( gradient_accumulate_steps=gradient_accumulate_steps, enable_torch_compile=enable_torch_compile, device_map=device_map, + model_dtype=model_dtype, **kwargs, ) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index bed19c3e9..39154f22e 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -114,6 +114,7 @@ class BaseCompressor(object): layer_config (dict): Per-layer quantization configuration. nsamples (int): Number of calibration samples. enable_torch_compile (bool): Whether to enable compile_func for quant blocks/layers. + model_dtype (str): model dtype used to load pre-trained model. """ bits: int | None @@ -147,6 +148,7 @@ def __init__( enable_alg_ext: bool = False, disable_opt_rtn: bool = False, seed: int = 42, + model_dtype: str = None, **kwargs, ): """Initialize AutoRound with quantization and tuning configuration. @@ -266,6 +268,7 @@ def __init__( model, platform=platform, device="cpu", # always load cpu first + model_dtype=model_dtype, ) elif tokenizer is None and not self.diffusion and iters > 0: raise ValueError("A tokenizer must be set for non-str model input") @@ -481,6 +484,7 @@ def _parse_and_set(scheme, kwargs): # We’d better keep the string scheme instead of the dict config, # since GGUF uses different mixed-bit strategies for q4_k_s and q4_k_m # even though they share the same scheme dict. + scheme = scheme.strip("'\" ") res = scheme scheme = scheme.upper() scheme = asdict(preset_name_to_scheme(scheme)) diff --git a/auto_round/compressors/diffusion/compressor.py b/auto_round/compressors/diffusion/compressor.py index 904767602..fe8749776 100644 --- a/auto_round/compressors/diffusion/compressor.py +++ b/auto_round/compressors/diffusion/compressor.py @@ -63,6 +63,7 @@ class DiffusionCompressor(BaseCompressor): low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False). device_map (str | dict | int | torch.device, optional): Device placement map. Defaults to 0. enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer + model_dtype (str): model dtype used to load pre-trained model. **kwargs: Additional keyword arguments. """ @@ -98,6 +99,7 @@ def __init__( device_map: Union[str, torch.device, int, dict] = 0, enable_torch_compile: bool = False, seed: int = 42, + model_dtype: str = None, **kwargs, ): logger.warning("Diffusion model quantization is experimental and is only validated on Flux models.") @@ -112,7 +114,7 @@ def __init__( self._set_device(device_map) if isinstance(model, str): - pipe, model = diffusion_load_model(model, platform=platform, device=self.device) + pipe, model = diffusion_load_model(model, platform=platform, device=self.device, model_dtype=model_dtype) elif isinstance(model, pipeline_utils.DiffusionPipeline): pipe = model model = pipe.transformer diff --git a/auto_round/compressors/mllm/compressor.py b/auto_round/compressors/mllm/compressor.py index 2cfa457b7..71d95d6ad 100644 --- a/auto_round/compressors/mllm/compressor.py +++ b/auto_round/compressors/mllm/compressor.py @@ -127,6 +127,7 @@ class MLLMCompressor(BaseCompressor): to_quant_block_names (str|list): A string or list whose elements are list of block's layer names to be quantized. enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer + model_dtype (str): model dtype used to load pre-trained model. **kwargs: Additional keyword arguments. """ @@ -162,6 +163,7 @@ def __init__( device_map: Union[str, torch.device, int, dict] = 0, enable_torch_compile: bool = False, seed: int = 42, + model_dtype: str = None, **kwargs, ): extra_data_dir = kwargs.pop("extra_data_dir", None) @@ -173,7 +175,9 @@ def __init__( self._set_device(device_map) if isinstance(model, str): - model, processor, tokenizer, image_processor = mllm_load_model(model, platform=platform, device=self.device) + model, processor, tokenizer, image_processor = mllm_load_model( + model, platform=platform, device=self.device, model_dtype=model_dtype + ) self.model = model quant_nontext_module = self._check_quant_nontext(layer_config, quant_nontext_module) From 1a167d575ceeb0c4bd846c385546f544fabb6387 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 4 Nov 2025 00:39:05 -0500 Subject: [PATCH 2/4] fix mllm eval Signed-off-by: n1ck-guo --- auto_round/__main__.py | 15 +++++++++++++- auto_round/eval/eval_cli.py | 37 +++++++++++++++++++++++++---------- auto_round/eval/evaluation.py | 33 +++++++++++++++++++++++-------- 3 files changed, 66 insertions(+), 19 deletions(-) diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 5b3a77060..86efd9a31 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -767,6 +767,7 @@ def tune(args): batch_size=args.eval_bs, limit=args.limit, eval_model_dtype=eval_model_dtype, + mllm=autoround.mllm, ) else: from auto_round.eval.evaluation import simple_evaluate @@ -777,8 +778,15 @@ def tune(args): st = time.time() if "llama" in args.model.lower(): model_args += ",add_bos_token=True" + if autoround.mllm: + model_type = "hf-multimodal" + if args.eval_bs is None or args.eval_bs == "auto": + logger.warning("hf-multimodal models does not support auto currently, reset eval_bs to 16") + args.eval_bs = 16 + else: + model_type = "hf" res = simple_evaluate( - model="hf", + model=model_type, model_args=model_args, tasks=tasks, device=device_str, @@ -796,10 +804,15 @@ def setup_eval_parser(): def run_eval(): + from auto_round.utils import is_mllm_model + args = setup_eval_parser() assert args.model or args.model_name, "[model] or --model MODEL_NAME should be set." if args.model is None: args.model = args.model_name + if is_mllm_model(args.model): + args.mllm = True + if args.eval_task_by_task: eval_task_by_task( model=args.model, diff --git a/auto_round/eval/eval_cli.py b/auto_round/eval/eval_cli.py index 4311cdeb8..009b6458d 100644 --- a/auto_round/eval/eval_cli.py +++ b/auto_round/eval/eval_cli.py @@ -220,6 +220,7 @@ def eval_task_by_task( trust_remote_code=True, eval_model_dtype=None, retry_times=3, + mllm=False, ): set_cuda_visible_devices(device) device_str, parallelism = get_device_and_parallelism(device) @@ -228,6 +229,7 @@ def eval_task_by_task( import traceback from lm_eval import simple_evaluate as lm_simple_evaluate # pylint: disable=E0611 + from lm_eval.models.hf_vlms import HFMultimodalLM from lm_eval.models.huggingface import HFLM from transformers import AutoModelForCausalLM, AutoTokenizer @@ -263,16 +265,31 @@ def eval_task_by_task( ) model.eval() parallelism = False - hflm = HFLM( - pretrained=model, - tokenizer=tokenizer, - device=device_str, - batch_size=batch_size, - max_batch_size=max_batch_size, - parallelize=parallelism, - trust_remote_code=trust_remote_code, - dtype=eval_model_dtype, - ) + if mllm: + if batch_size is None or batch_size == "auto": + logger.warning("hf-multimodal models does not support auto currently, reset eval_bs to 16") + batch_size = 16 + hflm = HFMultimodalLM( + pretrained=model, + tokenizer=tokenizer, + device=device_str, + batch_size=batch_size, + max_batch_size=max_batch_size, + parallelize=parallelism, + trust_remote_code=trust_remote_code, + dtype=eval_model_dtype, + ) + else: + hflm = HFLM( + pretrained=model, + tokenizer=tokenizer, + device=device_str, + batch_size=batch_size, + max_batch_size=max_batch_size, + parallelize=parallelism, + trust_remote_code=trust_remote_code, + dtype=eval_model_dtype, + ) if isinstance(tasks, str): tasks = tasks.replace(" ", "").split(",") diff --git a/auto_round/eval/evaluation.py b/auto_round/eval/evaluation.py index 9722b6696..00a0fdca0 100644 --- a/auto_round/eval/evaluation.py +++ b/auto_round/eval/evaluation.py @@ -17,8 +17,11 @@ from lm_eval import simple_evaluate as lm_simple_evaluate # pylint: disable=E0611 +from auto_round.logger import logger + os.environ["TOKENIZERS_PARALLELISM"] = "false" +from lm_eval.models.hf_vlms import HFMultimodalLM from lm_eval.models.huggingface import HFLM @@ -30,16 +33,30 @@ def simple_evaluate_user_model( max_batch_size: Optional[int] = 64, eval_model_dtype="auto", add_bos_token: bool = False, + mllm: bool = False, **kwargs ): - hflm = HFLM( - pretrained=user_model, - tokenizer=tokenizer, - batch_size=batch_size, - max_batch_size=max_batch_size, - dtype=eval_model_dtype, - add_bos_token=add_bos_token, - ) + if mllm: + if batch_size is None or batch_size == "auto": + logger.warning("hf-multimodal models does not support auto currently, reset eval_bs to 16") + batch_size = 16 + hflm = HFMultimodalLM( + pretrained=user_model, + tokenizer=tokenizer, + batch_size=batch_size, + max_batch_size=max_batch_size, + dtype=eval_model_dtype, + add_bos_token=add_bos_token, + ) + else: + hflm = HFLM( + pretrained=user_model, + tokenizer=tokenizer, + batch_size=batch_size, + max_batch_size=max_batch_size, + dtype=eval_model_dtype, + add_bos_token=add_bos_token, + ) return lm_simple_evaluate( model=hflm, model_args=None, batch_size=batch_size, max_batch_size=max_batch_size, limit=limit, **kwargs ) From 4abd0226b440b6ea68acfe1a0fba97f1beac4d35 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 4 Nov 2025 00:43:16 -0500 Subject: [PATCH 3/4] update Signed-off-by: n1ck-guo --- auto_round/autoround.py | 2 -- auto_round/compressors/adam.py | 2 -- auto_round/compressors/base.py | 3 +-- auto_round/compressors/diffusion/compressor.py | 3 +-- auto_round/compressors/mllm/compressor.py | 3 +-- 5 files changed, 3 insertions(+), 10 deletions(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 6e6a5de42..d52914447 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -85,7 +85,6 @@ def __new__( enable_adam: bool = False, # for MLLM and Diffusion extra_config: ExtraConfig = None, - model_dtype: str = None, **kwargs, ) -> BaseCompressor: """Initialize AutoRound with quantization and tuning configuration. @@ -187,7 +186,6 @@ def __new__( device_map=device_map, enable_torch_compile=enable_torch_compile, seed=seed, - model_dtype=model_dtype, **kwargs, ) return ar diff --git a/auto_round/compressors/adam.py b/auto_round/compressors/adam.py index e427a0444..fb79cf39a 100644 --- a/auto_round/compressors/adam.py +++ b/auto_round/compressors/adam.py @@ -101,7 +101,6 @@ def __init__( enable_torch_compile: bool = False, seed: int = 42, optimizer="AdamW", - model_dtype: str = None, **kwargs, ): super(AdamCompressor, self).__init__( @@ -120,7 +119,6 @@ def __init__( gradient_accumulate_steps=gradient_accumulate_steps, enable_torch_compile=enable_torch_compile, device_map=device_map, - model_dtype=model_dtype, **kwargs, ) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 39154f22e..ef05b3eea 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -114,7 +114,6 @@ class BaseCompressor(object): layer_config (dict): Per-layer quantization configuration. nsamples (int): Number of calibration samples. enable_torch_compile (bool): Whether to enable compile_func for quant blocks/layers. - model_dtype (str): model dtype used to load pre-trained model. """ bits: int | None @@ -148,7 +147,6 @@ def __init__( enable_alg_ext: bool = False, disable_opt_rtn: bool = False, seed: int = 42, - model_dtype: str = None, **kwargs, ): """Initialize AutoRound with quantization and tuning configuration. @@ -231,6 +229,7 @@ def __init__( disable_deterministic_algorithms = kwargs.pop("disable_deterministic_algorithms", True) enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False) static_kv_dtype = kwargs.pop("static_kv_dtype", None) + model_dtype = kwargs.pop("model_dtype", None) device = kwargs.pop("device", None) if envs.AR_USE_MODELSCOPE: platform = "model_scope" diff --git a/auto_round/compressors/diffusion/compressor.py b/auto_round/compressors/diffusion/compressor.py index fe8749776..f8680774f 100644 --- a/auto_round/compressors/diffusion/compressor.py +++ b/auto_round/compressors/diffusion/compressor.py @@ -63,7 +63,6 @@ class DiffusionCompressor(BaseCompressor): low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False). device_map (str | dict | int | torch.device, optional): Device placement map. Defaults to 0. enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer - model_dtype (str): model dtype used to load pre-trained model. **kwargs: Additional keyword arguments. """ @@ -99,10 +98,10 @@ def __init__( device_map: Union[str, torch.device, int, dict] = 0, enable_torch_compile: bool = False, seed: int = 42, - model_dtype: str = None, **kwargs, ): logger.warning("Diffusion model quantization is experimental and is only validated on Flux models.") + model_dtype = kwargs.pop("model_dtype", None) self.guidance_scale = guidance_scale self.num_inference_steps = num_inference_steps diff --git a/auto_round/compressors/mllm/compressor.py b/auto_round/compressors/mllm/compressor.py index 71d95d6ad..ada5dcf27 100644 --- a/auto_round/compressors/mllm/compressor.py +++ b/auto_round/compressors/mllm/compressor.py @@ -127,7 +127,6 @@ class MLLMCompressor(BaseCompressor): to_quant_block_names (str|list): A string or list whose elements are list of block's layer names to be quantized. enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer - model_dtype (str): model dtype used to load pre-trained model. **kwargs: Additional keyword arguments. """ @@ -163,11 +162,11 @@ def __init__( device_map: Union[str, torch.device, int, dict] = 0, enable_torch_compile: bool = False, seed: int = 42, - model_dtype: str = None, **kwargs, ): extra_data_dir = kwargs.pop("extra_data_dir", None) template = kwargs.pop("template", None) + model_dtype = kwargs.pop("model_dtype", None) to_quant_block_names: Union[str, list, None] = kwargs.pop("to_quant_block_names", None) if device_map is None: From 552c1937cece1f4ad9c1f7f820ae4481f48530ce Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 4 Nov 2025 01:59:06 -0500 Subject: [PATCH 4/4] code scan Signed-off-by: n1ck-guo --- auto_round/__main__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 86efd9a31..76a8f73d1 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -767,7 +767,7 @@ def tune(args): batch_size=args.eval_bs, limit=args.limit, eval_model_dtype=eval_model_dtype, - mllm=autoround.mllm, + mllm=autoround.mllm, # pylint: disable=E1101 ) else: from auto_round.eval.evaluation import simple_evaluate @@ -778,7 +778,7 @@ def tune(args): st = time.time() if "llama" in args.model.lower(): model_args += ",add_bos_token=True" - if autoround.mllm: + if autoround.mllm: # pylint: disable=E1101 model_type = "hf-multimodal" if args.eval_bs is None or args.eval_bs == "auto": logger.warning("hf-multimodal models does not support auto currently, reset eval_bs to 16")