Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
719e5ab
fp8 exporting bugfix
WeiweiZhang1 Oct 9, 2025
8e8b04f
Merge branch 'main' of https://github.com/intel/auto-round into main
WeiweiZhang1 Oct 16, 2025
57bb2f4
fix act related config saving
WeiweiZhang1 Oct 16, 2025
ad000e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 16, 2025
6bba765
add ut for act_config check
WeiweiZhang1 Oct 16, 2025
d9fcbd0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 16, 2025
2f34fc2
Merge branch 'main' into fix_act_config_exporting
WeiweiZhang1 Oct 16, 2025
71c9e96
Merge branch 'main' into fix_act_config_exporting
WeiweiZhang1 Oct 20, 2025
535602e
refine extra_config saving, add UTs
WeiweiZhang1 Oct 20, 2025
f8bad15
fix ut typo
WeiweiZhang1 Oct 20, 2025
a5c22b7
fix ut typo
WeiweiZhang1 Oct 20, 2025
2d6bde0
fixtypo
WeiweiZhang1 Oct 20, 2025
5b8d188
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2025
dff0dd7
fix CI
WeiweiZhang1 Oct 21, 2025
04278dc
fix scan issue
WeiweiZhang1 Oct 21, 2025
73761da
fix scan issue
WeiweiZhang1 Oct 21, 2025
4cf21ad
rm global variable
WeiweiZhang1 Oct 21, 2025
0454c57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2025
78bc23f
rerun ut
WeiweiZhang1 Oct 21, 2025
d6ffe3b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2025
5c5fb5f
refine ut
WeiweiZhang1 Oct 21, 2025
0a42ce6
Merge branch 'fix_act_config_exporting' of https://github.com/intel/a…
WeiweiZhang1 Oct 21, 2025
5510afa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import os
from concurrent.futures import ThreadPoolExecutor
from dataclasses import fields
from enum import Enum

import threadpoolctl as tctl
Expand All @@ -26,9 +27,10 @@
import transformers
from tqdm import tqdm

from auto_round.export.export_to_autoround.utils import REQUIRED_CONFIG_KEYS, check_neq_config
from auto_round.export.export_to_autoround.utils import check_neq_config
from auto_round.export.utils import save_model
from auto_round.logger import logger
from auto_round.schemes import QuantizationScheme
from auto_round.utils import (
SUPPORTED_FORMATS,
SUPPORTED_LAYER_TYPES,
Expand Down Expand Up @@ -324,26 +326,20 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
for i in range(len(block_name_to_quantize)):
block_name_to_quantize[i] = os.path.commonprefix(block_name_to_quantize[i]).rstrip(".")

for layer_name in layer_config:
if (
not layer_config[layer_name]["in_blocks"] and layer_config[layer_name]["bits"] <= 8
): ##lm head ##TODO fix act and so on
extra_config[layer_name] = {}
extra_config[layer_name]["bits"] = layer_config[layer_name]["bits"]
extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"]
extra_config[layer_name]["group_size"] = layer_config[layer_name]["group_size"]
extra_config[layer_name]["sym"] = layer_config[layer_name]["sym"]
elif layer_config[layer_name]["in_blocks"] or (
scheme_keys = [f.name for f in fields(QuantizationScheme)]
for layer_name, cfg in layer_config.items():
if not cfg["in_blocks"] and cfg["bits"] <= 8: # lm head
extra_config[layer_name] = {key: cfg.get(key) for key in scheme_keys}
elif cfg["in_blocks"] or (
block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize)
):
neq_keys = check_neq_config(
layer_config[layer_name], **{k: quantization_config[k] for k in REQUIRED_CONFIG_KEYS}
)
neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys})
if len(neq_keys) > 0:
extra_config[layer_name] = {}
for key in neq_keys:
if layer_config[layer_name][key] is not None:
extra_config[layer_name][key] = layer_config[layer_name][key]
for key in scheme_keys:
if cfg[key] is not None:
extra_config[layer_name][key] = cfg[key]

if len(extra_config) > 0:
quantization_config["extra_config"] = extra_config
names = list(layer_config.keys())
Expand Down
30 changes: 13 additions & 17 deletions auto_round/export/export_to_autoround/export_to_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@
import json
import os
from concurrent.futures import ThreadPoolExecutor
from dataclasses import fields

import threadpoolctl as tctl
import torch
import transformers
from tqdm import tqdm

from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad
from auto_round.export.export_to_autoround.utils import REQUIRED_CONFIG_KEYS, check_neq_config
from auto_round.export.export_to_autoround.utils import check_neq_config
from auto_round.export.utils import save_model
from auto_round.logger import logger
from auto_round.schemes import QuantizationScheme
from auto_round.utils import (
SUPPORTED_LAYER_TYPES,
_get_packing_device,
Expand Down Expand Up @@ -169,26 +171,20 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round",
for i in range(len(block_name_to_quantize)):
block_name_to_quantize[i] = os.path.commonprefix(block_name_to_quantize[i]).rstrip(".")

for layer_name in layer_config:
if (
not layer_config[layer_name]["in_blocks"] and layer_config[layer_name]["bits"] <= 8
): ##lm head ##TODO fix act and so on
extra_config[layer_name] = {}
extra_config[layer_name]["bits"] = layer_config[layer_name]["bits"]
extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"]
extra_config[layer_name]["group_size"] = layer_config[layer_name]["group_size"]
extra_config[layer_name]["sym"] = layer_config[layer_name]["sym"]
elif layer_config[layer_name]["in_blocks"] or (
scheme_keys = [f.name for f in fields(QuantizationScheme)]
for layer_name, cfg in layer_config.items():
if not cfg["in_blocks"] and cfg["bits"] <= 8: # lm head
extra_config[layer_name] = {key: cfg.get(key) for key in scheme_keys}
elif cfg["in_blocks"] or (
block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize)
):
neq_keys = check_neq_config(
layer_config[layer_name], **{k: quantization_config[k] for k in REQUIRED_CONFIG_KEYS}
)
neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys})
if len(neq_keys) > 0:
extra_config[layer_name] = {}
for key in neq_keys:
if layer_config[layer_name][key] is not None:
extra_config[layer_name][key] = layer_config[layer_name][key]
for key in scheme_keys:
if cfg[key] is not None:
extra_config[layer_name][key] = cfg[key]

if len(extra_config) > 0:
quantization_config["extra_config"] = extra_config
names = list(layer_config.keys())
Expand Down
30 changes: 13 additions & 17 deletions auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@
import json
import os
from concurrent.futures import ThreadPoolExecutor
from dataclasses import fields

import threadpoolctl as tctl
import torch
import torch.nn as nn
import transformers
from tqdm import tqdm

from auto_round.export.export_to_autoround.utils import REQUIRED_CONFIG_KEYS, check_neq_config
from auto_round.export.export_to_autoround.utils import check_neq_config
from auto_round.export.utils import save_model
from auto_round.logger import logger
from auto_round.schemes import QuantizationScheme
from auto_round.utils import (
SUPPORTED_LAYER_TYPES,
_get_packing_device,
Expand Down Expand Up @@ -195,26 +197,20 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs):
for i in range(len(block_name_to_quantize)):
block_name_to_quantize[i] = os.path.commonprefix(block_name_to_quantize[i]).rstrip(".")

for layer_name in layer_config:
if (
not layer_config[layer_name]["in_blocks"] and layer_config[layer_name]["bits"] <= 8
): ##lm head # TODO fix act and so on
extra_config[layer_name] = {}
extra_config[layer_name]["bits"] = layer_config[layer_name]["bits"]
extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"]
extra_config[layer_name]["group_size"] = layer_config[layer_name]["group_size"]
extra_config[layer_name]["sym"] = layer_config[layer_name]["sym"]
elif layer_config[layer_name]["in_blocks"] or (
scheme_keys = [f.name for f in fields(QuantizationScheme)]
for layer_name, cfg in layer_config.items():
if not cfg["in_blocks"] and cfg["bits"] <= 8: # lm head
extra_config[layer_name] = {key: cfg.get(key) for key in scheme_keys}
elif cfg["in_blocks"] or (
block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize)
):
neq_keys = check_neq_config(
layer_config[layer_name], **{k: quantization_config[k] for k in REQUIRED_CONFIG_KEYS}
)
neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys})
if len(neq_keys) > 0:
extra_config[layer_name] = {}
for key in neq_keys:
if layer_config[layer_name][key] is not None:
extra_config[layer_name][key] = layer_config[layer_name][key]
for key in scheme_keys:
if cfg[key] is not None:
extra_config[layer_name][key] = cfg[key]

if len(extra_config) > 0:
quantization_config["extra_config"] = extra_config
names = list(layer_config.keys())
Expand Down
26 changes: 10 additions & 16 deletions auto_round/export/export_to_autoround/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.

REQUIRED_CONFIG_KEYS = (
"data_type",
"bits",
"group_size",
"sym",
"act_bits",
"act_data_type",
"act_group_size",
"act_sym",
"act_dynamic",
)
from dataclasses import fields
from typing import List

from auto_round.schemes import QuantizationScheme

def check_neq_config(config: dict, **expected) -> dict[str, tuple]:

def check_neq_config(config: dict, **expected) -> List[str]:
"""
Compare a config dict against expected values.
Ensures all required keys are present in both config and expected.

Returns:
dict[str, tuple]: {key: (actual, expected)} for mismatched values.
List[str]: [keys] for mismatched values.
"""
scheme_keys = [f.name for f in fields(QuantizationScheme)]
# 1. Check missing from expected
missing_expected = [k for k in REQUIRED_CONFIG_KEYS if k not in expected]
missing_expected = [k for k in scheme_keys if k not in expected]
if missing_expected:
raise ValueError(f"Missing expected values for keys: {missing_expected}")

# 2. Check missing from layer config
missing_config = [k for k in REQUIRED_CONFIG_KEYS if k not in config]
missing_config = [k for k in scheme_keys if k not in config]
if missing_config:
raise ValueError(f"Missing config values for keys: {missing_config}")

# 3. Collect mismatches
return {key: (config[key], expected[key]) for key in REQUIRED_CONFIG_KEYS if config[key] != expected[key]}
return [key for key in scheme_keys if config[key] != expected[key] and config[key] is not None]
Loading