Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified auto_round/alg_ext.abi3.so
Binary file not shown.
50 changes: 34 additions & 16 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
get_layer_names_in_block,
get_module,
htcore,
is_complex_device_mapping,
is_debug_mode,
is_fp8_linear,
is_fp8_model,
Expand Down Expand Up @@ -472,7 +473,15 @@ def _parse_and_set_scheme(self, scheme: Union[str, dict, QuantizationScheme], kw
"""Parse and set the quantization scheme."""

def _parse_and_set(scheme, kwargs):
res = ""
if kwargs.get("data_type", None) and kwargs["data_type"].endswith("_dq") and not scheme.startswith("gguf"):
if "bits" not in kwargs:
data_type = kwargs["data_type"]
raise KeyError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need this, the default bits is in scheme="W4A16"

Copy link
Contributor

@wenhuach21 wenhuach21 Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all the default values are in scheme, if users want to override some values in the scheme, then set the specific keys

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if use auto_round api as this: auto_round(data_type="rtn_int_sym_dq", bits=3), it will have some bugs since default scheme is W4A16 with super_bits = None and super_group_size = None

f"please set bits when setting data_type={data_type}, or using scheme as an alternative.."
)
bits = kwargs["bits"]
scheme = f"gguf:q{bits}_k" if bits == 6 else f"gguf:q{bits}_k_s"
res = None
if isinstance(scheme, QuantizationScheme):
scheme = asdict(scheme)
elif isinstance(scheme, dict):
Expand Down Expand Up @@ -1205,7 +1214,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
m = get_module(self.model, name)

if is_fp8_linear(m):
m = convert_fp8_layer_to_linear(m, self.amp_dtype)
m = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device)
set_module(self.model, name, m)

# Step 1: Try quantization on GPU first, fall back to CPU if OOM
Expand Down Expand Up @@ -1358,7 +1367,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
cnt += 1
# Convert remaining fp8
if is_fp8_model(self.model):
convert_fp8_model_to_16b_model(self.model, self.amp_dtype)
convert_fp8_model_to_16b_model(self.model, self.amp_dtype, self.device)
self.quantized = True
return self.model, self.layer_config

Expand Down Expand Up @@ -1424,16 +1433,15 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
for block_name in block_names:
pbar.set_description(f"Quantizing {block_name}")
block = get_module(self.model, block_name)
block = block.to(self.device)
if is_fp8_model(self.model):
convert_fp8_model_to_16b_model(block, dtype=self.amp_dtype)
convert_fp8_model_to_16b_model(block, dtype=self.amp_dtype, device=self.device)

if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map):
if is_complex_device_mapping(self.device_map):
set_auto_device_map_for_block_with_tuning(
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size
)
# Dispatch model if needed
if self.device_map is not None:
if is_complex_device_mapping(self.device_map):
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

for _, m in block.named_modules():
Expand All @@ -1451,7 +1459,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
self.device,
self.cache_device,
)
if self.device_map is not None:
if is_complex_device_mapping(self.device_map):
accelerate.hooks.remove_hook_from_submodules(block)

if is_nv_fp(self.act_data_type) or is_static_wfp8afp8(self):
Expand Down Expand Up @@ -1630,7 +1638,7 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
if is_fp8_model(self.model):
for n, m in self.model.named_modules():
if is_fp8_linear(m):
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to("cpu")
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device).to("cpu")
set_module(self.model, n, new_layer)

end_time = time.time()
Expand Down Expand Up @@ -1678,8 +1686,8 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:

layer = get_module(self.model, layer_name)
layer = layer.to(self.device)
if is_fp8_model(self.model):
new_layer = convert_fp8_layer_to_linear(layer, self.amp_dtype).to(self.device)
if is_fp8_linear(layer):
new_layer = convert_fp8_layer_to_linear(layer, self.amp_dtype, self.device).to(self.device)
set_module(self.model, layer_name, new_layer)
layer = new_layer

Expand Down Expand Up @@ -2445,17 +2453,17 @@ def _quantize_block(
if is_fp8_model(self.model):
for n, m in block.named_modules():
if is_fp8_linear(m):
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to(device)
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device).to(device)
set_module(block, n, new_layer)

if self.device_map == "auto" or ((isinstance(self.device_map, str) and "," in self.device_map)):
if is_complex_device_mapping(self.device_map):
set_auto_device_map_for_block_with_tuning(
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device
)
else:
block = block.to(device)

if self.device_map is not None:
if is_complex_device_mapping(self.device_map):
for n, m in block.named_modules():
if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"):
continue
Expand Down Expand Up @@ -2653,15 +2661,15 @@ def _quantize_block(
device,
cache_device=self.cache_device,
)
if self.device_map is not None:
if is_complex_device_mapping(self.device_map):
accelerate.hooks.remove_hook_from_submodules(block)
mv_module_from_gpu(block)
clear_memory(input_ids)

return q_outputs, output

else:
if self.device_map is not None:
if is_complex_device_mapping(self.device_map):
accelerate.hooks.remove_hook_from_submodules(block)
mv_module_from_gpu(block)
clear_memory(input_ids)
Expand Down Expand Up @@ -2741,6 +2749,16 @@ def _quantize_blocks(
except (ImportError, ModuleNotFoundError):
logger.error("algorithm extension import error, fallback to default mode")
quantize_block = self._quantize_block
elif self.enable_alg_ext and self.data_type.endswith("dq"):
try:
from auto_round.alg_ext import dq_quantize_block_ext

BaseCompressor.dq_quantize_block_ext = dq_quantize_block_ext
quantize_block = self.dq_quantize_block_ext
logger.info("using algorithm extension for quantization.")
except (ImportError, ModuleNotFoundError):
logger.error("algorithm extension import error, fallback to default mode")
quantize_block = self._quantize_block
else:
quantize_block = self._quantize_block

Expand Down
25 changes: 17 additions & 8 deletions auto_round/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,19 @@ def get_packing_device(device: str | torch.device | None = "auto") -> torch.devi
raise TypeError(f"Unsupported device type: {type(device)} ({device})")


def is_complex_device_mapping(device_map):
if device_map is None or isinstance(device_map, int):
return False
elif device_map == "auto":
return True
elif isinstance(device_map, str) and "," in device_map:
return True
elif isinstance(device_map, dict):
return True
else:
return False


class CpuInfo(object):
"""Get CPU Info."""

Expand Down Expand Up @@ -598,15 +611,11 @@ def set_tuning_device_for_layer(model, name: str, device: str) -> None:
def set_non_auto_device_map(
model: torch.nn.Module, device_map: Union[str, int, dict], quant_layer_names: Union[None, list, tuple] = None
) -> None:
if not device_map:
return
if device_map == "auto":
return
if isinstance(device_map, str) and "," in device_map: # auto device map
return
if isinstance(device_map, int):
if not device_map or device_map == "auto" or isinstance(device_map, int):
return
if isinstance(device_map, str):
if "," in device_map: # auto device map
return
device_map = device_map.replace(" ", "")
infos = device_map.split(",")
device_map_dict = {}
Expand Down Expand Up @@ -840,7 +849,7 @@ def set_auto_device_map_for_block_with_tuning(
num_devices = torch.xpu.device_count()
device_name = "xpu"
else:
raise RuntimeError("No CUDA or XPU devices found.")
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to support all devices that are already supported by PyTorch, at least for single-device setups, such as NPUs or other accelerators?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is to solve device_map={"*": "cpu"}

device_list = None
if isinstance(device_map, str) and "," in device_map:
device_list = [int(dev) for dev in device_map.split(",") if dev.isdigit()]
Expand Down
7 changes: 4 additions & 3 deletions auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def check_seqlen_compatible(input_seqlen, tokenizer=None, model=None):
)


def convert_fp8_layer_to_linear(layer, dtype=torch.bfloat16):
def convert_fp8_layer_to_linear(layer, dtype=torch.bfloat16, device: str = "cpu"):
""" """
from auto_round.schemes import QuantizationScheme

Expand All @@ -939,6 +939,7 @@ def convert_fp8_layer_to_linear(layer, dtype=torch.bfloat16):
for key in keys:
setattr(new_layer, key, getattr(layer, key, None))

layer = layer.to(device)
if layer.__class__.__name__ == "CompressedLinear":
dq_weight = layer.compressor.decompress_module(layer)
else:
Expand All @@ -948,7 +949,7 @@ def convert_fp8_layer_to_linear(layer, dtype=torch.bfloat16):
return new_layer


def convert_fp8_model_to_16b_model(model, dtype=torch.bfloat16):
def convert_fp8_model_to_16b_model(model, dtype=torch.bfloat16, device: str = "cpu"):
"""
Convert a model with FP8 quantized layers to a model with 16-bit linear layers.
This is useful for compatibility with other frameworks or for further processing.
Expand All @@ -958,7 +959,7 @@ def convert_fp8_model_to_16b_model(model, dtype=torch.bfloat16):
cnt = 0
for n, m in model.named_modules():
if m.__class__.__name__ == "FP8Linear":
new_module = convert_fp8_layer_to_linear(m, dtype=dtype)
new_module = convert_fp8_layer_to_linear(m, dtype=dtype, device=device)
set_module(model, n, new_module)
cnt += 1
if cnt % 10 == 0: # Tricky setting
Expand Down
16 changes: 16 additions & 0 deletions test/test_cuda/test_alg_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,19 @@ def test_2bits(self):
# wo alg ext 0.2084, with 0.2364
self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.22)
shutil.rmtree(self.save_folder, ignore_errors=True)

def test_cli(self):
import os

model_name = "/models/opt-125m"
python_path = sys.executable

res = os.system(
f"cd ../.. && CUDA_VISIBLE_DEVICES=0 {python_path} -m auto_round --model {model_name} --device auto --enable_alg_ext --avg_bits 2 --options=W2A16,W4A16 --ignore_scale_zp_bits"
)
if res > 0 or res == -1:
assert False, "cmd line test fail, please have a check"


if __name__ == "__main__":
unittest.main()