Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified auto_round/alg_ext.abi3.so
100644 → 100755
Binary file not shown.
2 changes: 1 addition & 1 deletion auto_round/auto_scheme/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory

from auto_round.low_cpu_mem import get_module
from auto_round.schemes import QuantizationScheme, preset_name_to_scheme
from auto_round.utils import (
SUPPORTED_LAYER_TYPES,
check_to_quantized,
get_block_names,
get_layer_features,
get_module,
is_hpex_available,
)

Expand Down
6 changes: 2 additions & 4 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __new__(
**kwargs: Backward compatible options:
- enable_alg_ext, quant_lm_head, lr, lr_scheduler, sampler, not_use_best_mse, dynamic_max_gap,
super_group_size, super_bits, scale_dtype ("fp16" etc.),
nblocks, low_cpu_mem_usage, to_quant_block_names,
nblocks, to_quant_block_names,
enable_norm_bias_tuning, enable_quanted_input,
disable_deterministic_algorithms, vlm, static_kv_dtype
Raises:
Expand Down Expand Up @@ -271,7 +271,7 @@ class AutoRoundLLM(LLMCompressor):
**kwargs: Backward compatible options:
- enable_alg_ext, quant_lm_head, lr, lr_scheduler, sampler, not_use_best_mse, dynamic_max_gap,
super_group_size, super_bits, scale_dtype ("fp16" etc.),
nblocks, low_cpu_mem_usage, to_quant_block_names,
nblocks, to_quant_block_names,
enable_norm_bias_tuning, enable_quanted_input,
disable_deterministic_algorithms, mllm, static_kv_dtype
Raises:
Expand Down Expand Up @@ -366,7 +366,6 @@ class AutoRoundAdam(AdamCompressor):
lr (float): The learning rate (default is 0.005).
minmax_lr (float): The learning rate for min-max tuning (default is None).
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False).
low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False).
iters (int): Number of iterations (default is 200).
seqlen (int): Length of the sequence.
nsamples (int): Number of samples (default is 128).
Expand Down Expand Up @@ -473,7 +472,6 @@ class AutoRoundMLLM(MLLMCompressor):
lr (float): The learning rate (default is 0.005).
minmax_lr (float): The learning rate for min-max tuning (default is None).
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False).
low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False).
iters (int): Number of iterations (default is 200).
seqlen (int): Length of the sequence.
nsamples (int): Number of samples (default is 128).
Expand Down
56 changes: 14 additions & 42 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from auto_round.export.export_to_autoround import AutoRoundFormat
from auto_round.export.export_to_gguf.config import GGUF_CONFIG, GGUF_INNER_CONFIG, ModelType
from auto_round.logger import logger
from auto_round.low_cpu_mem.utils import get_layers_before_block
from auto_round.schemes import AutoScheme, QuantizationScheme, get_gguf_scheme, preset_name_to_scheme
from auto_round.sign_sgd import SignSGD
from auto_round.special_model_handler import _handle_moe_model
Expand Down Expand Up @@ -181,7 +180,7 @@ def __init__(
**kwargs: Backward compatible options:
- enable_alg_ext, quant_lm_head, lr, lr_scheduler, sampler, not_use_best_mse, dynamic_max_gap,
super_group_size, super_bits, scale_dtype ("fp16" etc.),
nblocks, low_cpu_mem_usage, to_quant_block_names,
nblocks, to_quant_block_names,
enable_norm_bias_tuning, enable_quanted_input,
disable_deterministic_algorithms, mllm, static_kv_dtype
Raises:
Expand Down Expand Up @@ -254,7 +253,6 @@ def __init__(
not_use_best_mse = kwargs.pop("not_use_best_mse", False)
dynamic_max_gap = kwargs.pop("dynamic_max_gap", -1)
nblocks = kwargs.pop("nblocks", 1)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
to_quant_block_names: Union[str, list, None] = kwargs.pop("to_quant_block_names", None)
enable_norm_bias_tuning: bool = kwargs.pop("enable_norm_bias_tuning", False)
enable_quanted_input: bool = kwargs.pop("enable_quanted_input", True)
Expand Down Expand Up @@ -292,14 +290,12 @@ def __init__(
# Model related
self.quantized = False
if isinstance(model, str):
model, tokenizer, low_cpu_mem_usage = llm_load_model(
model, tokenizer = llm_load_model(
model,
device="cpu", # always load cpu first
low_cpu_mem_mode=low_cpu_mem_usage,
)
elif tokenizer is None and not self.diffusion and iters > 0:
raise ValueError("A tokenizer must be set for non-str model input")
self.low_cpu_mem_usage = bool(low_cpu_mem_usage)
if unsupported_meta_device(model):
raise RuntimeError(
"AutoRound does not support parameters on meta device. "
Expand Down Expand Up @@ -572,7 +568,6 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None:
and TORCH_VERSION_AT_LEAST_2_6
and self.act_bits > 8
and not is_debug_mode()
and not self.low_cpu_mem_usage
and "fp8" not in self.data_type
and "fp8" not in self.act_data_type
):
Expand All @@ -581,10 +576,6 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None:
"Enabling it can reduce tuning cost by 20%, but it might throw an exception."
)

if self.low_cpu_mem_usage and self.enable_torch_compile:
self.enable_torch_compile = False
logger.warning("reset enable_torch_compile to `False` as low_cpu_mem_usage is enabled")

# if is_debug_mode() and self.enable_torch_compile:
# self.enable_torch_compile = False
# logger.warning("reset enable_torch_compile to `False` as debug mode is enabled")
Expand Down Expand Up @@ -1605,7 +1596,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
self._quantize_layer_via_rtn(m.tmp_name)
all_to_quantized_module_names.remove(m.tmp_name)

mv_module_from_gpu(block, self.low_cpu_mem_usage)
mv_module_from_gpu(block)
pbar.update(1)

pbar.close()
Expand Down Expand Up @@ -1717,11 +1708,11 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
all_q_inputs = self.try_cache_inter_data_gpucpu(
all_first_block_names, self.nsamples, layer_names=layer_names
)
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
self.model = mv_module_from_gpu(self.model)
clear_memory()
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
accelerate.hooks.remove_hook_from_submodules(self.model) # self.model.hf_device_map has not been changed
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
self.model = mv_module_from_gpu(self.model)
logger.info("caching done")
if len(all_blocks) > 1:
pbar = tqdm(range(0, sum([len(i) for i in all_blocks]), self.nblocks))
Expand Down Expand Up @@ -1855,7 +1846,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
self.model
) ##self.model.hf_device_map has not been changed

self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
self.model = mv_module_from_gpu(self.model)
clear_memory()
quant_layer = self._quantize_layer
for layer_name in layer_names:
Expand Down Expand Up @@ -1945,12 +1936,6 @@ def calib(self, nsamples, bs):
self.dataloader = self.dataset
total_cnt = 0

# load embed weight if use low_cpu_mem_usage
if self.low_cpu_mem_usage:
embed_layers = get_layers_before_block(self.model)
for n, m in embed_layers:
m = m.to(self.device)

for data in self.dataloader:
if data is None:
continue
Expand Down Expand Up @@ -2013,11 +1998,6 @@ def calib(self, nsamples, bs):
f"Target samples count is {nsamples}, while valid samples count is {total_cnt}"
)

# clean embed weight to save memory
if self.low_cpu_mem_usage:
for n, m in embed_layers:
m = m.to("meta")

@torch.no_grad()
def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, last_cache_name=None):
"""Attempts to cache intermediate data on GPU, if failed, then using CPU.
Expand Down Expand Up @@ -2091,7 +2071,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
accelerate.hooks.remove_hook_from_submodules(
self.model
) ##self.model.hf_device_map has not been changed
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
self.model = mv_module_from_gpu(self.model)
clear_memory()
## Important change after v0.51, on cpu, we use rtn mode for layers in layer_names
all_inputs = self.cache_inter_data(
Expand Down Expand Up @@ -2348,7 +2328,7 @@ def _quantize_layer(
logger.info(dump_info)
with torch.no_grad():
unwrapper_layer(self.model, wrapper_linear, layer_name, {})
mv_module_from_gpu(layer, self.low_cpu_mem_usage)
mv_module_from_gpu(layer)

lr = torch.tensor(self.lr)
minmax_lr = torch.tensor(self.minmax_lr)
Expand Down Expand Up @@ -2441,7 +2421,7 @@ def _quantize_layer(
best_iter = last_best_iter
with torch.no_grad():
unwrapper_layer(self.model, wrapper_linear, layer_name, best_params)
mv_module_from_gpu(layer, self.low_cpu_mem_usage)
mv_module_from_gpu(layer)
dump_info = f"quantized {layer_name}, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}"
logger.info(dump_info)

Expand Down Expand Up @@ -2636,7 +2616,7 @@ def _quantize_block(
)
logger.info(dump_info)
unwrapper_block(block, {}) # TODO Quant layer should change
mv_module_from_gpu(block, self.low_cpu_mem_usage)
mv_module_from_gpu(block)
return output, output

if self.lr_scheduler is None:
Expand Down Expand Up @@ -2731,8 +2711,6 @@ def _quantize_block(
set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max")

if self.enable_quanted_input:
if self.low_cpu_mem_usage:
block = block.to(device)
clear_memory()
q_outputs = self._get_block_outputs(
block,
Expand All @@ -2744,15 +2722,15 @@ def _quantize_block(
)
if self.device_map is not None:
accelerate.hooks.remove_hook_from_submodules(block)
mv_module_from_gpu(block, self.low_cpu_mem_usage)
mv_module_from_gpu(block)
clear_memory(input_ids)

return q_outputs, output

else:
if self.device_map is not None:
accelerate.hooks.remove_hook_from_submodules(block)
mv_module_from_gpu(block, self.low_cpu_mem_usage)
mv_module_from_gpu(block)
clear_memory(input_ids)
return None, output

Expand Down Expand Up @@ -2848,9 +2826,7 @@ def _quantize_blocks(
modules = [get_module(model, n) for n in names]
m = WrapperMultiblock(modules)

if not self.model.device.type == "meta" or self.low_cpu_mem_usage:
m = m.to(device)

m = m.to(device)
q_input, input_ids = quantize_block(
m,
input_ids,
Expand Down Expand Up @@ -2889,7 +2865,7 @@ def _quantize_blocks(
if pbar is not None:
pbar.update(1)

self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
self.model = mv_module_from_gpu(self.model)
for n, m in self.model.named_modules():
if hasattr(m, "name"):
delattr(m, "name")
Expand Down Expand Up @@ -2917,9 +2893,6 @@ def save_quantized(
"""
format = self._check_supported_format(format)

if self.low_cpu_mem_usage:
self.model = self.model.to("cpu")

if not self.quantized:
logger.warning("please run autoround.quantize first")
return
Expand Down Expand Up @@ -3207,7 +3180,6 @@ class AdamCompressor(BaseCompressor):
lr (float): The learning rate (default is 0.005).
minmax_lr (float): The learning rate for min-max tuning (default is None).
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False).
low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False).
iters (int): Number of iterations (default is 200).
seqlen (int): Length of the sequence.
nsamples (int): Number of samples (default is 128).
Expand Down
10 changes: 0 additions & 10 deletions auto_round/compressors/diffusion/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from auto_round.compressors.base import BaseCompressor
from auto_round.compressors.diffusion.dataset import get_diffusion_dataloader
from auto_round.logger import logger
from auto_round.low_cpu_mem.utils import get_layers_before_block
from auto_round.schemes import QuantizationScheme
from auto_round.utils import (
LazyImport,
Expand Down Expand Up @@ -302,11 +301,6 @@ def calib(self, nsamples, bs):
self.dataloader = self.dataset
total_cnt = 0

if self.low_cpu_mem_usage:
embed_layers = get_layers_before_block(self.model)
for n, m in embed_layers:
m = m.to(self.device)

total = nsamples if not hasattr(self.dataloader, "len") else min(nsamples, len(self.dataloader))
if self.pipe.dtype != self.model.dtype:
self.pipe.to(self.model.dtype)
Expand Down Expand Up @@ -358,10 +352,6 @@ def calib(self, nsamples, bs):
if isinstance(v[key], list) and len(v[key]) == total_cnt:
self.inputs[k][key] = v[key][:max_len]

# clean embed weight to save memory
if self.low_cpu_mem_usage:
for n, m in embed_layers:
m = m.to("meta")
# torch.cuda.empty_cache()

def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **kwargs):
Expand Down
11 changes: 0 additions & 11 deletions auto_round/compressors/mllm/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from auto_round.compressors.mllm.dataset import get_mllm_dataloader
from auto_round.compressors.mllm.template import Template, get_template
from auto_round.logger import logger
from auto_round.low_cpu_mem.utils import get_layers_before_block
from auto_round.schemes import QuantizationScheme
from auto_round.special_model_handler import (
NOT_SUPPORT_ONLY_TEXT_MODELS,
Expand Down Expand Up @@ -108,7 +107,6 @@ class MLLMCompressor(BaseCompressor):
lr (float): The learning rate (default is 0.005).
minmax_lr (float): The learning rate for min-max tuning (default is None).
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False).
low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False).
iters (int): Number of iterations (default is 200).
seqlen (int): Length of the sequence.
nsamples (int): Number of samples (default is 128).
Expand Down Expand Up @@ -315,11 +313,6 @@ def calib(self, nsamples, bs):
self.dataloader = self.dataset
total_cnt = 0

if self.low_cpu_mem_usage:
embed_layers = get_layers_before_block(self.model)
for n, m in embed_layers:
m = m.to(self.device)

total = nsamples if not hasattr(self.dataloader, "len") else min(nsamples, len(self.dataloader))
with tqdm(range(1, total + 1), desc="cache block inputs") as pbar:
for data in self.dataloader:
Expand Down Expand Up @@ -417,10 +410,6 @@ def calib(self, nsamples, bs):
if isinstance(v[key], list) and len(v[key]) == total_cnt:
self.inputs[k][key] = v[key][:max_len]

# clean embed weight to save memory
if self.low_cpu_mem_usage:
for n, m in embed_layers:
m = m.to("meta")
# torch.cuda.empty_cache()

def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **kwargs):
Expand Down
18 changes: 0 additions & 18 deletions auto_round/low_cpu_mem/__init__.py

This file was deleted.

Loading