Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ and [fbaldassarri](https://huggingface.co/fbaldassarri).
## 🆕 What's New

[2025/07] AutoRound now offers experimental support for **GGUF** format, and recommends using optimized RTN mode (--iters 0) for
all bits other than 3 bits. **A more advanced algorithm** tailored for specific configurations may be available in
v0.6.1. Example
all bits other than 3 bits. Example
models: [Intel/Qwen3-235B-A22B-q2ks-mixed-ar](https://huggingface.co/Intel/Qwen3-235B-A22B-q2ks-ar)
and [Intel/DeepSeek-R1-0528-q2ks-mixed-ar](https://huggingface.co/Intel/DeepSeek-R1-0528-q2ks-mixed-ar).
and [Intel/DeepSeek-R1-0528-q2ks-mixed-ar](https://huggingface.co/Intel/DeepSeek-R1-0528-q2ks-mixed-ar). **A more advanced algorithm** tailored for specific configurations may be available in
v0.6.1.

[2025/05] AutoRound provides some recipes for **DeepSeek-R1-0528**, please refer
to [Intel/DeepSeek-R1-0528-int2-mixed-ar](https://huggingface.co/Intel/DeepSeek-R1-0528-int2-mixed-ar), [Intel/DeepSeek-R1-0528-int4-ar](https://huggingface.co/Intel/DeepSeek-R1-0528-int4-ar)
Expand Down
112 changes: 74 additions & 38 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from auto_round.export.export_to_gguf.config import GGUF_CONFIG, GGUF_INNER_CONFIG, ModelType
from auto_round.low_cpu_mem.utils import get_layers_before_block
from auto_round.utils import (
INNER_SUPPORTED_LAYER_TYPES,
SUPPORTED_DTYPES,
SUPPORTED_LAYER_TYPES,
TORCH_VERSION_AT_LEAST_2_6,
Expand All @@ -46,6 +47,8 @@
collect_best_params,
compile_func,
convert_dtype_str2torch,
convert_fp8_layer_to_linear,
convert_fp8_model_to_16b_model,
detect_device,
find_matching_blocks,
flatten_list,
Expand All @@ -55,6 +58,7 @@
get_layer_names_in_block,
get_lm_head_name,
get_module,
get_quant_keys,
get_shared_keys,
htcore,
infer_bits_by_data_type,
Expand Down Expand Up @@ -141,23 +145,23 @@ class AutoRound(object):

def __init__(
self,
model: torch.nn.Module,
tokenizer,
model: Union[torch.nn.Module, str],
tokenizer=None,
bits: int = 4,
group_size: int = 128,
sym: bool = True,
layer_config: dict = None,
batch_size: int = 8,
amp: bool = True,
device: str = None,
device: Union[str, torch.device, int] = 0,
lr_scheduler=None,
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
enable_quanted_input: bool = True,
enable_minmax_tuning: bool = True,
lr: float = None,
minmax_lr: float = None,
low_gpu_mem_usage: bool = False,
low_cpu_mem_usage: bool = False,
low_cpu_mem_usage: int = 0,
iters: int = 200,
seqlen: int = 2048,
nsamples: int = 128,
Expand Down Expand Up @@ -188,14 +192,28 @@ def __init__(
if kwargs:
logger.warning(f"unrecognized keys {list(kwargs.keys())} were passed. Please check them.")
self.quantized = False
self.model_orig_dtype = model.dtype
self.seed = seed
set_seed(self.seed)
if device is not None and "," in str(device):
raise ValueError(
"API does not support explicit set multiple devices,"
" please set CUDA_VISIBLE_DEVICES yourself and use `device=auto` instead"
)

if isinstance(model, str):
model, tokenizer, low_cpu_mem_usage = llm_load_model(
model, device=device, low_cpu_mem_mode=low_cpu_mem_usage
)
elif tokenizer is None and iters > 0:
raise ValueError("A tokenizer must be set for non-str model input")
self.low_cpu_mem_usage = bool(low_cpu_mem_usage)
if unsupport_meta_device(model):
raise RuntimeError(
"AutoRound does not support parameters on meta device. "
"Please use more GPUs by setting `--device 0,1,2,3` or just place the model on CPU."
)
# self.model_orig_dtype = model.dtype
self.device = detect_device(device) ##must place after llm_load_model, because this one will convert auto

## important tuning hype-parameters
self.amp = amp
Expand All @@ -208,7 +226,7 @@ def __init__(
self.sym = sym

self.low_gpu_mem_usage = low_gpu_mem_usage
self.low_cpu_mem_usage = low_cpu_mem_usage

self.layer_config = {} if layer_config is None else layer_config
self.seqlen = seqlen
self.batch_size, self.gradient_accumulate_steps = batch_size, gradient_accumulate_steps
Expand All @@ -232,9 +250,10 @@ def __init__(
)
self.bits = tmp_bits
self.supported_types = SUPPORTED_LAYER_TYPES
self.inner_supported_types = INNER_SUPPORTED_LAYER_TYPES
self.model = model.eval()
self.tokenizer = tokenizer
self.device = detect_device(device)

self.scale_dtype = convert_dtype_str2torch(scale_dtype)
self.set_amp_dtype()
self.to_quant_block_names = to_quant_block_names
Expand Down Expand Up @@ -837,7 +856,12 @@ def quant_rtn_with_imatrix(self, all_to_quantized_module_names: list[str]) -> No
# Load dataset
from .calib_dataset import get_dataloader

if hasattr(self.model, "is_fp8"):
convert_fp8_model_to_16b_model(self.model, self.amp_dtype)

if isinstance(self.dataset, str):
if self.tokenizer is None:
raise ValueError("A tokenizer must be set for the model when using a dataset string.")
dataset_name = self.dataset.replace(" ", "")
self.dataloader = get_dataloader(
self.tokenizer, self.seqlen, dataset_name, self.seed, self.batch_size, self.nsamples
Expand Down Expand Up @@ -1063,6 +1087,10 @@ def quantize_layer_via_rtn(self, name: str) -> None:
"""
m = get_module(self.model, name)

if m.__class__.__name__ == "FP8Linear":
m = convert_fp8_layer_to_linear(m, self.amp_dtype)
set_module(self.model, name, m)

# Step 1: Use optimized RTN data type if available
if not self.disable_opt_rtn and not m.data_type.startswith("rtn_"):
from auto_round.data_type import QUANT_FUNC_WITH_DTYPE
Expand Down Expand Up @@ -1192,7 +1220,9 @@ def quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
clear_memory()
cnt = 1
cnt += 1

##convert remainning fp8
if hasattr(self.model, "is_fp8"):
convert_fp8_model_to_16b_model(self.model, self.amp_dtype)
self.quantized = True
return self.model, self.layer_config

Expand Down Expand Up @@ -1250,6 +1280,8 @@ def quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) -
pbar.set_description(f"Quantizing {block_name}")
block = get_module(self.model, block_name)
block = block.to(self.device)
if hasattr(self.model, "is_fp8"):
convert_fp8_model_to_16b_model(block, dtype=self.amp_dtype)
# Dispatch model if needed
if self.device_map is not None:
from accelerate import dispatch_model
Expand Down Expand Up @@ -1401,7 +1433,7 @@ def quantize(self):
self.batch_size = total_samples
logger.warning(f"force the train batch size to {total_samples}")

self.quant_blocks(
self.quantize_blocks(
self.model,
inputs,
block_names,
Expand All @@ -1419,6 +1451,12 @@ def quantize(self):

self.quant_layers(layer_names, all_inputs) ##TODO pack layer immediately

if hasattr(self.model, "is_fp8"):
for n, m in self.model.named_modules():
if m.__class__.__name__ == "FP8Linear":
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to("cpu")
set_module(self.model, n, new_layer)

end_time = time.time()
cost_time = end_time - self.start_time
logger.info(f"quantization tuning time {cost_time}")
Expand Down Expand Up @@ -1456,17 +1494,22 @@ def quant_layers(self, layer_names, layer_inputs):
"""
##TODO currently we take all the layers outside blocks as post block layers which is not optimal
## if there is no input for layer, we use rtn

for layer_name in copy.deepcopy(layer_names):
if layer_name not in layer_inputs:
logger.info(f"using rtn to quantize {layer_name}")
from auto_round.data_type import QUANT_FUNC_WITH_DTYPE

layer = get_module(self.model, layer_name)
if layer.__class__.__name__ == "FP8Linear":
new_layer = convert_fp8_layer_to_linear(layer, self.amp_dtype).to(self.device)
set_module(self.model, layer_name, new_layer)
layer = new_layer

if not self.disable_opt_rtn and "rtn_" + layer.data_type in QUANT_FUNC_WITH_DTYPE:
layer.data_type = "rtn_" + layer.data_type
logger.info("using optimized rtn method for quantizing %s", layer_name)
self.layer_config[layer_name]["data_type"] = layer.data_type
layer.to(self.device)
wrapper_layer = WrapperLinear(
layer,
enable_round_tuning=False,
Expand Down Expand Up @@ -1529,34 +1572,19 @@ def set_layerwise_config(self, layer_config):
otherwise returns False.
"""
# Get the names of layers in quantization blocks
layers_in_blocks = get_layer_names_in_block(self.model, self.supported_types, self.quant_block_list)
supported_types = self.supported_types + self.inner_supported_types
layers_in_blocks = get_layer_names_in_block(self.model, supported_types, self.quant_block_list)
##process regex in layer_config
all_supported_layer_names = []
# List of configuration keys
keys = [
"bits",
"group_size",
"sym",
"data_type",
"scale_dtype",
"act_bits",
"act_group_size",
"act_sym",
"act_dynamic",
"act_data_type",
"super_bits",
"super_group_size",
]
keys = get_quant_keys()

for n, m in self.model.named_modules():
# Delete previous configuration to avoid conflicts with prior tuning
for key in keys:
if hasattr(m, key):
delattr(m, key)

# Skip unsupported types
supported_types = self.supported_types

if not isinstance(m, supported_types):
continue
all_supported_layer_names.append(n)
Expand Down Expand Up @@ -1794,6 +1822,8 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
Raises:
Exception: If caching on GPU fails, switches to CPU and caches there.
"""
if hasattr(self.model, "is_fp8"):
layer_names = []
if layer_names is None:
layer_names = []

Expand Down Expand Up @@ -2211,7 +2241,7 @@ def get_act_max_hook(module, input, output):
continue
return hook_handles

def quant_block(self, block, input_ids, input_others, q_input=None, device=torch.device("cpu")):
def quantize_block(self, block, input_ids, input_others, q_input=None, device=torch.device("cpu")):
"""Quantize the weights of a given block of the model.

Args:
Expand All @@ -2224,6 +2254,12 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
Returns:
Tuple: (q_outputs, output) if self.enable_quanted_input is True, else (None, output)
"""
if hasattr(self.model, "is_fp8"):
for n, m in block.named_modules():
if m.__class__.__name__ == "FP8Linear":
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to(device)
set_module(block, n, new_layer)

if self.device_map is not None:
from accelerate import dispatch_model

Expand Down Expand Up @@ -2421,7 +2457,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
clear_memory(input_ids)
return None, output

def quant_blocks(
def quantize_blocks(
self, model: torch.nn.Module, inputs, block_names, q_input=None, nblocks=1, device="cpu", pbar=None
):
"""Quantize and dequantize the weights of the specified blocks in the model.
Expand Down Expand Up @@ -2459,9 +2495,9 @@ def quant_blocks(
for i in range(len(input_others[key])):
to_dtype(input_others[key][i], tmp_dtype)
if self.enable_torch_compile:
quant_block = compile_func(self.quant_block, device)
quant_block = compile_func(self.quantize_block, device)
else:
quant_block = self.quant_block
quant_block = self.quantize_block

if pbar is None:
pbar = tqdm(range(0, len(block_names), nblocks))
Expand Down Expand Up @@ -2806,23 +2842,23 @@ class AutoRoundOPT(AutoRound):

def __init__(
self,
model,
model: Union[torch.nn.Module, str],
tokenizer=None,
bits: int = 4,
group_size: int = 128,
sym: bool = True,
layer_config=None,
batch_size: int = 8,
amp: bool = True,
device=None,
device: Union[str, torch.device, int] = 0,
lr_scheduler=None,
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
enable_quanted_input: bool = True,
enable_minmax_tuning: bool = True,
lr: float = None,
minmax_lr: float = None,
low_gpu_mem_usage: bool = False,
low_cpu_mem_usage: bool = False,
low_cpu_mem_usage: int = 0,
iters: int = 200,
seqlen: int = 2048,
nsamples: int = 128,
Expand Down Expand Up @@ -2985,23 +3021,23 @@ class AutoRoundAdam(AutoRoundOPT):

def __init__(
self,
model,
model: Union[torch.nn.Module, str],
tokenizer=None,
bits: int = 4,
group_size: int = 128,
sym: bool = True,
layer_config=None,
batch_size: int = 8,
amp: bool = True,
device=None,
device: Union[str, torch.device, int] = 0,
lr_scheduler=None,
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
enable_quanted_input: bool = True,
enable_minmax_tuning: bool = True,
lr: float = None,
minmax_lr: float = None,
low_gpu_mem_usage: bool = False,
low_cpu_mem_usage: bool = False,
low_cpu_mem_usage: int = 0,
iters: int = 200,
seqlen: int = 2048,
nsamples: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion auto_round/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from typing import Optional, Union

from lm_eval import simple_evaluate as lm_simple_evaluate
from lm_eval import simple_evaluate as lm_simple_evaluate # pylint: disable=E0611

os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand Down
Loading
Loading