From 9cb3cc67b80c0be3b32058355d567ffec9997907 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Mon, 10 Nov 2025 23:31:42 -0500 Subject: [PATCH 1/4] fix multi cuda ut bug Signed-off-by: n1ck-guo --- auto_round/compressors/base.py | 2 +- auto_round/eval/eval_cli.py | 8 ++++-- auto_round/eval/evaluation.py | 3 ++- auto_round/utils/device.py | 19 +++++++++++--- test/test_cpu/test_scheme.py | 11 ++++++++ test/test_cuda/test_gguf.py | 48 ++++++++++++++++------------------ test/test_cuda/test_vlms.py | 1 - 7 files changed, 59 insertions(+), 33 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 45ed0c14f..7a119258f 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1464,7 +1464,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, self.device ) # Dispatch model if needed - if len(self.device_list) > 0: + if len(self.device_list) > 1: from accelerate.hooks import AlignDevicesHook, add_hook_to_module for _, m in block.named_modules(): diff --git a/auto_round/eval/eval_cli.py b/auto_round/eval/eval_cli.py index 009b6458d..7235e9e4c 100644 --- a/auto_round/eval/eval_cli.py +++ b/auto_round/eval/eval_cli.py @@ -229,7 +229,6 @@ def eval_task_by_task( import traceback from lm_eval import simple_evaluate as lm_simple_evaluate # pylint: disable=E0611 - from lm_eval.models.hf_vlms import HFMultimodalLM from lm_eval.models.huggingface import HFLM from transformers import AutoModelForCausalLM, AutoTokenizer @@ -269,6 +268,8 @@ def eval_task_by_task( if batch_size is None or batch_size == "auto": logger.warning("hf-multimodal models does not support auto currently, reset eval_bs to 16") batch_size = 16 + from lm_eval.models.hf_vlms import HFMultimodalLM + hflm = HFMultimodalLM( pretrained=model, tokenizer=tokenizer, @@ -333,7 +334,10 @@ def eval_task_by_task( res_all = res else: for key in res_keys: - res_all[key].update(res[key]) + if key not in res_all: + continue + else: + res_all[key].update(res[key]) print(make_table(res_all)) print("total eval time:", time.time() - st) diff --git a/auto_round/eval/evaluation.py b/auto_round/eval/evaluation.py index 00a0fdca0..515357e2d 100644 --- a/auto_round/eval/evaluation.py +++ b/auto_round/eval/evaluation.py @@ -21,7 +21,6 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" -from lm_eval.models.hf_vlms import HFMultimodalLM from lm_eval.models.huggingface import HFLM @@ -37,6 +36,8 @@ def simple_evaluate_user_model( **kwargs ): if mllm: + from lm_eval.models.hf_vlms import HFMultimodalLM + if batch_size is None or batch_size == "auto": logger.warning("hf-multimodal models does not support auto currently, reset eval_bs to 16") batch_size = 16 diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index deb6f2122..849f98744 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -193,6 +193,8 @@ def detect_device_count(): """ if torch.cuda.is_available(): return torch.cuda.device_count() + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + return torch.xpu.device_count() else: try: import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401 @@ -1144,11 +1146,13 @@ def set_avg_auto_device_map(model: torch.nn.Module, device_map): device_list = parse_available_devices(device_map) gpu_devices = [] for device in device_list: + if device.startswith("hpu") and len(device_list) > 1: + logger.warning_once("Auto-scheme does not support multiple HPUs.") if device.startswith("cpu") or device.startswith("hpu"): continue gpu_devices.append(device) num_devices = len(gpu_devices) - if num_devices < 1: + if num_devices <= 1: return for block_names in block_name_list: @@ -1272,7 +1276,16 @@ def parse_available_devices(device_map: Union[str, torch.device, int, dict, None device_map = device_map.strip() if device_map.lower() == "cpu": return ["cpu"] - + if device_map.lower() == "auto": + device_count = detect_device_count() + if "cuda" in device_types: + return [f"cuda:{i}" for i in range(device_count)] + elif "xpu" in device_types: + return [f"xpu:{i}" for i in range(device_count)] + elif "hpu" in device_types: + return [f"hpu:{i}" for i in range(device_count)] + else: + return ["cpu"] # Split by commas parts = [x.strip() for x in device_map.split(",") if x.strip()] parsed = [] @@ -1283,7 +1296,7 @@ def parse_available_devices(device_map: Union[str, torch.device, int, dict, None parsed.append(f"{device_type}:{p}" if device_type != "cpu" else "cpu") else: parsed.append(p) - return parsed + return list(set(parsed)) if isinstance(device_map, dict): # Extract all devices recursively from dict values diff --git a/test/test_cpu/test_scheme.py b/test/test_cpu/test_scheme.py index d0d29a441..c2d165639 100644 --- a/test/test_cpu/test_scheme.py +++ b/test/test_cpu/test_scheme.py @@ -118,6 +118,17 @@ def test_scheme_in_layer_config(self): if n == "model.decoder.layers.4.self_attn.k_proj": self.assertEqual(m.group_size, 64) + def test_parse_available_devices(self): + from auto_round.utils.device import parse_available_devices + + device_list = parse_available_devices("auto") + self.assertTrue(len(device_list) == 1 and "cpu" in device_list) + device_list = parse_available_devices("a:cuda:0,b:cuda:1,c:cpu") + self.assertTrue(len(device_list) == 3) + self.assertEqual(device_list, ["cuda:0", "cuda:1", "cpu"]) + device_list = parse_available_devices("0,1") + self.assertTrue(len(device_list) == 1 and "cpu" in device_list) + if __name__ == "__main__": unittest.main() diff --git a/test/test_cuda/test_gguf.py b/test/test_cuda/test_gguf.py index fe4388667..4935ada6f 100644 --- a/test/test_cuda/test_gguf.py +++ b/test/test_cuda/test_gguf.py @@ -88,7 +88,7 @@ def test_q2_k_export(self): quantized_model_path = "./saved" autoround.save_quantized(output_dir=quantized_model_path, inplace=False, format="gguf:q2_k_s") - gguf_file = "Qwen2.5-1.5B-Instruct-1.5B-Q2_K_S.gguf" + gguf_file = os.listdir(quantized_model_path)[0] model = AutoModelForCausalLM.from_pretrained(quantized_model_path, gguf_file=gguf_file, device_map="auto") text = "There is a girl who likes adventure," inputs = autoround.tokenizer(text, return_tensors="pt").to(model.device) @@ -123,7 +123,7 @@ def test_q4_0(self): quantized_model_path = "./saved" autoround.save_quantized(output_dir=quantized_model_path, inplace=False, format="gguf:q4_0") - gguf_file = "Qwen2.5-0.5B-Instruct-494M-Q4_0.gguf" + gguf_file = os.listdir(quantized_model_path)[0] model = AutoModelForCausalLM.from_pretrained(quantized_model_path, gguf_file=gguf_file, device_map="auto") text = "There is a girl who likes adventure," inputs = autoround.tokenizer(text, return_tensors="pt").to(model.device) @@ -144,7 +144,7 @@ def test_q4_1(self): quantized_model_path = "./saved" autoround.save_quantized(output_dir=quantized_model_path, inplace=False, format="gguf:q4_1") - gguf_file = "Qwen2.5-0.5B-Instruct-494M-Q4_1.gguf" + gguf_file = os.listdir(quantized_model_path)[0] model = AutoModelForCausalLM.from_pretrained(quantized_model_path, gguf_file=gguf_file, device_map="auto") text = "There is a girl who likes adventure," inputs = autoround.tokenizer(text, return_tensors="pt").to(model.device) @@ -186,27 +186,25 @@ def test_vlm_gguf(self): from auto_round import AutoRoundMLLM from auto_round.utils import mllm_load_model - model, processor, tokenizer, image_processor = mllm_load_model(model_name) - autoround = AutoRoundMLLM( - model, - tokenizer=tokenizer, - processor=processor, - image_processor=image_processor, - device="auto", - iters=0, - ) - quantized_model_path = "./saved" - autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0") - self.assertTrue("mmproj-model.gguf" in os.listdir("./saved")) - file_size = os.path.getsize("./saved/Qwen2.5-VL-7B-Instruct-7.6B-Q4_0.gguf") / 1024**2 - self.assertAlmostEqual(file_size, 4226, delta=1.0) - file_size = os.path.getsize("./saved/mmproj-model.gguf") / 1024**2 - self.assertAlmostEqual(file_size, 2578, delta=1.0) - shutil.rmtree("./saved", ignore_errors=True) + # model, processor, tokenizer, image_processor = mllm_load_model(model_name) + # autoround = AutoRoundMLLM( + # model, + # tokenizer=tokenizer, + # processor=processor, + # image_processor=image_processor, + # device="auto", + # iters=0, + # ) + # quantized_model_path = "./saved" + # autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0") + # self.assertTrue("mmproj-model.gguf" in os.listdir("./saved")) + # file_size = os.path.getsize("./saved/Qwen2.5-VL-7B-Instruct-Q4_0.gguf") / 1024**2 + # self.assertAlmostEqual(file_size, 4226, delta=5.0) + # file_size = os.path.getsize("./saved/mmproj-model.gguf") / 1024**2 + # self.assertAlmostEqual(file_size, 2580, delta=5.0) + # shutil.rmtree("./saved", ignore_errors=True) model_name = "/models/gemma-3-12b-it" - from auto_round import AutoRoundMLLM - from auto_round.utils import mllm_load_model model, processor, tokenizer, image_processor = mllm_load_model(model_name) autoround = AutoRoundMLLM( @@ -221,10 +219,10 @@ def test_vlm_gguf(self): quantized_model_path = "./saved" autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_k_m") self.assertTrue("mmproj-model.gguf" in os.listdir("./saved")) - file_size = os.path.getsize("./saved/gemma-3-12b-it-12B-Q4_K_M.gguf") / 1024**2 - self.assertAlmostEqual(file_size, 6568, delta=1.0) + file_size = os.path.getsize("./saved/gemma-3-12B-it-Q4_K_M.gguf") / 1024**2 + self.assertAlmostEqual(file_size, 6568, delta=5.0) file_size = os.path.getsize("./saved/mmproj-model.gguf") / 1024**2 - self.assertAlmostEqual(file_size, 1599, delta=1.0) + self.assertAlmostEqual(file_size, 1599, delta=5.0) shutil.rmtree(quantized_model_path, ignore_errors=True) # @require_gguf diff --git a/test/test_cuda/test_vlms.py b/test/test_cuda/test_vlms.py index eee7c055a..d06c48ff5 100644 --- a/test/test_cuda/test_vlms.py +++ b/test/test_cuda/test_vlms.py @@ -140,7 +140,6 @@ def test_mllm_detect(self): "/models/Phi-3.5-vision-instruct", "/models/Qwen2-VL-2B-Instruct", "/models/SmolVLM-256M-Instruct", - "/models/Llama-4-Maverick-17B-128E-Instruct", "/models/Mistral-Small-3.1-24B-Instruct-2503", "/models/InternVL3-1B", "/models/pixtral-12b", From 69c20b5925c772adbe94a1d581168f1f6ea98872 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 11 Nov 2025 02:06:14 -0500 Subject: [PATCH 2/4] fix Signed-off-by: n1ck-guo --- auto_round/compressors/utils.py | 17 ++++++++++++++--- test/test_cuda/test_gguf.py | 2 +- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index 5aabe3969..2718cd531 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -323,9 +323,20 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str cfg.setdefault(key, copy.deepcopy(default_dict.get(key))) # 5. collect supported modules + embedding_types = (torch.nn.Embedding,) gguf_name = get_gguf_scheme(default_scheme) - if gguf_name and torch.nn.Embedding not in supported_types: - supported_types = (*supported_types, torch.nn.Embedding) + if gguf_name: + if torch.nn.Embedding not in supported_types: + supported_types = (*supported_types, torch.nn.Embedding) + + # for some Embedding which type() is not torch.nn.Embedding + # for example: transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding + model_module_name = model.__class__.__module__ + module_cls = sys.modules[model_module_name] + for name in module_cls.__dict__: + if name.endswith("Embedding") and not name.endswith("RotaryEmbedding"): + embedding_types = (*embedding_types, getattr(module_cls, name)) + supported_types = (*supported_types, *embedding_types) all_supported_layer_names, embedding_layer_names = [], [] all_module_names = [] @@ -338,7 +349,7 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str if type(m) not in supported_types and m.__class__.__name__ not in inner_supported_types: continue all_supported_layer_names.append(n) - if isinstance(m, torch.nn.Embedding): + if isinstance(m, embedding_types) or m.__class__.__name__.endswith("Embedding"): embedding_layer_names.append(n) # 6. expand regex configs diff --git a/test/test_cuda/test_gguf.py b/test/test_cuda/test_gguf.py index 4935ada6f..672b8457c 100644 --- a/test/test_cuda/test_gguf.py +++ b/test/test_cuda/test_gguf.py @@ -214,7 +214,7 @@ def test_vlm_gguf(self): image_processor=image_processor, device="auto", nsamples=32, - iters=1, + iters=0, ) quantized_model_path = "./saved" autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_k_m") From 653ffadc24d7813ebcf446321431b2f0c5ae2e8a Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 11 Nov 2025 03:23:54 -0500 Subject: [PATCH 3/4] fix gguf for mllm Signed-off-by: n1ck-guo --- auto_round/compressors/utils.py | 39 ++++++++++++++++++++++++++------- auto_round/utils/common.py | 18 +++++++++++++++ auto_round/utils/model.py | 19 ++-------------- test/test_cuda/test_gguf.py | 36 +++++++++++++++--------------- 4 files changed, 69 insertions(+), 43 deletions(-) diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index 2718cd531..566e88f3b 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -661,7 +661,7 @@ def get_layer_config_by_gguf_format(layer_config, target_gguf_format: str, model import gguf # pylint: disable=E0401 - from auto_round.utils.common import LazyImport + from auto_round.utils.common import MM_KEYS, LazyImport from auto_round.utils.model import get_lm_head_name, get_module # from auto_round.export.export_to_gguf.convert import ModelBase, get_model_architecture @@ -671,24 +671,37 @@ def get_layer_config_by_gguf_format(layer_config, target_gguf_format: str, model hparams=model.config.to_dict(), model_type=model_type ) try: - model_class = convert_hf_to_gguf.ModelBase.from_model_architecture(model_architecture, model_type=model_type) + if model_type != ModelType.TEXT: + model_class_vision = convert_hf_to_gguf.ModelBase.from_model_architecture( + model_architecture, model_type=model_type + ) + model_class = convert_hf_to_gguf.ModelBase.from_model_architecture( + model_architecture, model_type=ModelType.TEXT + ) + except NotImplementedError: return layer_config, {} n_layer = None + if model_type != ModelType.TEXT: + n_layer_vision = None for name in ["n_layers", "num_hidden_layers", "n_layer", "num_layers"]: - sub_attr = "text_config" if model_type == ModelType.TEXT else "vision_config" if hasattr(model.config, name): n_layer = getattr(model.config, name) break - if hasattr(model.config, sub_attr): - if hasattr(getattr(model.config, sub_attr), name): - n_layer = getattr(getattr(model.config, sub_attr), name) - break + if model_type != ModelType.TEXT: + if hasattr(model.config, "text_config"): + if hasattr(getattr(model.config, "text_config"), name): + n_layer = getattr(getattr(model.config, "text_config"), name) + if hasattr(model.config, "vision_config"): + if hasattr(getattr(model.config, "vision_config"), name): + n_layer_vision = getattr(getattr(model.config, "vision_config"), name) if n_layer is None: return layer_config, {} tensor_map = gguf.get_tensor_name_map(model_class.model_arch, n_layer) + if model_type != ModelType.TEXT: + tensor_map_vision = gguf.get_tensor_name_map(model_class_vision.model_arch, n_layer_vision) def _set_config(config, target_config): for k, v in target_config.items(): @@ -744,7 +757,17 @@ def _set_config(config, target_config): re.search("gguf:q([0-9]{1,})_[01k]", GGUF_CONFIG[target_gguf_format]["embedding"]).group(1) ) - gguf_name = tensor_map.get_name(layer_name) + if model_type != ModelType.TEXT and any([key in layer_name for key in MM_KEYS]): + gguf_name = tensor_map_vision.get_name(layer_name) + if gguf_name is None: + for key in MM_KEYS: + gguf_name = tensor_map_vision.get_name(layer_name.replace(f".{key}", "")) + if gguf_name is not None: + break + else: + gguf_name = tensor_map.get_name(layer_name) + if gguf_name is None: + gguf_name = tensor_map.get_name(layer_name.replace(".language_model", "")) bits_index = 6 if config.get("fixed_by_user", False): if "bits" not in config: diff --git a/auto_round/utils/common.py b/auto_round/utils/common.py index 9d4e4c98a..b1a5e18a6 100644 --- a/auto_round/utils/common.py +++ b/auto_round/utils/common.py @@ -124,6 +124,24 @@ def __getitem__(self, key): SUPPORTED_LAYER_TYPES = SUPPORTED_LAYER_TYPES + (LinearLayer, LinearAllreduce) +MM_KEYS = [ + "multi_modal_projector", + "vision_tower", + "multimodal_projector", + "thinker", + "visual", + "audio", + "talker", + "token2wav", + "vision_model", + "audio_tower", + "vision_encoder", + "vision_language_adapter", + "patch_merger", + "pre_mm_projector_norm", + "vision", +] + def is_debug_mode(): """Checks if the Python interpreter is running in debug mode. diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 5f9e1941b..6188ea16a 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -497,23 +497,8 @@ def is_pure_text_model(model): def is_mllm_model(model_or_path: Union[str, torch.nn.Module], platform: str = None): - MM_KEYS = [ - "multi_modal_projector", - "vision_tower", - "multimodal_projector", - "thinker", - "visual", - "audio", - "talker", - "token2wav", - "vision_model", - "audio_tower", - "vision_encoder", - "vision_language_adapter", - "patch_merger", - "pre_mm_projector_norm", - "vision", - ] + from auto_round.utils.common import MM_KEYS + model_path = model_or_path if isinstance(model_or_path, str) else model_or_path.name_or_path if not os.path.isdir(model_path): model_path = download_or_get_path(model_path, platform=platform) diff --git a/test/test_cuda/test_gguf.py b/test/test_cuda/test_gguf.py index 672b8457c..aaad1cf6a 100644 --- a/test/test_cuda/test_gguf.py +++ b/test/test_cuda/test_gguf.py @@ -186,23 +186,23 @@ def test_vlm_gguf(self): from auto_round import AutoRoundMLLM from auto_round.utils import mllm_load_model - # model, processor, tokenizer, image_processor = mllm_load_model(model_name) - # autoround = AutoRoundMLLM( - # model, - # tokenizer=tokenizer, - # processor=processor, - # image_processor=image_processor, - # device="auto", - # iters=0, - # ) - # quantized_model_path = "./saved" - # autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0") - # self.assertTrue("mmproj-model.gguf" in os.listdir("./saved")) - # file_size = os.path.getsize("./saved/Qwen2.5-VL-7B-Instruct-Q4_0.gguf") / 1024**2 - # self.assertAlmostEqual(file_size, 4226, delta=5.0) - # file_size = os.path.getsize("./saved/mmproj-model.gguf") / 1024**2 - # self.assertAlmostEqual(file_size, 2580, delta=5.0) - # shutil.rmtree("./saved", ignore_errors=True) + model, processor, tokenizer, image_processor = mllm_load_model(model_name) + autoround = AutoRoundMLLM( + model, + tokenizer=tokenizer, + processor=processor, + image_processor=image_processor, + device="auto", + iters=0, + ) + quantized_model_path = "./saved" + autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0") + self.assertTrue("mmproj-model.gguf" in os.listdir("./saved")) + file_size = os.path.getsize("./saved/Qwen2.5-VL-7B-Instruct-Q4_0.gguf") / 1024**2 + self.assertAlmostEqual(file_size, 4226, delta=5.0) + file_size = os.path.getsize("./saved/mmproj-model.gguf") / 1024**2 + self.assertAlmostEqual(file_size, 2580, delta=5.0) + shutil.rmtree("./saved", ignore_errors=True) model_name = "/models/gemma-3-12b-it" @@ -220,7 +220,7 @@ def test_vlm_gguf(self): autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_k_m") self.assertTrue("mmproj-model.gguf" in os.listdir("./saved")) file_size = os.path.getsize("./saved/gemma-3-12B-it-Q4_K_M.gguf") / 1024**2 - self.assertAlmostEqual(file_size, 6568, delta=5.0) + self.assertAlmostEqual(file_size, 6962, delta=5.0) file_size = os.path.getsize("./saved/mmproj-model.gguf") / 1024**2 self.assertAlmostEqual(file_size, 1599, delta=5.0) shutil.rmtree(quantized_model_path, ignore_errors=True) From 188850b50874f34c257f8e348c00cd4ce324fe9e Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 11 Nov 2025 04:37:54 -0500 Subject: [PATCH 4/4] fix Signed-off-by: n1ck-guo --- auto_round/compressors/utils.py | 16 ++++++++++------ auto_round/export/export_to_gguf/export.py | 4 ++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index 566e88f3b..cf8d90098 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -685,17 +685,21 @@ def get_layer_config_by_gguf_format(layer_config, target_gguf_format: str, model n_layer = None if model_type != ModelType.TEXT: n_layer_vision = None - for name in ["n_layers", "num_hidden_layers", "n_layer", "num_layers"]: + for name in ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]: if hasattr(model.config, name): n_layer = getattr(model.config, name) - break if model_type != ModelType.TEXT: - if hasattr(model.config, "text_config"): + if n_layer is not None and hasattr(model.config, "text_config"): if hasattr(getattr(model.config, "text_config"), name): n_layer = getattr(getattr(model.config, "text_config"), name) - if hasattr(model.config, "vision_config"): - if hasattr(getattr(model.config, "vision_config"), name): - n_layer_vision = getattr(getattr(model.config, "vision_config"), name) + for config_name in ["vision_config", "vision_encoder"]: + if hasattr(model.config, config_name): + if hasattr(getattr(model.config, config_name), name): + n_layer_vision = getattr(getattr(model.config, config_name), name) + break + if n_layer and n_layer_vision: + break + if n_layer is None: return layer_config, {} diff --git a/auto_round/export/export_to_gguf/export.py b/auto_round/export/export_to_gguf/export.py index 8633a2a50..890a93880 100644 --- a/auto_round/export/export_to_gguf/export.py +++ b/auto_round/export/export_to_gguf/export.py @@ -133,8 +133,8 @@ def pack_gguf_layer( ): """Export the model to gguf format.""" global gguf_model_instance_global - if output_dir is not None and os.path.exists(output_dir): - logger.warning_once(f"{output_dir} already exists, this may cause model conflict") + # if output_dir is not None and os.path.exists(output_dir): + # logger.warning_once(f"{output_dir} already exists, this may cause model conflict") if "gguf_model_instance_global" not in globals(): config = model.config