Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 50 additions & 12 deletions auto_round/compressors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,20 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str
cfg.setdefault(key, copy.deepcopy(default_dict.get(key)))

# 5. collect supported modules
embedding_types = (torch.nn.Embedding,)
gguf_name = get_gguf_scheme(default_scheme)
if gguf_name and torch.nn.Embedding not in supported_types:
supported_types = (*supported_types, torch.nn.Embedding)
if gguf_name:
if torch.nn.Embedding not in supported_types:
supported_types = (*supported_types, torch.nn.Embedding)

# for some Embedding which type() is not torch.nn.Embedding
# for example: transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding
model_module_name = model.__class__.__module__
module_cls = sys.modules[model_module_name]
for name in module_cls.__dict__:
if name.endswith("Embedding") and not name.endswith("RotaryEmbedding"):
embedding_types = (*embedding_types, getattr(module_cls, name))
supported_types = (*supported_types, *embedding_types)

all_supported_layer_names, embedding_layer_names = [], []
all_module_names = []
Expand All @@ -338,7 +349,7 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str
if type(m) not in supported_types and m.__class__.__name__ not in inner_supported_types:
continue
all_supported_layer_names.append(n)
if isinstance(m, torch.nn.Embedding):
if isinstance(m, embedding_types) or m.__class__.__name__.endswith("Embedding"):
embedding_layer_names.append(n)

# 6. expand regex configs
Expand Down Expand Up @@ -650,7 +661,7 @@ def get_layer_config_by_gguf_format(layer_config, target_gguf_format: str, model

import gguf # pylint: disable=E0401

from auto_round.utils.common import LazyImport
from auto_round.utils.common import MM_KEYS, LazyImport
from auto_round.utils.model import get_lm_head_name, get_module

# from auto_round.export.export_to_gguf.convert import ModelBase, get_model_architecture
Expand All @@ -660,24 +671,41 @@ def get_layer_config_by_gguf_format(layer_config, target_gguf_format: str, model
hparams=model.config.to_dict(), model_type=model_type
)
try:
model_class = convert_hf_to_gguf.ModelBase.from_model_architecture(model_architecture, model_type=model_type)
if model_type != ModelType.TEXT:
model_class_vision = convert_hf_to_gguf.ModelBase.from_model_architecture(
model_architecture, model_type=model_type
)
model_class = convert_hf_to_gguf.ModelBase.from_model_architecture(
model_architecture, model_type=ModelType.TEXT
)

except NotImplementedError:
return layer_config, {}

n_layer = None
for name in ["n_layers", "num_hidden_layers", "n_layer", "num_layers"]:
sub_attr = "text_config" if model_type == ModelType.TEXT else "vision_config"
if model_type != ModelType.TEXT:
n_layer_vision = None
for name in ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]:
if hasattr(model.config, name):
n_layer = getattr(model.config, name)
break
if hasattr(model.config, sub_attr):
if hasattr(getattr(model.config, sub_attr), name):
n_layer = getattr(getattr(model.config, sub_attr), name)
if model_type != ModelType.TEXT:
if n_layer is not None and hasattr(model.config, "text_config"):
if hasattr(getattr(model.config, "text_config"), name):
n_layer = getattr(getattr(model.config, "text_config"), name)
for config_name in ["vision_config", "vision_encoder"]:
if hasattr(model.config, config_name):
if hasattr(getattr(model.config, config_name), name):
n_layer_vision = getattr(getattr(model.config, config_name), name)
break
if n_layer and n_layer_vision:
break

if n_layer is None:
return layer_config, {}

tensor_map = gguf.get_tensor_name_map(model_class.model_arch, n_layer)
if model_type != ModelType.TEXT:
tensor_map_vision = gguf.get_tensor_name_map(model_class_vision.model_arch, n_layer_vision)

def _set_config(config, target_config):
for k, v in target_config.items():
Expand Down Expand Up @@ -733,7 +761,17 @@ def _set_config(config, target_config):
re.search("gguf:q([0-9]{1,})_[01k]", GGUF_CONFIG[target_gguf_format]["embedding"]).group(1)
)

gguf_name = tensor_map.get_name(layer_name)
if model_type != ModelType.TEXT and any([key in layer_name for key in MM_KEYS]):
gguf_name = tensor_map_vision.get_name(layer_name)
if gguf_name is None:
for key in MM_KEYS:
gguf_name = tensor_map_vision.get_name(layer_name.replace(f".{key}", ""))
if gguf_name is not None:
break
else:
gguf_name = tensor_map.get_name(layer_name)
if gguf_name is None:
gguf_name = tensor_map.get_name(layer_name.replace(".language_model", ""))
bits_index = 6
if config.get("fixed_by_user", False):
if "bits" not in config:
Expand Down
8 changes: 6 additions & 2 deletions auto_round/eval/eval_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ def eval_task_by_task(
import traceback

from lm_eval import simple_evaluate as lm_simple_evaluate # pylint: disable=E0611
from lm_eval.models.hf_vlms import HFMultimodalLM
from lm_eval.models.huggingface import HFLM
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand Down Expand Up @@ -269,6 +268,8 @@ def eval_task_by_task(
if batch_size is None or batch_size == "auto":
logger.warning("hf-multimodal models does not support auto currently, reset eval_bs to 16")
batch_size = 16
from lm_eval.models.hf_vlms import HFMultimodalLM

hflm = HFMultimodalLM(
pretrained=model,
tokenizer=tokenizer,
Expand Down Expand Up @@ -333,7 +334,10 @@ def eval_task_by_task(
res_all = res
else:
for key in res_keys:
res_all[key].update(res[key])
if key not in res_all:
continue
else:
res_all[key].update(res[key])
print(make_table(res_all))

print("total eval time:", time.time() - st)
Expand Down
3 changes: 2 additions & 1 deletion auto_round/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from lm_eval.models.hf_vlms import HFMultimodalLM
from lm_eval.models.huggingface import HFLM


Expand All @@ -37,6 +36,8 @@ def simple_evaluate_user_model(
**kwargs
):
if mllm:
from lm_eval.models.hf_vlms import HFMultimodalLM

if batch_size is None or batch_size == "auto":
logger.warning("hf-multimodal models does not support auto currently, reset eval_bs to 16")
batch_size = 16
Expand Down
4 changes: 2 additions & 2 deletions auto_round/export/export_to_gguf/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def pack_gguf_layer(
):
"""Export the model to gguf format."""
global gguf_model_instance_global
if output_dir is not None and os.path.exists(output_dir):
logger.warning_once(f"{output_dir} already exists, this may cause model conflict")
# if output_dir is not None and os.path.exists(output_dir):
# logger.warning_once(f"{output_dir} already exists, this may cause model conflict")
if "gguf_model_instance_global" not in globals():
config = model.config

Expand Down
18 changes: 18 additions & 0 deletions auto_round/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,24 @@ def __getitem__(self, key):

SUPPORTED_LAYER_TYPES = SUPPORTED_LAYER_TYPES + (LinearLayer, LinearAllreduce)

MM_KEYS = [
"multi_modal_projector",
"vision_tower",
"multimodal_projector",
"thinker",
"visual",
"audio",
"talker",
"token2wav",
"vision_model",
"audio_tower",
"vision_encoder",
"vision_language_adapter",
"patch_merger",
"pre_mm_projector_norm",
"vision",
]


def is_debug_mode():
"""Checks if the Python interpreter is running in debug mode.
Expand Down
19 changes: 16 additions & 3 deletions auto_round/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ def detect_device_count():
"""
if torch.cuda.is_available():
return torch.cuda.device_count()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return torch.xpu.device_count()
else:
try:
import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401
Expand Down Expand Up @@ -1144,11 +1146,13 @@ def set_avg_auto_device_map(model: torch.nn.Module, device_map):
device_list = parse_available_devices(device_map)
gpu_devices = []
for device in device_list:
if device.startswith("hpu") and len(device_list) > 1:
logger.warning_once("Auto-scheme does not support multiple HPUs.")
if device.startswith("cpu") or device.startswith("hpu"):
continue
gpu_devices.append(device)
num_devices = len(gpu_devices)
if num_devices < 1:
if num_devices <= 1:
return

for block_names in block_name_list:
Expand Down Expand Up @@ -1272,7 +1276,16 @@ def parse_available_devices(device_map: Union[str, torch.device, int, dict, None
device_map = device_map.strip()
if device_map.lower() == "cpu":
return ["cpu"]

if device_map.lower() == "auto":
device_count = detect_device_count()
if "cuda" in device_types:
return [f"cuda:{i}" for i in range(device_count)]
elif "xpu" in device_types:
return [f"xpu:{i}" for i in range(device_count)]
elif "hpu" in device_types:
return [f"hpu:{i}" for i in range(device_count)]
else:
return ["cpu"]
# Split by commas
parts = [x.strip() for x in device_map.split(",") if x.strip()]
parsed = []
Expand All @@ -1283,7 +1296,7 @@ def parse_available_devices(device_map: Union[str, torch.device, int, dict, None
parsed.append(f"{device_type}:{p}" if device_type != "cpu" else "cpu")
else:
parsed.append(p)
return parsed
return list(set(parsed))

if isinstance(device_map, dict):
# Extract all devices recursively from dict values
Expand Down
19 changes: 2 additions & 17 deletions auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,23 +497,8 @@ def is_pure_text_model(model):


def is_mllm_model(model_or_path: Union[str, torch.nn.Module], platform: str = None):
MM_KEYS = [
"multi_modal_projector",
"vision_tower",
"multimodal_projector",
"thinker",
"visual",
"audio",
"talker",
"token2wav",
"vision_model",
"audio_tower",
"vision_encoder",
"vision_language_adapter",
"patch_merger",
"pre_mm_projector_norm",
"vision",
]
from auto_round.utils.common import MM_KEYS

model_path = model_or_path if isinstance(model_or_path, str) else model_or_path.name_or_path
if not os.path.isdir(model_path):
model_path = download_or_get_path(model_path, platform=platform)
Expand Down
11 changes: 11 additions & 0 deletions test/test_cpu/test_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ def test_scheme_in_layer_config(self):
if n == "model.decoder.layers.4.self_attn.k_proj":
self.assertEqual(m.group_size, 64)

def test_parse_available_devices(self):
from auto_round.utils.device import parse_available_devices

device_list = parse_available_devices("auto")
self.assertTrue(len(device_list) == 1 and "cpu" in device_list)
device_list = parse_available_devices("a:cuda:0,b:cuda:1,c:cpu")
self.assertTrue(len(device_list) == 3)
self.assertEqual(device_list, ["cuda:0", "cuda:1", "cpu"])
device_list = parse_available_devices("0,1")
self.assertTrue(len(device_list) == 1 and "cpu" in device_list)


if __name__ == "__main__":
unittest.main()
22 changes: 10 additions & 12 deletions test/test_cuda/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_q2_k_export(self):
quantized_model_path = "./saved"

autoround.save_quantized(output_dir=quantized_model_path, inplace=False, format="gguf:q2_k_s")
gguf_file = "Qwen2.5-1.5B-Instruct-1.5B-Q2_K_S.gguf"
gguf_file = os.listdir(quantized_model_path)[0]
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, gguf_file=gguf_file, device_map="auto")
text = "There is a girl who likes adventure,"
inputs = autoround.tokenizer(text, return_tensors="pt").to(model.device)
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_q4_0(self):
quantized_model_path = "./saved"

autoround.save_quantized(output_dir=quantized_model_path, inplace=False, format="gguf:q4_0")
gguf_file = "Qwen2.5-0.5B-Instruct-494M-Q4_0.gguf"
gguf_file = os.listdir(quantized_model_path)[0]
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, gguf_file=gguf_file, device_map="auto")
text = "There is a girl who likes adventure,"
inputs = autoround.tokenizer(text, return_tensors="pt").to(model.device)
Expand All @@ -144,7 +144,7 @@ def test_q4_1(self):
quantized_model_path = "./saved"

autoround.save_quantized(output_dir=quantized_model_path, inplace=False, format="gguf:q4_1")
gguf_file = "Qwen2.5-0.5B-Instruct-494M-Q4_1.gguf"
gguf_file = os.listdir(quantized_model_path)[0]
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, gguf_file=gguf_file, device_map="auto")
text = "There is a girl who likes adventure,"
inputs = autoround.tokenizer(text, return_tensors="pt").to(model.device)
Expand Down Expand Up @@ -198,15 +198,13 @@ def test_vlm_gguf(self):
quantized_model_path = "./saved"
autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0")
self.assertTrue("mmproj-model.gguf" in os.listdir("./saved"))
file_size = os.path.getsize("./saved/Qwen2.5-VL-7B-Instruct-7.6B-Q4_0.gguf") / 1024**2
self.assertAlmostEqual(file_size, 4226, delta=1.0)
file_size = os.path.getsize("./saved/Qwen2.5-VL-7B-Instruct-Q4_0.gguf") / 1024**2
self.assertAlmostEqual(file_size, 4226, delta=5.0)
file_size = os.path.getsize("./saved/mmproj-model.gguf") / 1024**2
self.assertAlmostEqual(file_size, 2578, delta=1.0)
self.assertAlmostEqual(file_size, 2580, delta=5.0)
shutil.rmtree("./saved", ignore_errors=True)

model_name = "/models/gemma-3-12b-it"
from auto_round import AutoRoundMLLM
from auto_round.utils import mllm_load_model

model, processor, tokenizer, image_processor = mllm_load_model(model_name)
autoround = AutoRoundMLLM(
Expand All @@ -216,15 +214,15 @@ def test_vlm_gguf(self):
image_processor=image_processor,
device="auto",
nsamples=32,
iters=1,
iters=0,
)
quantized_model_path = "./saved"
autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_k_m")
self.assertTrue("mmproj-model.gguf" in os.listdir("./saved"))
file_size = os.path.getsize("./saved/gemma-3-12b-it-12B-Q4_K_M.gguf") / 1024**2
self.assertAlmostEqual(file_size, 6568, delta=1.0)
file_size = os.path.getsize("./saved/gemma-3-12B-it-Q4_K_M.gguf") / 1024**2
self.assertAlmostEqual(file_size, 6962, delta=5.0)
file_size = os.path.getsize("./saved/mmproj-model.gguf") / 1024**2
self.assertAlmostEqual(file_size, 1599, delta=1.0)
self.assertAlmostEqual(file_size, 1599, delta=5.0)
shutil.rmtree(quantized_model_path, ignore_errors=True)

# @require_gguf
Expand Down
1 change: 0 additions & 1 deletion test/test_cuda/test_vlms.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def test_mllm_detect(self):
"/models/Phi-3.5-vision-instruct",
"/models/Qwen2-VL-2B-Instruct",
"/models/SmolVLM-256M-Instruct",
"/models/Llama-4-Maverick-17B-128E-Instruct",
"/models/Mistral-Small-3.1-24B-Instruct-2503",
"/models/InternVL3-1B",
"/models/pixtral-12b",
Expand Down