diff --git a/auto_round/auto_scheme/default_alg.abi3.so b/auto_round/auto_scheme/default_alg.abi3.so index 83da28647..220fb3ce5 100644 Binary files a/auto_round/auto_scheme/default_alg.abi3.so and b/auto_round/auto_scheme/default_alg.abi3.so differ diff --git a/auto_round/auto_scheme/gen_auto_scheme.py b/auto_round/auto_scheme/gen_auto_scheme.py index ca0abdfe0..d650bccca 100644 --- a/auto_round/auto_scheme/gen_auto_scheme.py +++ b/auto_round/auto_scheme/gen_auto_scheme.py @@ -82,6 +82,9 @@ def _check_configs(self) -> None: def get_layer_config(self) -> dict[str, dict]: method_name = self.auto_scheme.method method_func = AUTO_SCHEME_METHODS[method_name] + if self.auto_scheme.low_gpu_mem_usage: + self.enable_torch_compile = False + layer_config = method_func( self.auto_scheme, self.model, @@ -92,6 +95,7 @@ def get_layer_config(self) -> dict[str, dict]: device_map=self.device_map, enable_torch_compile=self.enable_torch_compile, disable_opt_rtn=self.disable_opt_rtn, + low_gpu_mem_usage=self.auto_scheme.low_gpu_mem_usage, ) layer_config = self.fallback_gguf_layer_config(layer_config) return layer_config diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 4716f173b..a78c94737 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Callable, Union +from typing import Union import torch diff --git a/auto_round/compressors/__init__.py b/auto_round/compressors/__init__.py index dbf47b9c2..03983b2c7 100644 --- a/auto_round/compressors/__init__.py +++ b/auto_round/compressors/__init__.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from auto_round.compressors.base import * +from auto_round.compressors.adam import AdamCompressor +from auto_round.compressors.base import BaseCompressor +from auto_round.compressors.base import BaseCompressor as LLMCompressor from auto_round.compressors.mllm.compressor import MLLMCompressor from auto_round.compressors.diffusion.compressor import DiffusionCompressor from auto_round.compressors.config import ( diff --git a/auto_round/compressors/adam.py b/auto_round/compressors/adam.py new file mode 100644 index 000000000..4606eab3a --- /dev/null +++ b/auto_round/compressors/adam.py @@ -0,0 +1,161 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +import torch + +from auto_round.compressors.base import BaseCompressor +from auto_round.schemes import QuantizationScheme +from auto_round.utils import check_is_cpu, htcore, is_hpex_available + + +class AdamCompressor(BaseCompressor): + """Class for quantization with optimizers like adamw of a PyTorch model. + + Args: + model: The PyTorch model to be quantized. + tokenizer: An optional tokenizer for processing input data. + scheme (str| dict | QuantizationScheme ): A preset scheme that defines the quantization configurations + bits (int): Number of bits for quantization (default is 4). + group_size (int): Size of the quantization group (default is 128). + sym (bool): Whether sym to be used (default is True). + layer_config (dict): Configuration for weight quantization (default is None). + batch_size (int): Batch size for training (default is 8). + amp (bool): Whether to use automatic mixed precision (default is True). + device: The device to be used for training (default is "auto"). + lr_scheduler: The learning rate scheduler to be used. + dataset: The default dataset name (default is "NeelNanda/pile-10k"). + enable_quanted_input (bool): Whether to use quantized input data (default is True). + enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). + lr (float): The learning rate (default is 0.005). + minmax_lr (float): The learning rate for min-max tuning (default is None). + low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False). + iters (int): Number of iterations (default is 200). + seqlen (int): Length of the sequence. + nsamples (int): Number of samples (default is 128). + sampler (str): The sampling method (default is "rand"). + seed (int): The random seed (default is 42). + nblocks (int): Number of blocks (default is 1). + gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). + not_use_best_mse (bool): Whether to use mean squared error (default is False). + dynamic_max_gap (int): The dynamic maximum gap (default is -1). + data_type (str): The data type to be used (default is "int"). + scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels + have different choices. + act_bits (int): Number of bits for activation quantization. Default is 16. + act_group_size (int): Group size for activation quantization. Default is None. + act_sym (bool): Whether to use symmetric activation quantization. Default is None. + act_data_type (str): Specifies the data type for activations. + Defaults to None, in which case it inherits the weight data type. + act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. + to_quant_block_names (str|list): A string or list whose elements are list of + block's layer names to be quantized. + enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning + enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer function + **kwargs: Additional keyword arguments. + + Returns: + The quantized model. + """ + + bits: int | None + group_size: int | None + sym: bool | None + data_type: str | None + act_bits: int | None + act_group_size: int | None + act_sym: bool | None + act_data_type: str | None + act_dynamic: bool | None + super_bits: int | None + super_group_size: int | None + + def __init__( + self, + model: Union[torch.nn.Module, str], + tokenizer=None, + scheme: Union[str, dict, QuantizationScheme] = "W4A16", + layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None, + dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", + iters: int = 200, + seqlen: int = 2048, + nsamples: int = 128, + batch_size: int = 8, + gradient_accumulate_steps: int = 1, + low_gpu_mem_usage: bool = False, + device_map: Union[str, int, torch.device, dict] = 0, + enable_torch_compile: bool = False, + seed: int = 42, + optimizer="AdamW", + **kwargs, + ): + super(AdamCompressor, self).__init__( + model=model, + tokenizer=tokenizer, + scheme=scheme, + layer_config=layer_config, + batch_size=batch_size, + dataset=dataset, + low_gpu_mem_usage=low_gpu_mem_usage, + iters=iters, + seqlen=seqlen, + nsamples=nsamples, + seed=seed, + gradient_accumulate_steps=gradient_accumulate_steps, + enable_torch_compile=enable_torch_compile, + device_map=device_map, + **kwargs, + ) + + self.optimizer = self._get_optimizer(optimizer) + + def _get_optimizer(self, optimizer): + if optimizer is None: + optimizer = torch.optim.AdamW + elif isinstance(optimizer, str): + optimizer = getattr(torch.optim, optimizer) + else: + optimizer = optimizer + return optimizer + + def _get_scaler(self): + scaler = None + if self.amp and not check_is_cpu(self.device): + from torch.cuda.amp import GradScaler + + scaler = GradScaler(init_scale=1024, growth_interval=100000) + return scaler + + def _scale_loss_and_backward(self, scaler, loss): + if scaler is not None: + loss = scaler.scale(loss) + + loss.backward() + if is_hpex_available(): + htcore.mark_step() + return loss + + def _step(self, scaler, optimizer, lr_schedule): + if scaler is not None: + scaler.step(optimizer) + optimizer.zero_grad() + lr_schedule.step() + scaler.update() + else: + optimizer.step() + optimizer.zero_grad() + lr_schedule.step() + if is_hpex_available(): + htcore.mark_step() diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 9608e6ee4..297f8a50e 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -20,7 +20,6 @@ import traceback from collections import defaultdict from dataclasses import asdict, fields -from enum import Enum from typing import Any, Callable, Union import accelerate @@ -36,8 +35,6 @@ check_need_act_calibration, check_skippable_keywords, collect_best_params, - get_fp_layer_names, - get_layer_config_by_gguf_format, get_shared_keys, gguf_args_check, infer_bits_by_data_type, @@ -53,7 +50,7 @@ from auto_round.data_type import QUANT_FUNC_WITH_DTYPE from auto_round.data_type.utils import reshape_pad_tensor_by_group_size from auto_round.export.export_to_autoround import AutoRoundFormat -from auto_round.export.export_to_gguf.config import GGUF_CONFIG, GGUF_INNER_CONFIG, ModelType +from auto_round.export.export_to_gguf.config import GGUF_INNER_CONFIG, ModelType from auto_round.logger import logger from auto_round.schemes import AutoScheme, QuantizationScheme, get_gguf_scheme, preset_name_to_scheme from auto_round.sign_sgd import SignSGD @@ -66,7 +63,6 @@ TORCH_VERSION_AT_LEAST_2_6, CpuInfo, check_and_mark_fp8_model, - check_is_cpu, check_seqlen_compatible, check_to_quantized, clear_memory, @@ -76,15 +72,10 @@ convert_fp8_model_to_16b_model, copy_python_files_from_model_cache, detect_device, - estimate_tuning_block_mem, find_matching_blocks, flatten_list, get_block_names, - get_device_memory, - get_layer_features, get_layer_names_in_block, - get_lm_head_name, - get_max_vram, get_module, htcore, is_debug_mode, @@ -99,6 +90,11 @@ to_dtype, unsupported_meta_device, ) +from auto_round.utils.device import ( + get_major_device, + set_auto_device_map_for_block_with_tuning, + set_non_auto_device_map, +) from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block @@ -234,11 +230,12 @@ def __init__( enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False) static_kv_dtype = kwargs.pop("static_kv_dtype", None) device = kwargs.pop("device", None) + # Scale factor for RAM usage per parameter. + mem_per_param_scale = kwargs.pop("mem_per_param_scale", None) self.quant_lm_head = kwargs.pop("quant_lm_head", False) self.mllm = kwargs.pop("mllm") if "mllm" in kwargs else False self.diffusion = kwargs.pop("diffusion") if "diffusion" in kwargs else False - # Scale factor for RAM usage per parameter. - self.mem_per_param_scale = kwargs.pop("mem_per_param_scale", None) + self.fp_layers = kwargs.pop("fp_layers", "") self.layer_config = layer_config self.supported_types = SUPPORTED_LAYER_TYPES @@ -294,21 +291,16 @@ def __init__( self.enable_torch_compile = enable_torch_compile self._adjust_torch_compile(enable_torch_compile) + self.device_map = device_map + if isinstance(self.device_map, str): + self.device_map = self.device_map.replace(" ", "") + if isinstance(scheme, AutoScheme): - self.layer_config = self._gen_auto_scheme(model, scheme, dataset, device_map) + self.layer_config = self._gen_auto_scheme(model, scheme, dataset, self.device_map) # Set device, must place after model loading - self._set_device(device_map) - - if (isinstance(device_map, dict) and device_map) or device_map == "auto": - self.device_map = device_map - elif isinstance(device_map, str) and "," in device_map: - device_map = device_map.replace(" ", "") # Remove any spaces - self.device_list = [int(dev) for dev in device_map.split(",") if dev.isdigit()] - self.device_map = "auto" - else: - self.device_map = None - self._set_device_map_in_blocks(self.device_map) + self.device = get_major_device(device_map) + set_non_auto_device_map(self.model, self.device_map) # Tuning hyperparameters self.seed = seed @@ -340,6 +332,10 @@ def __init__( self.optimizer = self._get_optimizer(None) self.disable_opt_rtn = disable_opt_rtn self.is_packing_immediate = False # whether to pack the layer immediately after tuning + if mem_per_param_scale is None: + self.mem_per_param_scale = 13 if self.iters != 0 else 1 + else: + self.mem_per_param_scale = mem_per_param_scale # KV cache, this one does not affect tuning but will collect some infos during tuning self.static_kv_dtype = static_kv_dtype @@ -435,7 +431,7 @@ def _gen_auto_scheme( # mainly using quant_layers and fixed by users from auto_round.auto_scheme.gen_auto_scheme import GenScheme - if not self.enable_torch_compile and self.super_bits is None: + if not self.enable_torch_compile and self.super_bits is None and not scheme.low_gpu_mem_usage: logger.warning("we strongly recommend to set `enable_torch_compile` to True for AutoScheme to save VRAM") gen_scheme = GenScheme( scheme, @@ -592,134 +588,10 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: "Enabling it can reduce tuning cost by 20%, but it might throw an exception." ) - # if is_debug_mode() and self.enable_torch_compile: - # self.enable_torch_compile = False - # logger.warning("reset enable_torch_compile to `False` as debug mode is enabled") - if (self.data_type.startswith("fp") or self.act_data_type.startswith("fp")) and self.enable_torch_compile: self.enable_torch_compile = False logger.warning("reset enable_torch_compile to `False` as fp8 is enabled") - def _set_device_map_in_blocks(self, device_map: Union[str, dict, None]) -> None: - """Sets the device map for specific blocks in the model. - - Args: - device_map (Union[str, dict, None]): A mapping of module names to devices. - If provided as a string, it should be in the format - "module_name:device,module_name:device". Devices can be integers - (GPU IDs) or strings (e.g., 'cpu', 'cuda:0'). - """ - if self.device_map is None or len(self.device_map) == 0: - self.device_map = None - if not device_map: - return - if self.device_map == "auto" and device_map == "auto": - return - if isinstance(device_map, str): - device_map = device_map.replace(" ", "") - infos = device_map.split(",") - device_map_dict = {} - for info in infos: - index = info.find(":") - key = info[:index] - value = info[index + 1 :] - device_map_dict[key] = value - device_map = device_map_dict - - names = [n for n, m in self.model.named_modules() if len(list(m.children())) == 0] - - for key, device in device_map.items(): - if isinstance(device, str) and device.isdigit(): - device = int(device) - device = detect_device(device) - try: - module = get_module(self.model, key) - module.tuning_device = device - except: - matching_names = [name for name in names if re.match(key, name)] - if len(matching_names) > 0: - for name in matching_names: - self._set_device_for_matching_module(name, device) - else: - for name in names: - if key in name: - self._set_device_for_matching_module(name, device) - - def _set_device_for_matching_module(self, name: str, device: str) -> None: - """Sets the device for a module if it matches the given name.""" - module = get_module(self.model, name) - if hasattr(module, "tuning_device") and module.tuning_device != device: - logger.warning( - f"multiple devices have been set for layer {name}, keeping original device {module.tuning_device}" - ) - else: - module.tuning_device = device - - def _set_auto_device_map_in_block(self, block: torch.nn.Module, input_ids: list[torch.Tensor]) -> None: - """Automatically sets the device map for the block based on available GPUs and memory constraints.""" - if torch.cuda.is_available(): - num_gpus = torch.cuda.device_count() - elif torch.xpu.is_available(): - logger.warning_once("XPU does not support auto device map yet, using device 0 for tuning.") - return - else: - raise RuntimeError("No CUDA or XPU devices found.") - if num_gpus <= 1: - self.device_map = None - return - - if hasattr(self, "device_list") and self.device_list: - cuda_devices = [f"cuda:{i}" for i in self.device_list] - device_0 = cuda_devices[0] - else: - cuda_devices = [f"cuda:{i}" for i in range(num_gpus)] - device_0 = "cuda:0" - - device_0_memory = get_device_memory( - self.device_list[0] if hasattr(self, "device_list") and self.device_list else 0 - ) - block_memory, input_output_memory = estimate_tuning_block_mem(block, input_ids) - if self.low_gpu_mem_usage: - input_output_memory = 0 - - mem_per_param_scale = 13 if self.mem_per_param_scale is None else self.mem_per_param_scale - if self.iters == 0: - mem_per_param_scale = 1 # for rtn - - if (block_memory * mem_per_param_scale + input_output_memory) < device_0_memory: - return # fit in one GPU - - device_map = {} - device_memory = {device: get_device_memory(int(device.split(":")[1])) for device in cuda_devices} - device_memory[device_0] = device_0_memory - input_output_memory - - device_idx = 0 - # First, fill device 0 to its maximum capacity, then distribute the remaining layers evenly across other devices - for n, m in block.named_modules(): - if check_to_quantized(m): - layer_name = block.tmp_name + "." + n - layer_memory = m.weight.nbytes / 1024**3 - if device_idx == 0 and layer_memory * mem_per_param_scale < device_memory[cuda_devices[device_idx]]: - device_map[layer_name] = cuda_devices[device_idx] - device_memory[cuda_devices[device_idx]] -= layer_memory * mem_per_param_scale - elif device_idx == 0: - device_idx += 1 # Move to the next device once device 0 is full - device_map[layer_name] = cuda_devices[device_idx] - device_memory[cuda_devices[device_idx]] -= layer_memory * mem_per_param_scale - else: - # Calculate the target device index based on even distribution - sorted_devices = sorted(cuda_devices, key=lambda d: device_memory[d], reverse=True) - device_idx = sorted_devices[0] - if layer_memory * mem_per_param_scale < device_memory[device_idx]: - device_map[layer_name] = device_idx - device_memory[device_idx] -= layer_memory * mem_per_param_scale - else: - logger.warning_once( - f"Block {block.tmp_name} not fit in available GPU memory. " - "Consider using more GPUs or reducing mem_per_param_scale if OOM occurs." - ) - self._set_device_map_in_blocks(device_map) - def _dq_check(self) -> None: """Reset the default value of super_bits and super_group_size""" if self.data_type.endswith("_dq"): @@ -1335,24 +1207,19 @@ def _quantize_layer_via_rtn(self, name: str) -> None: if is_fp8_linear(m): m = convert_fp8_layer_to_linear(m, self.amp_dtype) set_module(self.model, name, m) - # - # # Step 1: Use optimized RTN data type if available - # if not self.disable_opt_rtn: - # rtn_data_type = self._check_rtn_dytpe(m.data_type, m.bits, m.sym) - # if rtn_data_type is not None: - # m.data_type = rtn_data_type - # self.layer_config[name]["data_type"] = m.data_type - - # Step 2: Try quantization on GPU first, fall back to CPU if OOM + + # Step 1: Try quantization on GPU first, fall back to CPU if OOM # if only export gguf, using gguf-packing instead of rtn if self.is_packing_immediate and self.iters == 0 and "gguf" in self.formats[0] and not self.disable_opt_rtn: m.scale = None m.zp = None else: try: - m = m.to(m.tuning_device if hasattr(m, "tuning_device") else self.device) + tuning_device = m.tuning_device if hasattr(m, "tuning_device") else self.device + m = m.to(tuning_device) m = WrapperLinear( m, + device=tuning_device, enable_minmax_tuning=False, enable_norm_bias_tuning=False, enable_round_tuning=False, @@ -1379,7 +1246,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None: except Exception as e: raise - # Step 3: Optional immediate packing/export + # Step 2: Optional immediate packing/export if self.is_packing_immediate: from auto_round.export import PACKING_LAYER_WITH_FORMAT @@ -1405,9 +1272,6 @@ def _quantize_layer_via_rtn(self, name: str) -> None: ) else: PACKING_LAYER_WITH_FORMAT[target_backend](name, self.model, self.formats[0], device=self.device) - - # if self.low_gpu_mem_usage: - # clear_memory() else: set_module(self.model, name, m) @@ -1563,9 +1427,10 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) if is_fp8_model(self.model): convert_fp8_model_to_16b_model(block, dtype=self.amp_dtype) - if self.device_map == "auto": - self._set_auto_device_map_in_block(block, input_ids) - + if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map): + set_auto_device_map_for_block_with_tuning( + block, self.device_map, input_ids, self.low_gpu_mem_usage, self.mem_per_param_scale + ) # Dispatch model if needed if self.device_map is not None: from accelerate.hooks import AlignDevicesHook, add_hook_to_module @@ -2388,8 +2253,8 @@ def _quantize_layer( best_loss = torch.finfo(torch.float).max scaler = self._get_scaler() # pylint: disable=assignment-from-none init_loss = None - gradient_accumulate_steps = self.batch_size ##Force to low gpu - batch_size = 1 ##Force to low gpu + gradient_accumulate_steps = self.batch_size # Force to low gpu + batch_size = 1 # Force to low gpu pick_samples = batch_size * gradient_accumulate_steps pick_samples = min(nsamples, pick_samples) if self.sampler != "rand": @@ -2578,8 +2443,10 @@ def _quantize_block( new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to(device) set_module(block, n, new_layer) - if self.device_map == "auto": - self._set_auto_device_map_in_block(block, input_ids) + if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map): + set_auto_device_map_for_block_with_tuning( + block, self.device_map, input_ids, self.low_gpu_mem_usage, self.mem_per_param_scale + ) if self.device_map is not None: for n, m in block.named_modules(): @@ -3210,147 +3077,3 @@ def _sampling_inputs( current_input_others[key] = input_others[key] return current_input_ids, current_input_others - - -class LLMCompressor(BaseCompressor): - pass - - -class AdamCompressor(BaseCompressor): - """Class for quantization with optimizers like adamw of a PyTorch model. - - Args: - model: The PyTorch model to be quantized. - tokenizer: An optional tokenizer for processing input data. - scheme (str| dict | QuantizationScheme ): A preset scheme that defines the quantization configurations - bits (int): Number of bits for quantization (default is 4). - group_size (int): Size of the quantization group (default is 128). - sym (bool): Whether sym to be used (default is True). - layer_config (dict): Configuration for weight quantization (default is None). - batch_size (int): Batch size for training (default is 8). - amp (bool): Whether to use automatic mixed precision (default is True). - device: The device to be used for training (default is "auto"). - lr_scheduler: The learning rate scheduler to be used. - dataset: The default dataset name (default is "NeelNanda/pile-10k"). - enable_quanted_input (bool): Whether to use quantized input data (default is True). - enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). - lr (float): The learning rate (default is 0.005). - minmax_lr (float): The learning rate for min-max tuning (default is None). - low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False). - iters (int): Number of iterations (default is 200). - seqlen (int): Length of the sequence. - nsamples (int): Number of samples (default is 128). - sampler (str): The sampling method (default is "rand"). - seed (int): The random seed (default is 42). - nblocks (int): Number of blocks (default is 1). - gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). - not_use_best_mse (bool): Whether to use mean squared error (default is False). - dynamic_max_gap (int): The dynamic maximum gap (default is -1). - data_type (str): The data type to be used (default is "int"). - scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels - have different choices. - act_bits (int): Number of bits for activation quantization. Default is 16. - act_group_size (int): Group size for activation quantization. Default is None. - act_sym (bool): Whether to use symmetric activation quantization. Default is None. - act_data_type (str): Specifies the data type for activations. - Defaults to None, in which case it inherits the weight data type. - act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. - to_quant_block_names (str|list): A string or list whose elements are list of - block's layer names to be quantized. - enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning - enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer function - **kwargs: Additional keyword arguments. - - Returns: - The quantized model. - """ - - bits: int | None - group_size: int | None - sym: bool | None - data_type: str | None - act_bits: int | None - act_group_size: int | None - act_sym: bool | None - act_data_type: str | None - act_dynamic: bool | None - super_bits: int | None - super_group_size: int | None - - def __init__( - self, - model: Union[torch.nn.Module, str], - tokenizer=None, - scheme: Union[str, dict, QuantizationScheme] = "W4A16", - layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None, - dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", - iters: int = 200, - seqlen: int = 2048, - nsamples: int = 128, - batch_size: int = 8, - gradient_accumulate_steps: int = 1, - low_gpu_mem_usage: bool = False, - device_map: Union[str, int, torch.device, dict] = 0, - enable_torch_compile: bool = False, - seed: int = 42, - optimizer="AdamW", - **kwargs, - ): - super(AdamCompressor, self).__init__( - model=model, - tokenizer=tokenizer, - scheme=scheme, - layer_config=layer_config, - batch_size=batch_size, - dataset=dataset, - low_gpu_mem_usage=low_gpu_mem_usage, - iters=iters, - seqlen=seqlen, - nsamples=nsamples, - seed=seed, - gradient_accumulate_steps=gradient_accumulate_steps, - enable_torch_compile=enable_torch_compile, - device_map=device_map, - **kwargs, - ) - - self.optimizer = self._get_optimizer(optimizer) - - def _get_optimizer(self, optimizer): - if optimizer is None: - optimizer = torch.optim.AdamW - elif isinstance(optimizer, str): - optimizer = getattr(torch.optim, optimizer) - else: - optimizer = optimizer - return optimizer - - def _get_scaler(self): - scaler = None - if self.amp and not check_is_cpu(self.device): - from torch.cuda.amp import GradScaler - - scaler = GradScaler(init_scale=1024, growth_interval=100000) - return scaler - - def _scale_loss_and_backward(self, scaler, loss): - if scaler is not None: - loss = scaler.scale(loss) - - loss.backward() - if is_hpex_available(): - htcore.mark_step() - return loss - - def _step(self, scaler, optimizer, lr_schedule): - if scaler is not None: - scaler.step(optimizer) - optimizer.zero_grad() - lr_schedule.step() - scaler.update() - else: - optimizer.step() - optimizer.zero_grad() - lr_schedule.step() - if is_hpex_available(): - htcore.mark_step() diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index fb6c5f3db..bd0533581 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -22,7 +22,7 @@ def search_scales(data: torch.Tensor, bits: int, qw: Union[None, torch.Tensor, float] = None) -> torch.Tensor: nmax = pow(2, bits - 1) - imax = abs(data).argmax(axis=-1, keepdims=True) + imax = torch.abs(data).argmax(dim=-1, keepdim=True) group_max = torch.take_along_dim(data, imax, dim=-1) iscales = -nmax * get_reciprocal(group_max) scales = get_reciprocal(iscales) diff --git a/auto_round/schemes.py b/auto_round/schemes.py index 96faf62c0..1e4dcdf2b 100644 --- a/auto_round/schemes.py +++ b/auto_round/schemes.py @@ -294,12 +294,14 @@ class AutoScheme: shared_layers: Optional[Iterable[Iterable[str]]] = None method: str = "default" ignore_scale_zp_bits: bool = False + batch_size: Optional[int] = None nsamples: Optional[int] = None seqlen: Optional[int] = None dataset: Optional[str] = None # Import Notice no comma for each item device_map: Optional[Union[str, torch.device, int, dict]] = None enable_torch_compile: Optional[bool] = None disable_opt_rtn: bool = True + low_gpu_mem_usage: bool = True def __post_init__(self): if isinstance(self.options, str): diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 53abc57aa..850c95343 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -13,13 +13,16 @@ # limitations under the License. import gc import os +import re from functools import lru_cache -from typing import Any, Callable, Dict, List, Tuple, Union +from itertools import combinations +from typing import Callable, Union import cpuinfo import torch from auto_round.logger import logger +from auto_round.utils.model import check_to_quantized, get_block_names, get_layer_features, get_module # Note on HPU usage: # There are two modes available for enabling auto-round on HPU: @@ -199,7 +202,7 @@ def detect_device_count(): return 0 -def detect_device(device: Union[str, int, torch.device] = None) -> str: +def detect_device(device: Union[None, str, int, torch.device] = None) -> str: """Detects the appropriate computation device. This function determines the device to use for computations. It can take @@ -262,7 +265,7 @@ def is_valid_digit(s): return device -def get_device_and_parallelism(device: Union[str, torch.device, int]) -> Tuple[str, bool]: +def get_device_and_parallelism(device: Union[str, torch.device, int]) -> tuple[str, bool]: if isinstance(device, str): devices = device.replace(" ", "").split(",") elif isinstance(device, int): @@ -539,3 +542,302 @@ def get_device_memory(i: int = 0) -> int: else: raise RuntimeError("No supported device found (CUDA or XPU).") return total_memory + + +def get_major_device(device_map: Union[None, str, torch.device, int, dict]) -> str: + if device_map is None or isinstance(device_map, (str, torch.device, int)): + device = detect_device(device_map) + return device + + if isinstance(device_map, dict) and device_map: + tmp_devices = [] + for val in device_map.values(): + if isinstance(val, (str, torch.device, int)): # could optimize + tmp_device = detect_device(val) + tmp_device = tmp_device.split(":")[0] + tmp_devices.append(tmp_device) + tmp_devices = list(set(tmp_devices)) + device = None + for tmp_device in tmp_devices: + if tmp_device != "cpu": + device = tmp_device + break + if device is None: + device = tmp_devices[0] + if len(tmp_devices) > 1: + logger.warning_once( + f"there are multiple device types in the device_map, " + f"please make sure they are correct,use the first none-cpu device {device} as the core device " + ) + + return device + logger.warning_once(f"device_map should be [str, torch.device, int, dict], but got {type(device_map)}") + return "cpu" + + +def set_tuning_device_for_layer(model, name: str, device: str) -> None: + """Sets the device for a module if it matches the given name.""" + module = get_module(model, name) + if hasattr(module, "tuning_device") and module.tuning_device != device: + logger.warning( + f"multiple devices have been set for layer {name}, keeping original device {module.tuning_device}" + ) + else: + module.tuning_device = device + + +def set_non_auto_device_map( + model: torch.nn.Module, device_map: Union[str, int, dict], quant_layer_names: Union[None, list, tuple] = None +) -> None: + if not device_map: + return + if device_map == "auto": + return + if isinstance(device_map, str) and "," in device_map: # auto device map + return + if isinstance(device_map, int): + return + if isinstance(device_map, str): + device_map = device_map.replace(" ", "") + infos = device_map.split(",") + device_map_dict = {} + for info in infos: + index = info.find(":") + key = info[:index] + value = info[index + 1 :] + device_map_dict[key] = value + device_map = device_map_dict + if quant_layer_names is not None: + names = quant_layer_names + else: + names = [ + n for n, m in model.named_modules() if len(list(m.children())) == 0 + ] # if it's a block, it will be incorrect + for key, device in device_map.items(): + if isinstance(device, str) and device.isdigit(): + device = int(device) + device = detect_device(device) + if key in names: + module = get_module(model, key) + module.tuning_device = device + else: + matching_names = [name for name in names if re.match(key, name)] + for name in matching_names: + set_tuning_device_for_layer(model, name, device) + if not matching_names: + logger.warning(f"{key} in `device_map` dose not match any modules, please have a check") + + +def set_auto_device_map_for_block_with_tuning( + block: torch.nn.Module, device_map, input_ids: list[torch.Tensor], low_gpu_mem_usage=False, mem_per_param_scale=13.0 +): + """ + Automatically sets the device map for the block based on available GPUs and memory constraints. + + Args: + block (torch.nn.Module): The model block whose device map is to be set. + device_map (str | int | dict): Specifies the device mapping. + input_ids (list[torch.Tensor]): List of input tensors used for estimating memory requirements. + low_gpu_mem_usage (bool, optional): If True, ignoring input/output memory. Defaults to False. + mem_per_param_scale (float, optional): Scaling factor for estimating memory usage per parameter in the block. + Typical values range from 10.0 to 20.0 depending on model size and GPU memory characteristics. + Higher values are more conservative and help avoid out-of-memory errors. Defaults to 13.0. + + Returns: + None + + Raises: + RuntimeError: If no CUDA or XPU devices are found. + + Note: + This function is intended for internal use in device memory management and tuning. + The mem_per_param_scale parameter should be adjusted based on empirical memory usage observations. + """ + if torch.cuda.is_available(): + num_gpus = torch.cuda.device_count() + elif torch.xpu.is_available(): + logger.warning_once("XPU does not support auto device map yet, using device 0 for tuning.") + return + else: + raise RuntimeError("No CUDA or XPU devices found.") + device_list = None + if isinstance(device_map, str) and "," in device_map: + device_list = [int(dev) for dev in device_map.split(",") if dev.isdigit()] + + if device_list: + cuda_devices = [f"cuda:{i}" for i in device_list] + device_0 = cuda_devices[0] + else: + cuda_devices = [f"cuda:{i}" for i in range(num_gpus)] + device_0 = "cuda:0" + + device_0_memory = get_device_memory(device_list[0] if device_list else 0) + block_memory, input_output_memory = estimate_tuning_block_mem(block, input_ids) + if low_gpu_mem_usage: + input_output_memory = 0 + + if (block_memory * mem_per_param_scale + input_output_memory) < device_0_memory: + return # fit in one GPU + + device_map = {} + device_memory = {device: get_device_memory(int(device.split(":")[1])) for device in cuda_devices} + device_memory[device_0] = device_0_memory - input_output_memory + + device_idx = 0 + names = [] + # First, fill device 0 to its maximum capacity, then distribute the remaining layers evenly across other devices + for n, m in block.named_modules(): + if check_to_quantized(m): + layer_name = m.tmp_name + names.append(layer_name) + layer_memory = m.weight.nbytes / 1024**3 + if device_idx == 0 and layer_memory * mem_per_param_scale < device_memory[cuda_devices[device_idx]]: + device_map[layer_name] = cuda_devices[device_idx] + device_memory[cuda_devices[device_idx]] -= layer_memory * mem_per_param_scale + elif device_idx == 0: + device_idx += 1 # Move to the next device once device 0 is full + device_map[layer_name] = cuda_devices[device_idx] + device_memory[cuda_devices[device_idx]] -= layer_memory * mem_per_param_scale + else: + # Calculate the target device index based on even distribution + sorted_devices = sorted(cuda_devices, key=lambda d: device_memory[d], reverse=True) + device_idx = sorted_devices[0] + if layer_memory * mem_per_param_scale < device_memory[device_idx]: + device_map[layer_name] = device_idx + device_memory[device_idx] -= layer_memory * mem_per_param_scale + else: + logger.warning_once( + f"Block {block.tmp_name} not fit in available GPU memory. " + "Consider using more GPUs or reducing mem_per_param_scale if OOM occurs." + ) + + set_non_auto_device_map(block, device_map, names) + + +def partition_dict_numbers(number_dict, n): + """ + Partition a dictionary of numbers into N groups with approximately equal sums + """ + # Edge cases + if n > len(number_dict): + groups = [] + for key, value in number_dict.items(): + groups.append({key: value}) + for _ in range(n - len(number_dict)): + groups.append({}) + return groups + + if n == len(number_dict): + return [{key: value} for key, value in number_dict.items()] + + total_sum = sum(number_dict.values()) + # target = total_sum / n # Use float for better precision + + items = list(number_dict.items()) + result = [] + remaining = items.copy() + + def find_optimal_subset(arr, target): + """Find subset with sum closest to target""" + best_subset = [] + best_diff = float("inf") + + # Try all possible subset sizes + for r in range(1, len(arr) + 1): + for combo in combinations(arr, r): + current_sum = sum(value for _, value in combo) + current_diff = abs(current_sum - target) + + # If we found a perfect match, return immediately + if current_diff == 0: + return list(combo) + + # Update the best subset if this is better + if current_diff < best_diff and current_sum <= total_sum: + best_diff = current_diff + best_subset = list(combo) + + return best_subset + + # Distribute items into n-1 groups + for i in range(n - 1): + if not remaining: + break + + # Calculate dynamic target based on remaining items + remaining_target = sum(value for _, value in remaining) / (n - i) + subset = find_optimal_subset(remaining, remaining_target) + + result.append(dict(subset)) + + # Remove allocated items + for item in subset: + remaining.remove(item) + + # Last group gets all remaining items + result.append(dict(remaining)) + + return result + + +def set_avg_auto_device_map(model: torch.nn.Module, device_map): + block_name_list = get_block_names(model) + device_list = None + if isinstance(device_map, str) and "," in device_map: + device_list = [int(dev) for dev in device_map.split(",") if dev.isdigit()] + num_devices = len(device_list) + else: + if torch.cuda.is_available(): + num_devices = torch.cuda.device_count() + elif torch.xpu.is_available(): + logger.warning_once("XPU does not support auto device map yet, using device 0 for tuning.") + return + else: + return + + if device_list: + cuda_devices = [f"cuda:{i}" for i in device_list] + else: + cuda_devices = [f"cuda:{i}" for i in range(num_devices)] + + for block_names in block_name_list: + for block_name in block_names: + params_dict = {} + block_module = get_module(model, block_name) + for n, m in block_module.named_modules(): + in_features, out_features = get_layer_features(m) + if in_features is None: + continue + params_dict[n] = in_features * out_features + + res_list = partition_dict_numbers(params_dict, num_devices) + device_index = 0 + for res in res_list: + for key in res.keys(): + set_tuning_device_for_layer(block_module, key, cuda_devices[device_index]) + device_index += 1 + + +if __name__ == "__main__": + # Example usage + number_dict = {"item1": 90, "item2": 20, "item3": 30, "item4": 40, "item5": 50, "item6": 60} + + groups = partition_dict_numbers(number_dict, 10) + for i, group in enumerate(groups): + print(f"Group {i + 1}: {group}, Sum: {sum(group.values())}") + + groups = partition_dict_numbers(number_dict, 6) + for i, group in enumerate(groups): + print(f"Group {i + 1}: {group}, Sum: {sum(group.values())}") + + groups = partition_dict_numbers(number_dict, 4) + for i, group in enumerate(groups): + print(f"Group {i + 1}: {group}, Sum: {sum(group.values())}") + + groups = partition_dict_numbers(number_dict, 3) + for i, group in enumerate(groups): + print(f"Group {i + 1}: {group}, Sum: {sum(group.values())}") + + groups = partition_dict_numbers(number_dict, 2) + for i, group in enumerate(groups): + print(f"Group {i + 1}: {group}, Sum: {sum(group.values())}") diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 63cb24a9f..7a1c66de8 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -18,7 +18,7 @@ from collections import UserDict from dataclasses import asdict, fields from pathlib import Path -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Union import torch import transformers @@ -263,7 +263,6 @@ def mllm_load_model( **kwargs, ): import transformers - from huggingface_hub import HfApi, HfFileSystem, hf_hub_download from transformers import AutoModel, AutoModelForCausalLM, AutoProcessor, AutoTokenizer from auto_round.utils.device import get_device_and_parallelism, set_fake_cuda_device_capability @@ -732,7 +731,7 @@ def set_nested_attr(module, attr_name: str, value): setattr(module, attrs[-1], value) -def _pad_weight(weight: torch.Tensor, block_size: list) -> Tuple[torch.Tensor, int, int]: +def _pad_weight(weight: torch.Tensor, block_size: list) -> tuple[torch.Tensor, int, int]: """Pads a matrix to make its dimensions multiples of block_size.""" M, N = weight.shape[-2:] block_size_m, block_size_n = block_size @@ -757,7 +756,7 @@ def _unpad_weight(weight: torch.Tensor, original_M: int, original_N: int, keep_f def pad_block_fp8_weight_naive( weight: torch.Tensor, weight_scale: torch.Tensor, block_size: list -) -> Tuple[torch.Tensor, int, int]: +) -> tuple[torch.Tensor, int, int]: assert len(block_size) == 2 block_size_m, block_size_n = block_size @@ -946,8 +945,10 @@ def set_module(model, key, new_module): def get_layer_features(layer): """Extracts input and output feature dimensions for supported layers.""" - from auto_round.utils.common import LinearAllreduce, LinearLayer, deepspeed_exists + from auto_round.utils import deepspeed_exists + if deepspeed_exists: + from deepspeed.module_inject import LinearAllreduce, LinearLayer if type(layer) == torch.nn.Linear: return layer.in_features, layer.out_features elif type(layer) == transformers.pytorch_utils.Conv1D: # TODO: Verify correctness diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index ea7e2ec5b..67072d762 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -155,8 +155,12 @@ def _init_tuning_params_and_quant_func(self): if type(self.orig_layer) == transformers.pytorch_utils.Conv1D: orig_weight = orig_weight.t() weight_reshape = reshape_and_pad_tensor(orig_weight.data, orig_layer.group_size) - self.weight_min = torch.clamp(weight_reshape.min(1)[0], max=0) - self.weight_max = torch.clamp(weight_reshape.max(1)[0], min=0) + if self.enable_round_tuning: + self.weight_min = torch.clamp(weight_reshape.min(1)[0], max=0) + self.weight_max = torch.clamp(weight_reshape.max(1)[0], min=0) + else: + self.weight_min = None + self.weight_max = None self._init_params( "value", p_dtype, weight_reshape.shape, 0, self.enable_round_tuning and self.orig_layer.bits < 16 ) @@ -232,7 +236,7 @@ def _qdq_weight(self, value, min_scale, max_scale): quant_kwargs["super_group_size"] = self.orig_layer.super_group_size weight_q, scale, zp = self.weight_quant_func( - weight, + weight.to(self.device), bits=self.orig_layer.bits, group_size=self.orig_layer.group_size, v=value, @@ -362,7 +366,7 @@ def _set_dict_attr(attr_dict, attr_name): assert global_scale.numel() == 1 self.orig_layer.weight_global_scale = global_scale.to("cpu") - ##unwrapper bias + # Unwrapper bias if self.enable_norm_bias_tuning and "bias_v" in best_params.keys(): ##fake quant bias_v = best_params["bias_v"].to(self.device) bias = self.orig_layer.bias @@ -466,6 +470,7 @@ def forward(self, x): Returns: torch.Tensor: Output tensor after applying the wrapped layer. """ + # logger.info(self.orig_layer.tmp_name) x = x.to(self.device) weight_q, *_ = self._qdq_weight(self.value, self.min_scale, self.max_scale) diff --git a/docs/step_by_step.md b/docs/step_by_step.md index 4f2a8e9bd..f692543b7 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -306,16 +306,20 @@ ar.quantize_and_save() ~~~ #### Hyperparameters in AutoScheme -`avg_bits(float)`: Target average bits for the whole model, only to be quantized layer will be counted in the average bits calculation. +`avg_bits(float)` Target average bits for the whole model; only layers to be quantized will be counted in the average bits calculation. -`options(Union[str, list[Union[QuantizationScheme, str]])`: the options of quantization schemes to choose from. It could be a string like "W4A16", or a list of strings or QuantizationScheme objects. +`options(Union[str, list[Union[QuantizationScheme, str]])` the options of quantization schemes to choose from. It could be a string like "W4A16", or a list of strings or QuantizationScheme objects. -`ignore_scale_zp_bits(bool)`: Whether to ignore the bits of scale and zero point in average bits calculation. Default is False. +`ignore_scale_zp_bits(bool)` Whether to ignore the bits of scale and zero point in average bits calculation. Default is False. `device_map (Optional[str,dict,torch.device])` only supported in API now, as auto-scheme used more VRAM than auto-round tuning, so you could set a different device_map for it. `shared_layers (Optional[Iterable[Iterable[str]]])` only supported in API now +`batch_size (Optional[int])` could be set to 1 to reduce VRAM but increase time cost + +`low_gpu_mem_usage(bool=True)` whether to reduce gpu memory usage at the cost of more time cost + In some serving frameworks, certain layers (e.g., QKV or MoE) are fused to accelerate inference. These fused layers may require the same data type and bit configuration. The shared_layers option simplifies this setup by supporting both regex and full-name matching. **Note that regex matching is applied in a block-wise manner.** @@ -348,10 +352,23 @@ ar.quantize_and_save() ``` #### AutoScheme Cost -The tuning cost of AutoScheme is approximately 2 to 4 times that of model's bf16 size, depending on the options. + We tested it on Nvidia A100 80G using torch v2.8. -We will try to optimize the VRAM usage in the future. +We will try to optimize the RAM usage in the future. The RAM usage is about 1.1-1.5x of the model's BF16 size + +| Models | Scheme | VRAM Cost | Time Cost | +| ------------- | --------------------- | --------- | --------------------- | +| Qwen3-8B | W2A16 / W4A16 / W8A16 | 14G | 60s * len of options | +| Qwen3-8B | MXFP4 / MXFP8 | 18G | 60s * len of options | +| Qwen3-8B | GGUF* | 14G | 80s * len of options | +| Qwen3-32B | W2A16 / W4A16 / W8A16 | 29G | 180s* len of options | +| Qwen3-32B | MXFP4 / MXFP8 | 29G | 180s* len of options | +| Qwen3-32B | GGUF* | 18G | 300s * len of options | +| Llama-3.3-70B | W2A16 / W4A16 / W8A16 | 32G | 420s * len of options | + +
+Cost w/o low_gpu_mem_usage | Models | Scheme | VRAM Cost
(torch compile) | Time Cost
torch compile | VRAM Cost
wo torch compile | Time Cost
wo torch compile | | --------- | ----------------- | ------------------------------- | ----------------------------- | -------------------------------- | -------------------------------- | @@ -361,7 +378,7 @@ We will try to optimize the VRAM usage in the future. | Qwen3-32B | W2A16/W4A16/W8A16 | OOM with 240G | --- | OOM with 240G | --- | | Qwen3-32B | MXFP4/MXFP8 | 160G | 200s * len of options | 200G | 240s * len of options | | Qwen3-32B | GGUF* | 210G | 80s * len of options | 200G | 60s * len of options | - +
#### Limitations diff --git a/test/test_cpu/test_gguf_format.py b/test/test_cpu/test_gguf_format.py index 3a5cb3d43..308425cd1 100644 --- a/test/test_cpu/test_gguf_format.py +++ b/test/test_cpu/test_gguf_format.py @@ -339,7 +339,7 @@ def test_qtype_setting(self): # Qwen3-0.6B output q6_k, token_embed q4_0 448M # Qwen3-8B output q6_k, token_embed q4_0 4.5G # Llama-3.2-1B-Instruct o output, token_embed q6_k 736M - from auto_round.compressors import get_layer_config_by_gguf_format, set_layer_config + from auto_round.compressors.utils import set_layer_config from auto_round.export.export_to_gguf.config import ModelType model_name = "/tf_dataset/auto_round/models/Qwen/Qwen2.5-0.5B-Instruct" diff --git a/test/test_cuda/test_auto_scheme.py b/test/test_cuda/test_auto_scheme.py index 5b7116768..ac486fef4 100644 --- a/test/test_cuda/test_auto_scheme.py +++ b/test/test_cuda/test_auto_scheme.py @@ -83,10 +83,37 @@ def test_shared_layers(self): # @multi_card def test_multi_card(self): - model_name = "/models/Qwen3-8B" - target_bits = 5.254 - # for device_map in ["auto", "0,1", "0", None]: + model_name = "/models/Qwen3-0.6B" + target_bits = 5.265 + for device_map in ["auto", "0,1", "0", None]: + scheme = AutoScheme(avg_bits=target_bits, options=("NVFP4")) + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1, device_map=device_map) + model, layer_config = ar.quantize() + avg_bits, _ = compute_avg_bits_for_model(model) + print(avg_bits) + assert target_bits - 0.1 < avg_bits <= target_bits + 1e-3 + + @multi_card + def test_multi_card_1(self): + model_name = "/models/Qwen3-0.6B" + target_bits = 5.265 + from transformers import AutoModelForCausalLM, AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto") scheme = AutoScheme(avg_bits=target_bits, options=("NVFP4")) + ar = AutoRound(model=model, tokenizer=tokenizer, scheme=scheme, iters=0, nsamples=1) + model, layer_config = ar.quantize() + avg_bits, _ = compute_avg_bits_for_model(model) + print(avg_bits) + assert target_bits - 0.1 < avg_bits <= target_bits + 1e-3 + + def test_non_low_gpu_mem_usage(self): + model_name = "/models/Qwen3-0.6B" + target_bits = 5.265 + # for device_map in ["auto", "0,1", "0", None]: + scheme = AutoScheme(avg_bits=target_bits, options=("NVFP4"), low_gpu_mem_usage=False, device_map="auto") + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1) model, layer_config = ar.quantize() avg_bits, _ = compute_avg_bits_for_model(model)