From 0055d30e84d54a4f012d75e1e30e8ce7c6d12217 Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Sun, 2 Nov 2025 20:52:59 -0500 Subject: [PATCH 01/19] add moe support Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 1 + auto_round/utils/device.py | 109 +++++++++++++++++++++++++++++---- 2 files changed, 99 insertions(+), 11 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index da325c989..7f4cfbdda 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2760,6 +2760,7 @@ def _quantize_blocks( modules = [get_module(model, n) for n in names] m = WrapperMultiblock(modules) + m.config = model.config if hasattr(model, "config") else None q_input, input_ids = quantize_block( m, input_ids, diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 3ea3e63e0..91295ec7d 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -732,6 +732,74 @@ def find_best_device(layer_name, estimated_memory, layer_idx): return ordered_device_map, names +def get_moe_memory_ratio(block: torch.nn.Module) -> float: + """ + Calculate the memory ratio for MoE (Mixture of Experts) models. + + For MoE models, only num_experts_per_tok experts are activated per token, + not all experts. This function returns the ratio of active experts to total experts. + + Args: + block (torch.nn.Module): The model block to analyze. + + Returns: + float: Memory ratio (num_experts_per_tok / num_experts). + Returns 1.0 for non-MoE models. + + Examples: + - Non-MoE model: returns 1.0 + - Mixtral (2/8 experts): returns 0.25 + - Qwen2MoE (4/60 experts): returns ~0.067 + """ + from auto_round.utils.model import is_moe + + for name, module in block.named_modules(): + if is_moe(module): + config = getattr(block, "config", None) + if config is None: + break + + # Try to get num_experts_per_tok (active experts count) + num_experts_per_tok = getattr( + config, "num_experts_per_tok", None + ) # Mixtral, Qwen2MoE, DeepSeek, GPT-OSS, Llama4, LLaDAMoE + if num_experts_per_tok is None: + num_experts_per_tok = getattr(config, "moe_num_active_primary_experts", None) # SmallThinker + if num_experts_per_tok is None: + # HunYuan MoE uses moe_topk (array), get first element + moe_topk = getattr(config, "moe_topk", None) # HunYuan MoE V1 + if moe_topk is not None and isinstance(moe_topk, (list, tuple)) and len(moe_topk) > 0: + num_experts_per_tok = moe_topk[0] + elif moe_topk is not None: + num_experts_per_tok = moe_topk + + if num_experts_per_tok is None: + break + + # Get total number of experts + num_experts = getattr(config, "num_local_experts", None) # Mixtral, PhiMoE, Grok, Llama4 + if num_experts is None: + num_experts = getattr( + config, "num_experts", None + ) # Qwen2MoE, Olmo, BailingMoE, GroveMoE, HunYuan, LLaDAMoE + if num_experts is None: + num_experts = getattr(config, "moe_num_primary_experts", None) # SmallThinker + if num_experts is None: + num_experts = getattr(config, "n_routed_experts", None) # DeepSeek + + if num_experts is not None and num_experts > 0: + moe_ratio = num_experts_per_tok / num_experts + logger.debug( + f"MoE detected: {num_experts_per_tok}/{num_experts} experts active per token, " + f"activation memory ratio: {moe_ratio:.2f}" + ) + logger.debug(f"Using MoE memory ratio: {moe_ratio:.4f}") + return moe_ratio + break # Only check once per block + + return 1.0 # Default ratio for non-MoE models + + def estimate_tuning_block_mem( block: torch.nn.Module, input_ids: list[torch.Tensor], pick_samples: int ) -> tuple[dict, float]: @@ -746,21 +814,22 @@ def estimate_tuning_block_mem( Returns: tuple: A tuple containing the following: - layer_memory_dict (dict): A dictionary mapping layer names to their memory consumption (in GB). - Format: {layer_name: {"param_memory": float, "output_memory": float}} + Format: {layer_name: {"param_memory": float, "output_memory": float}} - input_output_memory (float): The memory consumption (in GB) for input and output tensors of the block. - additional_memory (float): Additional memory overhead (in GB) for operations like attention. """ # Calculate all block parameters memory and build layer-wise memory dict - from auto_round.utils.model import get_layer_features + from auto_round.utils.model import get_layer_features, is_moe layer_memory_dict = {} - total_param_mem = 0 # Calculate batch_size and sequence_length from input_ids for output memory estimation seq_len = input_ids[0].shape[1] if input_ids and len(input_ids[0].shape) >= 2 else 1 element_size = input_ids[0].element_size() if input_ids else 2 # Default to 2 bytes (fp16/bf16) + moe_ratio = get_moe_memory_ratio(block) # Get MoE memory ratio (1.0 for non-MoE models) + for name, module in block.named_modules(): if check_to_quantized(module): enable_act_quant = module.act_bits <= 8 @@ -785,15 +854,33 @@ def estimate_tuning_block_mem( output_memory_gb = 0.0 # memory * 2, because it contains grad tensor. - layer_memory_dict[layer_name] = {"param_memory": param_memory_gb * 2, "output_memory": output_memory_gb * 2} + # Check if this is a MoE expert layer by layer name (e.g., "mlp.experts.0.gate_proj") + is_moe_expert = "expert" in layer_name.lower() + layer_memory_dict[layer_name] = { + "param_memory": param_memory_gb * 2, + "output_memory": output_memory_gb * 2, + "is_moe_expert": is_moe_expert, + } # Assuming bfloat16 or float32, input and output block_input_output_memory = 2 * sum(tensor.nbytes for tensor in input_ids) / 1024**3 # Roughly estimate additional memory for attention and other operations - additional_activation_memory = sum(info["output_memory"] for info in layer_memory_dict.values()) + # For MoE expert layers, multiply activation memory by the ratio of active experts + # For non-MoE layers (attention, norm, etc.), use full activation memory + layer_activation_memory = 0.0 + for layer_name, info in layer_memory_dict.items(): + if info.get("is_moe_expert", False): + # MoE expert layer: only a fraction of experts are active + # memory * 2 records intermediate activation memory of GeLU, or etc. + layer_activation_memory += info["output_memory"] * moe_ratio * 2 + else: + # Non-MoE layer: use full activation memory + layer_activation_memory += info["output_memory"] + + # layer_activation_memory considers other ops activation memory # 1GB considers norm weight, sdpa, reference_output, etc. - additional_memory = additional_activation_memory + 1 # GB + additional_memory = layer_activation_memory + 1 # GB if torch.xpu.is_available(): # https://github.com/intel/torch-xpu-ops/issues/2232 # TODO: XPU takes more memory than expected. for llama 8B, it's about 12 GB @@ -802,7 +889,7 @@ def estimate_tuning_block_mem( logger.warning_once("XPU additional memory usage of SDPA is estimated to be 12 GB.") logger.warning_once("Remove it after https://github.com/intel/torch-xpu-ops/issues/2232 is fixed.") - return layer_memory_dict, block_input_output_memory, additional_memory + return layer_memory_dict, layer_activation_memory, block_input_output_memory, additional_memory def set_auto_device_map_for_block_with_tuning( @@ -853,21 +940,21 @@ def set_auto_device_map_for_block_with_tuning( device_0 = f"{device_name}:0" device_0_memory = get_device_memory(device_list[0] if device_list else 0) - layer_memory_dict, block_input_output_memory, additional_memory = estimate_tuning_block_mem( - block, input_ids, pick_samples + layer_memory_dict, layer_activation_memory, block_input_output_memory, additional_memory = ( + estimate_tuning_block_mem(block, input_ids, pick_samples) ) if low_gpu_mem_usage: block_input_output_memory = 0 # Calculate total block memory from layer memory dict (including both param and output memory) total_block_param_memory = sum(info["param_memory"] for info in layer_memory_dict.values()) - total_block_output_memory = sum(info["output_memory"] for info in layer_memory_dict.values()) + total_block_output_memory = layer_activation_memory # Average dispatch strategy # card_0_left_memory = card_0_mem - block_input_output_memory - additional_memory - layer_outputs_memory logger.debug("Card 0 used memory details [Estimated]:") logger.debug(f" Block input output cache memory: {block_input_output_memory} GB") - logger.debug(f" Quantized layer outputs memory: {total_block_output_memory} GB") + logger.debug(f" Quantized layer outputs memory: {layer_activation_memory} GB") logger.debug(f" Additional_memory from other ops: {additional_memory} GB") card_0_left_memory = max( From 29df3579a1acefb4376d10bcc62f4ccebd58600c Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Mon, 3 Nov 2025 01:15:32 -0500 Subject: [PATCH 02/19] use low_gpu_mem_usage to cache best params Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 15 +++++++++------ auto_round/compressors/utils.py | 4 +++- auto_round/utils/device.py | 3 ++- auto_round/wrapper.py | 4 ++-- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 7f4cfbdda..dea3dd2fb 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2321,10 +2321,10 @@ def _quantize_layer( if total_loss < best_loss: best_loss = total_loss if not self.not_use_best_mse: - best_params = collect_best_params(wrapper_linear) + best_params = collect_best_params(wrapper_linear, self.low_gpu_mem_usage) last_best_iter = i if self.not_use_best_mse and i == self.iters - 1: - best_params = collect_best_params(wrapper_linear) + best_params = collect_best_params(wrapper_linear, self.low_gpu_mem_usage) if not self.not_use_best_mse: if 0 < self.dynamic_max_gap <= i - last_best_iter: @@ -2603,8 +2603,10 @@ def _quantize_block( ) total_loss += loss.item() / num_elm + # Sometimes the cached memory is not released during training and cause OOM + if self.low_gpu_mem_usage: + clear_memory_if_reached_threshold(threshold=0.85) self._scale_loss_and_backward(scaler, loss) - clear_memory_if_reached_threshold(threshold=0.85) if i == 0: init_loss = total_loss @@ -2612,12 +2614,12 @@ def _quantize_block( if total_loss < best_loss: best_loss = total_loss if not self.not_use_best_mse: - best_params = collect_best_params(block) + best_params = collect_best_params(block, self.low_gpu_mem_usage) # print(f"get better result at iter {i}, the loss is {total_loss}", flush=True) last_best_iter = i if self.not_use_best_mse and i == self.iters - 1: - best_params = collect_best_params(block) + best_params = collect_best_params(block, self.low_gpu_mem_usage) if not self.not_use_best_mse: if 0 < self.dynamic_max_gap <= i - last_best_iter: @@ -2634,6 +2636,8 @@ def _quantize_block( f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" ) logger.info(dump_info) + if self.low_gpu_mem_usage: + clear_memory() # clear cached memory during training if len(unquantized_layer_names) != 0: logger.info(f"{unquantized_layer_names} have not been quantized") with torch.no_grad(): @@ -2644,7 +2648,6 @@ def _quantize_block( set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max") if self.enable_quanted_input: - clear_memory() q_outputs = self._get_block_outputs( block, input_ids, diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index 79e8b7133..dc541534e 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -199,13 +199,15 @@ def check_awq_gemm_compatibility(model, bits, group_size, sym, layer_configs=Non return True, "" -def collect_best_params(block): +def collect_best_params(block, low_gpu_mem_usage: bool = False): params = {} for n, m in block.named_modules(): if hasattr(m, "orig_layer"): params[n] = {} for key in m.params.keys(): params[n][key] = copy.deepcopy(m.params[key].data) + if low_gpu_mem_usage: + params[n][key] = params[n][key].cpu() return params diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 91295ec7d..0130745a2 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -431,7 +431,7 @@ def clear_memory_if_reached_threshold(threshold=0.85): elif hasattr(torch, "xpu") and torch.xpu.is_available(): name, device_api = "XPU", torch.xpu else: - return + return False num_devices = device_api.device_count() for i in range(num_devices): @@ -452,6 +452,7 @@ def clear_memory_if_reached_threshold(threshold=0.85): return True except Exception as e: logger.warning_once(f"Failed to check memory for {name} device {i}: {e}") + return False def check_memory_availability(device, inputs, weight, org_seqlen, org_bs): diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index 67072d762..9b820c291 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -550,7 +550,7 @@ def __init__(self, orig_layer, bit=4, group_size=-1, device="cpu"): def unwrapper(self, best_params): if best_params is None: return self.orig_layer - v = best_params["v"] + v = best_params["v"].to(self.device) weight_q, _, _ = self.quant_func( self.orig_layer.weight, self.bits, self.group_size, v, q_scale_thresh=self.q_scale_thresh ) @@ -601,7 +601,7 @@ def __init__(self, orig_layer, bit=4, group_size=-1, device="cpu"): def unwrapper(self, best_params): if best_params is None: return self.orig_layer - v = best_params["v"] + v = best_params["v"].to(self.device) weight_q, _, _ = self.quant_func( self.orig_layer.weight, self.bits, self.group_size, v, q_scale_thresh=self.q_scale_thresh ) From d8917f9dcd534e425c2cdfd1e478b5c4326278d6 Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Mon, 3 Nov 2025 02:59:16 -0500 Subject: [PATCH 03/19] add warning for memory estimation Signed-off-by: He, Xin3 --- auto_round/utils/device.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 0130745a2..6e017f264 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -746,6 +746,7 @@ def get_moe_memory_ratio(block: torch.nn.Module) -> float: Returns: float: Memory ratio (num_experts_per_tok / num_experts). Returns 1.0 for non-MoE models. + bool: True if the model is MoE, False otherwise. Examples: - Non-MoE model: returns 1.0 @@ -795,10 +796,10 @@ def get_moe_memory_ratio(block: torch.nn.Module) -> float: f"activation memory ratio: {moe_ratio:.2f}" ) logger.debug(f"Using MoE memory ratio: {moe_ratio:.4f}") - return moe_ratio + return moe_ratio, True break # Only check once per block - return 1.0 # Default ratio for non-MoE models + return 1.0, False # Default ratio for non-MoE models def estimate_tuning_block_mem( @@ -829,7 +830,7 @@ def estimate_tuning_block_mem( seq_len = input_ids[0].shape[1] if input_ids and len(input_ids[0].shape) >= 2 else 1 element_size = input_ids[0].element_size() if input_ids else 2 # Default to 2 bytes (fp16/bf16) - moe_ratio = get_moe_memory_ratio(block) # Get MoE memory ratio (1.0 for non-MoE models) + moe_ratio, has_moe = get_moe_memory_ratio(block) # Get MoE memory ratio (1.0 for non-MoE models) for name, module in block.named_modules(): if check_to_quantized(module): @@ -856,7 +857,8 @@ def estimate_tuning_block_mem( # memory * 2, because it contains grad tensor. # Check if this is a MoE expert layer by layer name (e.g., "mlp.experts.0.gate_proj") - is_moe_expert = "expert" in layer_name.lower() + parent_module = get_module(block, layer_name.rsplit(".", 1)[0]) if "." in layer_name else block + is_moe_expert = "expert" in layer_name.lower() and isinstance(parent_module, torch.nn.ModuleList) layer_memory_dict[layer_name] = { "param_memory": param_memory_gb * 2, "output_memory": output_memory_gb * 2, @@ -873,8 +875,7 @@ def estimate_tuning_block_mem( for layer_name, info in layer_memory_dict.items(): if info.get("is_moe_expert", False): # MoE expert layer: only a fraction of experts are active - # memory * 2 records intermediate activation memory of GeLU, or etc. - layer_activation_memory += info["output_memory"] * moe_ratio * 2 + layer_activation_memory += info["output_memory"] * moe_ratio else: # Non-MoE layer: use full activation memory layer_activation_memory += info["output_memory"] @@ -882,13 +883,20 @@ def estimate_tuning_block_mem( # layer_activation_memory considers other ops activation memory # 1GB considers norm weight, sdpa, reference_output, etc. additional_memory = layer_activation_memory + 1 # GB + if has_moe: + # TODO: Cannot estimate the memory usage correctly for MoE models yet. + # For MoE models, additional memory usage can be higher due to routing, gating, + # and multiple expert activations. Here we use a conservative estimate. + moe_additional_memory = additional_memory * 3 # GB + additional_memory += moe_additional_memory if torch.xpu.is_available(): # https://github.com/intel/torch-xpu-ops/issues/2232 # TODO: XPU takes more memory than expected. for llama 8B, it's about 12 GB xpu_additional_memory = 12 # GB additional_memory += xpu_additional_memory - logger.warning_once("XPU additional memory usage of SDPA is estimated to be 12 GB.") - logger.warning_once("Remove it after https://github.com/intel/torch-xpu-ops/issues/2232 is fixed.") + logger.warning_once( + "[Memory Estimation]: If there is an abnormal memory issue, please collect log with AR_LOG_LEVEL=debug and raise issue to us." + ) return layer_memory_dict, layer_activation_memory, block_input_output_memory, additional_memory @@ -947,20 +955,17 @@ def set_auto_device_map_for_block_with_tuning( if low_gpu_mem_usage: block_input_output_memory = 0 - # Calculate total block memory from layer memory dict (including both param and output memory) total_block_param_memory = sum(info["param_memory"] for info in layer_memory_dict.values()) - total_block_output_memory = layer_activation_memory # Average dispatch strategy # card_0_left_memory = card_0_mem - block_input_output_memory - additional_memory - layer_outputs_memory - logger.debug("Card 0 used memory details [Estimated]:") + card_0_used_memory = block_input_output_memory + layer_activation_memory + additional_memory + logger.debug(f"Card 0 used memory details [Estimated]: {card_0_used_memory} GB") logger.debug(f" Block input output cache memory: {block_input_output_memory} GB") logger.debug(f" Quantized layer outputs memory: {layer_activation_memory} GB") logger.debug(f" Additional_memory from other ops: {additional_memory} GB") - card_0_left_memory = max( - 0, device_0_memory - block_input_output_memory - total_block_output_memory - additional_memory - ) + card_0_left_memory = max(0, (device_0_memory - card_0_used_memory)) # Calculate total available memory across all devices total_available_memory = card_0_left_memory From e04db30c21d26ed2d8750f9e1ee9e9f4f6a6016c Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Mon, 3 Nov 2025 03:48:59 -0500 Subject: [PATCH 04/19] fix bug Signed-off-by: He, Xin3 --- auto_round/compressors/utils.py | 5 +++-- auto_round/utils/device.py | 9 ++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index dc541534e..94b67b3f3 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -205,9 +205,10 @@ def collect_best_params(block, low_gpu_mem_usage: bool = False): if hasattr(m, "orig_layer"): params[n] = {} for key in m.params.keys(): - params[n][key] = copy.deepcopy(m.params[key].data) if low_gpu_mem_usage: - params[n][key] = params[n][key].cpu() + params[n][key] = m.params[key].data.cpu() + else: + params[n][key] = copy.deepcopy(m.params[key].data) return params diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 6e017f264..50fa0d1e3 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -855,10 +855,13 @@ def estimate_tuning_block_mem( else: output_memory_gb = 0.0 + if has_moe: + pparent_module = get_module(block, layer_name.rsplit(".", 2)[0]) if "." in layer_name else block + is_moe_expert = "expert" in layer_name.lower() and isinstance(pparent_module, torch.nn.ModuleList) + else: + is_moe_expert = False + # memory * 2, because it contains grad tensor. - # Check if this is a MoE expert layer by layer name (e.g., "mlp.experts.0.gate_proj") - parent_module = get_module(block, layer_name.rsplit(".", 1)[0]) if "." in layer_name else block - is_moe_expert = "expert" in layer_name.lower() and isinstance(parent_module, torch.nn.ModuleList) layer_memory_dict[layer_name] = { "param_memory": param_memory_gb * 2, "output_memory": output_memory_gb * 2, From c1568dcf9ec29a398fca389a8492b05454128441 Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Mon, 3 Nov 2025 05:00:33 -0500 Subject: [PATCH 05/19] support ds on CUDA and 70b on XPU Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 6 ++++-- auto_round/utils/device.py | 12 ++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index dea3dd2fb..4cd2d352d 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2604,9 +2604,11 @@ def _quantize_block( total_loss += loss.item() / num_elm # Sometimes the cached memory is not released during training and cause OOM - if self.low_gpu_mem_usage: - clear_memory_if_reached_threshold(threshold=0.85) + if self.low_gpu_mem_usage and torch.xpu.is_available(): + clear_memory_if_reached_threshold(threshold=0.5) self._scale_loss_and_backward(scaler, loss) + if self.low_gpu_mem_usage: + clear_memory_if_reached_threshold(threshold=0.8) if i == 0: init_loss = total_loss diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 50fa0d1e3..2cc8379a4 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -437,18 +437,14 @@ def clear_memory_if_reached_threshold(threshold=0.85): for i in range(num_devices): try: total_memory = device_api.get_device_properties(i).total_memory - allocated_memory = device_api.memory_reserved(i) if name == "CUDA" else device_api.memory_allocated(i) - memory_usage_ratio = allocated_memory / total_memory + reserved_memory = device_api.memory_reserved(i) + memory_usage_ratio = reserved_memory / total_memory if memory_usage_ratio >= threshold: logger.warning_once( - f"{name} device {i}: Memory usage {memory_usage_ratio*100:.2f}% " - f"exceeds threshold {threshold*100:.2f}%. Clearing memory..." + f"{name} device {i} has reached memory threshold. During the tuning process, a memory clearing operation will be called, which will result in more time consumption." ) clear_memory() - allocated_memory = device_api.memory_reserved(i) if name == "CUDA" else device_api.memory_allocated(i) - memory_usage_ratio = allocated_memory / total_memory - logger.warning_once(f"Cleared memory. {name} device {i}: Memory usage {memory_usage_ratio*100:.2f}%") return True except Exception as e: logger.warning_once(f"Failed to check memory for {name} device {i}: {e}") @@ -890,7 +886,7 @@ def estimate_tuning_block_mem( # TODO: Cannot estimate the memory usage correctly for MoE models yet. # For MoE models, additional memory usage can be higher due to routing, gating, # and multiple expert activations. Here we use a conservative estimate. - moe_additional_memory = additional_memory * 3 # GB + moe_additional_memory = additional_memory * 6 # GB additional_memory += moe_additional_memory if torch.xpu.is_available(): # https://github.com/intel/torch-xpu-ops/issues/2232 From 550439d85d882e1fbc58bebaa22142be3008823e Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Mon, 3 Nov 2025 06:58:54 -0500 Subject: [PATCH 06/19] fix oom of deepseek Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 4cd2d352d..444c436af 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2603,12 +2603,17 @@ def _quantize_block( ) total_loss += loss.item() / num_elm - # Sometimes the cached memory is not released during training and cause OOM - if self.low_gpu_mem_usage and torch.xpu.is_available(): - clear_memory_if_reached_threshold(threshold=0.5) + if self.low_gpu_mem_usage: + # Sometimes the cached memory is not released and cause OOM during backward + if torch.xpu.is_available(): + # TODO: whether to improve threshold for llama3.3 70b on 2x 24GB cards + clear_memory_if_reached_threshold(threshold=0.5) + else: + clear_memory_if_reached_threshold(threshold=0.85) self._scale_loss_and_backward(scaler, loss) if self.low_gpu_mem_usage: - clear_memory_if_reached_threshold(threshold=0.8) + # clear memory to avoid OOM due to memory fragmentation + clear_memory_if_reached_threshold(threshold=0.9) if i == 0: init_loss = total_loss From cbd02db78115fbf67a72d8d343a0145c6af02d1c Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Mon, 3 Nov 2025 07:01:09 -0500 Subject: [PATCH 07/19] fix pylint Signed-off-by: He, Xin3 --- auto_round/utils/device.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 2cc8379a4..823184a82 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -442,7 +442,9 @@ def clear_memory_if_reached_threshold(threshold=0.85): if memory_usage_ratio >= threshold: logger.warning_once( - f"{name} device {i} has reached memory threshold. During the tuning process, a memory clearing operation will be called, which will result in more time consumption." + f"{name} device {i} has reached memory threshold. " + + "Memory clearing operation will be called during each iteration, which " + + "will result in more time consumption." ) clear_memory() return True @@ -894,7 +896,8 @@ def estimate_tuning_block_mem( xpu_additional_memory = 12 # GB additional_memory += xpu_additional_memory logger.warning_once( - "[Memory Estimation]: If there is an abnormal memory issue, please collect log with AR_LOG_LEVEL=debug and raise issue to us." + "[Memory Estimation]: If there is an abnormal memory issue, please collect log with " + + "AR_LOG_LEVEL=debug and raise issue to us." ) return layer_memory_dict, layer_activation_memory, block_input_output_memory, additional_memory From f466677251b97ca3b16beb1bda7f474a99fb6784 Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Mon, 3 Nov 2025 07:10:50 -0500 Subject: [PATCH 08/19] threshold is 0.8 Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 444c436af..76db8bd31 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2609,11 +2609,11 @@ def _quantize_block( # TODO: whether to improve threshold for llama3.3 70b on 2x 24GB cards clear_memory_if_reached_threshold(threshold=0.5) else: - clear_memory_if_reached_threshold(threshold=0.85) + clear_memory_if_reached_threshold(threshold=0.8) self._scale_loss_and_backward(scaler, loss) if self.low_gpu_mem_usage: # clear memory to avoid OOM due to memory fragmentation - clear_memory_if_reached_threshold(threshold=0.9) + clear_memory_if_reached_threshold(threshold=0.8) if i == 0: init_loss = total_loss From 8942f58f40abd27f4326ee52f6de3e7864e8e14c Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Mon, 3 Nov 2025 21:43:51 -0500 Subject: [PATCH 09/19] Tighten constraints to avoid performance degradation cases Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 17 ++++++++--------- auto_round/utils/device.py | 8 +++++++- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 76db8bd31..0e63536b8 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2449,7 +2449,7 @@ def _quantize_block( set_module(block, n, new_layer) if self.device_map == "auto" or ((isinstance(self.device_map, str) and "," in self.device_map)): - set_auto_device_map_for_block_with_tuning( + card_0_in_high_risk = set_auto_device_map_for_block_with_tuning( block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device ) else: @@ -2603,15 +2603,14 @@ def _quantize_block( ) total_loss += loss.item() / num_elm - if self.low_gpu_mem_usage: - # Sometimes the cached memory is not released and cause OOM during backward - if torch.xpu.is_available(): - # TODO: whether to improve threshold for llama3.3 70b on 2x 24GB cards - clear_memory_if_reached_threshold(threshold=0.5) - else: - clear_memory_if_reached_threshold(threshold=0.8) + + if self.low_gpu_mem_usage and card_0_in_high_risk: + # clear memory to avoid OOM due to memory fragmentation + clear_memory_if_reached_threshold(threshold=0.5) + self._scale_loss_and_backward(scaler, loss) - if self.low_gpu_mem_usage: + + if self.low_gpu_mem_usage and card_0_in_high_risk: # clear memory to avoid OOM due to memory fragmentation clear_memory_if_reached_threshold(threshold=0.8) diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 823184a82..4b64c77b7 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -923,7 +923,10 @@ def set_auto_device_map_for_block_with_tuning( output_device (str | torch.device, optional): Device to move unassigned modules to. Defaults to None. Returns: - None + card_0_in_high_risk (bool): True if the first device is at risk of running out of memory, False otherwise. + card_0_in_high_risk = card_0_used_memory / device_0_memory > 0.8 + card_0_used_memory = card_0_left_memory + block_input_output_memory + additional_memory + We may need to clear card 0 memory more frequently during training/inference in that case. Raises: RuntimeError: If no CUDA or XPU devices are found. @@ -968,6 +971,7 @@ def set_auto_device_map_for_block_with_tuning( logger.debug(f" Additional_memory from other ops: {additional_memory} GB") card_0_left_memory = max(0, (device_0_memory - card_0_used_memory)) + card_0_in_high_risk = card_0_used_memory / device_0_memory >= 0.8 # Calculate total available memory across all devices total_available_memory = card_0_left_memory @@ -1002,6 +1006,8 @@ def set_auto_device_map_for_block_with_tuning( if has_params or has_buffers: module = module.to(output_device) + return card_0_in_high_risk + def partition_dict_numbers(number_dict, n): """ From a5d455e5ee663c9caa430126b1c281ecef88c191 Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Tue, 4 Nov 2025 00:30:57 -0500 Subject: [PATCH 10/19] update per review comments Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 12 ++-- auto_round/compressors/utils.py | 2 +- auto_round/utils/device.py | 108 +++++++++++++++++++------------- 3 files changed, 72 insertions(+), 50 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 0e63536b8..952449990 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1430,7 +1430,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map): set_auto_device_map_for_block_with_tuning( - block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size + block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, self.device ) # Dispatch model if needed if self.device_map is not None: @@ -2448,12 +2448,9 @@ def _quantize_block( new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to(device) set_module(block, n, new_layer) - if self.device_map == "auto" or ((isinstance(self.device_map, str) and "," in self.device_map)): - card_0_in_high_risk = set_auto_device_map_for_block_with_tuning( - block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device - ) - else: - block = block.to(device) + card_0_in_high_risk = set_auto_device_map_for_block_with_tuning( + block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device + ) if self.device_map is not None: for n, m in block.named_modules(): @@ -2777,6 +2774,7 @@ def _quantize_blocks( q_input=q_input, device=device, ) + del m.config if self.is_packing_immediate: from auto_round.export import PACKING_LAYER_WITH_FORMAT diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index 94b67b3f3..db7f0169d 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -206,7 +206,7 @@ def collect_best_params(block, low_gpu_mem_usage: bool = False): params[n] = {} for key in m.params.keys(): if low_gpu_mem_usage: - params[n][key] = m.params[key].data.cpu() + params[n][key] = m.params[key].data.to("cpu", copy=True) else: params[n][key] = copy.deepcopy(m.params[key].data) return params diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 4b64c77b7..d4826e646 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -731,6 +731,25 @@ def find_best_device(layer_name, estimated_memory, layer_idx): return ordered_device_map, names +def get_first_available_attr(obj, attr_names: list[str], default=None): + """ + Get the first available attribute from a list of attribute names. + + Args: + obj: The object to get the attribute from. + attr_names (list[str]): List of attribute names to try in order. + default: Default value to return if none of the attributes exist. + + Returns: + The value of the first available attribute, or default if none exist. + """ + for attr_name in attr_names: + value = getattr(obj, attr_name, None) + if value is not None: + return value + return default + + def get_moe_memory_ratio(block: torch.nn.Module) -> float: """ Calculate the memory ratio for MoE (Mixture of Experts) models. @@ -754,48 +773,45 @@ def get_moe_memory_ratio(block: torch.nn.Module) -> float: from auto_round.utils.model import is_moe for name, module in block.named_modules(): - if is_moe(module): - config = getattr(block, "config", None) - if config is None: - break + if not is_moe(module): + continue - # Try to get num_experts_per_tok (active experts count) - num_experts_per_tok = getattr( - config, "num_experts_per_tok", None - ) # Mixtral, Qwen2MoE, DeepSeek, GPT-OSS, Llama4, LLaDAMoE - if num_experts_per_tok is None: - num_experts_per_tok = getattr(config, "moe_num_active_primary_experts", None) # SmallThinker - if num_experts_per_tok is None: - # HunYuan MoE uses moe_topk (array), get first element - moe_topk = getattr(config, "moe_topk", None) # HunYuan MoE V1 - if moe_topk is not None and isinstance(moe_topk, (list, tuple)) and len(moe_topk) > 0: - num_experts_per_tok = moe_topk[0] - elif moe_topk is not None: - num_experts_per_tok = moe_topk - - if num_experts_per_tok is None: - break + config = getattr(block, "config", None) + if config is None: + break - # Get total number of experts - num_experts = getattr(config, "num_local_experts", None) # Mixtral, PhiMoE, Grok, Llama4 - if num_experts is None: - num_experts = getattr( - config, "num_experts", None - ) # Qwen2MoE, Olmo, BailingMoE, GroveMoE, HunYuan, LLaDAMoE - if num_experts is None: - num_experts = getattr(config, "moe_num_primary_experts", None) # SmallThinker - if num_experts is None: - num_experts = getattr(config, "n_routed_experts", None) # DeepSeek - - if num_experts is not None and num_experts > 0: - moe_ratio = num_experts_per_tok / num_experts - logger.debug( - f"MoE detected: {num_experts_per_tok}/{num_experts} experts active per token, " - f"activation memory ratio: {moe_ratio:.2f}" - ) - logger.debug(f"Using MoE memory ratio: {moe_ratio:.4f}") - return moe_ratio, True - break # Only check once per block + # Try to get num_experts_per_tok (active experts count) + # Mixtral, Qwen2MoE, DeepSeek, GPT-OSS, Llama4, LLaDAMoE, SmallThinker + num_experts_per_tok = get_first_available_attr( + config, ["num_experts_per_tok", "moe_num_active_primary_experts"] + ) + + # HunYuan MoE uses moe_topk (array), get first element + if num_experts_per_tok is None: + moe_topk = getattr(config, "moe_topk", None) # HunYuan MoE V1 + if moe_topk is not None and isinstance(moe_topk, (list, tuple)) and len(moe_topk) > 0: + num_experts_per_tok = moe_topk[0] + elif moe_topk is not None: + num_experts_per_tok = moe_topk + + if num_experts_per_tok is None: + break + + # Get total number of experts + # Mixtral, PhiMoE, Grok, Llama4, Qwen2MoE, Olmo, BailingMoE, GroveMoE, HunYuan, LLaDAMoE, SmallThinker, DeepSeek + num_experts = get_first_available_attr( + config, ["num_local_experts", "num_experts", "moe_num_primary_experts", "n_routed_experts"] + ) + + if num_experts is not None and num_experts > 0: + moe_ratio = num_experts_per_tok / num_experts + logger.debug( + f"MoE detected: {num_experts_per_tok}/{num_experts} experts active per token, " + f"activation memory ratio: {moe_ratio:.2f}" + ) + logger.debug(f"Using MoE memory ratio: {moe_ratio:.4f}") + return moe_ratio, True + break # Only check once per block return 1.0, False # Default ratio for non-MoE models @@ -910,6 +926,7 @@ def set_auto_device_map_for_block_with_tuning( low_gpu_mem_usage=False, pick_samples=8, output_device=None, + card_0_threshold=0.9, ): """ Automatically sets the device map for the block based on available GPUs and memory constraints. @@ -921,10 +938,12 @@ def set_auto_device_map_for_block_with_tuning( low_gpu_mem_usage (bool, optional): If True, ignoring input/output memory. Defaults to False. pick_samples (int, optional): Number of samples to consider for memory estimation. Defaults to 8. output_device (str | torch.device, optional): Device to move unassigned modules to. Defaults to None. + card_0_threshold (float, optional): Threshold ratio to determine if the first device is at high risk of + running out of memory. Defaults to 0.9 (90%). Returns: card_0_in_high_risk (bool): True if the first device is at risk of running out of memory, False otherwise. - card_0_in_high_risk = card_0_used_memory / device_0_memory > 0.8 + card_0_in_high_risk = card_0_used_memory / device_0_memory > card_0_threshold card_0_used_memory = card_0_left_memory + block_input_output_memory + additional_memory We may need to clear card 0 memory more frequently during training/inference in that case. @@ -934,6 +953,11 @@ def set_auto_device_map_for_block_with_tuning( Note: This function is intended for internal use in device memory management and tuning. """ + if not (device_map == "auto" or ((isinstance(device_map, str) and "," in device_map))): + block = block.to(output_device) + card_0_in_high_risk = False # card 0 contains weight, clear_memory will not help much + return card_0_in_high_risk + if torch.cuda.is_available(): num_devices = torch.cuda.device_count() device_name = "cuda" @@ -971,7 +995,7 @@ def set_auto_device_map_for_block_with_tuning( logger.debug(f" Additional_memory from other ops: {additional_memory} GB") card_0_left_memory = max(0, (device_0_memory - card_0_used_memory)) - card_0_in_high_risk = card_0_used_memory / device_0_memory >= 0.8 + card_0_in_high_risk = card_0_used_memory / device_0_memory >= card_0_threshold # Calculate total available memory across all devices total_available_memory = card_0_left_memory From cde0aa10bf399b46b92c7e9fce21849af9a90e66 Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Tue, 4 Nov 2025 00:37:32 -0500 Subject: [PATCH 11/19] add clear_memory back Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 952449990..d73b0c002 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2651,6 +2651,8 @@ def _quantize_block( set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max") if self.enable_quanted_input: + if not self.low_gpu_mem_usage: # In case of clearing memory twice + clear_memory() # clear cached memory during training q_outputs = self._get_block_outputs( block, input_ids, From 518078bdaca3fc5736c471670b6d7fdd2e9700f5 Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Tue, 4 Nov 2025 00:40:34 -0500 Subject: [PATCH 12/19] fix bug Signed-off-by: He, Xin3 --- auto_round/utils/device.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index d4826e646..56bd18589 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -1102,18 +1102,18 @@ def find_optimal_subset(arr, target): def set_avg_auto_device_map(model: torch.nn.Module, device_map): block_name_list = get_block_names(model) device_list = None + if torch.cuda.is_available(): + num_devices = torch.cuda.device_count() + device_name = "cuda" + elif torch.xpu.is_available(): + num_devices = torch.xpu.device_count() + device_name = "xpu" + else: + return + if isinstance(device_map, str) and "," in device_map: device_list = [int(dev) for dev in device_map.split(",") if dev.isdigit()] num_devices = len(device_list) - else: - if torch.cuda.is_available(): - num_devices = torch.cuda.device_count() - device_name = "cuda" - elif torch.xpu.is_available(): - num_devices = torch.xpu.device_count() - device_name = "xpu" - else: - return if device_list: gpu_devices = [f"{device_name}:{i}" for i in device_list] From a27e1dd042d6e19a930a09da6f07ddd0b4c17ce1 Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Tue, 4 Nov 2025 02:46:39 -0500 Subject: [PATCH 13/19] move loss device to the second card if card_0_in_high_risk Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 13 +++++++------ auto_round/utils/device.py | 15 +++++++++++---- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index d73b0c002..24faa7cf5 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2402,6 +2402,7 @@ def _get_current_q_output( input_others: dict, indices: list[int], device: str, + output_device: str = "cpu", ) -> torch.Tensor: current_input_ids, current_input_others = self._sampling_inputs( input_ids, @@ -2412,7 +2413,7 @@ def _get_current_q_output( share_cache_keys=self.shared_cache_keys, ) output_q = self.block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device) - return output_q + return output_q.to(output_device) def _get_current_num_elm( self, @@ -2448,7 +2449,7 @@ def _quantize_block( new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to(device) set_module(block, n, new_layer) - card_0_in_high_risk = set_auto_device_map_for_block_with_tuning( + card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning( block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device ) @@ -2578,18 +2579,18 @@ def _quantize_block( current_output = self._get_current_output(output, indices) - current_output = to_device(current_output, device) + current_output = to_device(current_output, loss_device) - output_q = self._get_current_q_output(block, input_ids, input_others, indices, device) + output_q = self._get_current_q_output(block, input_ids, input_others, indices, device, loss_device) if self.attention_mask: tmp_attention_mask = [self.attention_mask[i] for i in indices] - tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) + tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(loss_device) tmp_attention_mask.unsqueeze_(-1) else: tmp_attention_mask = 1.0 if self.amp: - with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): + with autocast(device_type=loss_device.split(":")[0], dtype=self.amp_dtype): loss = mse_loss( # pylint: disable=not-callable output_q * tmp_attention_mask, current_output * tmp_attention_mask ) diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 56bd18589..e56f00f30 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -956,7 +956,8 @@ def set_auto_device_map_for_block_with_tuning( if not (device_map == "auto" or ((isinstance(device_map, str) and "," in device_map))): block = block.to(output_device) card_0_in_high_risk = False # card 0 contains weight, clear_memory will not help much - return card_0_in_high_risk + loss_device = output_device + return card_0_in_high_risk, loss_device if torch.cuda.is_available(): num_devices = torch.cuda.device_count() @@ -973,14 +974,18 @@ def set_auto_device_map_for_block_with_tuning( if device_list: gpu_devices = [f"{device_name}:{i}" for i in device_list] device_0 = gpu_devices[0] + device_1 = gpu_devices[1] else: gpu_devices = [f"{device_name}:{i}" for i in range(num_devices)] device_0 = f"{device_name}:0" + device_1 = f"{device_name}:1" device_0_memory = get_device_memory(device_list[0] if device_list else 0) + device_1_memory = get_device_memory(device_list[1] if device_list else 1) layer_memory_dict, layer_activation_memory, block_input_output_memory, additional_memory = ( estimate_tuning_block_mem(block, input_ids, pick_samples) ) + loss_memory = block_input_output_memory / 2 # GB, rough estimate for loss tensor memory if low_gpu_mem_usage: block_input_output_memory = 0 @@ -996,10 +1001,12 @@ def set_auto_device_map_for_block_with_tuning( card_0_left_memory = max(0, (device_0_memory - card_0_used_memory)) card_0_in_high_risk = card_0_used_memory / device_0_memory >= card_0_threshold + card_1_left_memory = max(0, device_1_memory - loss_memory) if card_0_in_high_risk else device_1_memory + loss_device = device_1 if card_0_in_high_risk else output_device # Calculate total available memory across all devices - total_available_memory = card_0_left_memory - for i in range(1, len(gpu_devices)): + total_available_memory = card_0_left_memory + card_1_left_memory + for i in range(2, len(gpu_devices)): device_idx = device_list[i] if device_list else i total_available_memory += get_device_memory(device_idx) @@ -1030,7 +1037,7 @@ def set_auto_device_map_for_block_with_tuning( if has_params or has_buffers: module = module.to(output_device) - return card_0_in_high_risk + return card_0_in_high_risk, loss_device def partition_dict_numbers(number_dict, n): From 74c45716d88fd28e3cbef4dc33980063bcbff3f5 Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Wed, 5 Nov 2025 00:11:10 -0500 Subject: [PATCH 14/19] update per review comments Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 11 ++++++----- auto_round/compressors/utils.py | 8 +++----- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 7d6f470b2..333c41027 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2332,10 +2332,10 @@ def _quantize_layer( if total_loss < best_loss: best_loss = total_loss if not self.not_use_best_mse: - best_params = collect_best_params(wrapper_linear, self.low_gpu_mem_usage) + best_params = collect_best_params(wrapper_linear, self.cache_device) last_best_iter = i if self.not_use_best_mse and i == self.iters - 1: - best_params = collect_best_params(wrapper_linear, self.low_gpu_mem_usage) + best_params = collect_best_params(wrapper_linear, self.cache_device) if not self.not_use_best_mse: if 0 < self.dynamic_max_gap <= i - last_best_iter: @@ -2459,7 +2459,8 @@ def _quantize_block( if is_fp8_linear(m): new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device).to(device) set_module(block, n, new_layer) - + # card_0_in_high_risk indicates that card_0 memory is already in high usage w/o any weights + # card_0_used_memory = block_input_output_memory + layer_activation_memory + additional_memory card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning( block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device ) @@ -2629,12 +2630,12 @@ def _quantize_block( if total_loss < best_loss: best_loss = total_loss if not self.not_use_best_mse: - best_params = collect_best_params(block, self.low_gpu_mem_usage) + best_params = collect_best_params(block, self.cache_device) # print(f"get better result at iter {i}, the loss is {total_loss}", flush=True) last_best_iter = i if self.not_use_best_mse and i == self.iters - 1: - best_params = collect_best_params(block, self.low_gpu_mem_usage) + best_params = collect_best_params(block, self.cache_device) if not self.not_use_best_mse: if 0 < self.dynamic_max_gap <= i - last_best_iter: diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index db7f0169d..fb94611a5 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -199,16 +199,14 @@ def check_awq_gemm_compatibility(model, bits, group_size, sym, layer_configs=Non return True, "" -def collect_best_params(block, low_gpu_mem_usage: bool = False): +def collect_best_params(block, cache_device="cpu"): + """Collect the best parameters from the block to the specified device.""" params = {} for n, m in block.named_modules(): if hasattr(m, "orig_layer"): params[n] = {} for key in m.params.keys(): - if low_gpu_mem_usage: - params[n][key] = m.params[key].data.to("cpu", copy=True) - else: - params[n][key] = copy.deepcopy(m.params[key].data) + params[n][key] = m.params[key].data.to(cache_device, copy=True) return params From f155941ea63733a717039d1765c91e169a7e3b47 Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Wed, 5 Nov 2025 00:49:09 -0500 Subject: [PATCH 15/19] add comments Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 333c41027..2898cd1fd 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2459,8 +2459,8 @@ def _quantize_block( if is_fp8_linear(m): new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device).to(device) set_module(block, n, new_layer) - # card_0_in_high_risk indicates that card_0 memory is already in high usage w/o any weights - # card_0_used_memory = block_input_output_memory + layer_activation_memory + additional_memory + # card_0_in_high_risk indicates that card_0 memory is already in high usage (90%) w/o any weights + # loss_device is used to calculate loss on the second device if available and card_0_in_high_risk card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning( block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device ) From 9e209cd4f012865f2a58cfb6bae57ee21c293ebc Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Wed, 5 Nov 2025 00:52:40 -0500 Subject: [PATCH 16/19] add sync Signed-off-by: He, Xin3 --- auto_round/utils/device.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 13ca7275b..57c2c67bc 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -412,8 +412,10 @@ def _clear_memory_for_cpu_and_cuda(tensor=None): del tensor gc.collect() if torch.cuda.is_available(): + torch.cuda.synchronize() torch.cuda.empty_cache() if torch.xpu.is_available(): + torch.xpu.synchronize() torch.xpu.empty_cache() From b7d13915e399d468aa218d9c7cae99539c1be9fa Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Wed, 5 Nov 2025 21:57:55 -0500 Subject: [PATCH 17/19] remove duplicate clear_memory and add warning message Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 2 -- auto_round/utils/device.py | 14 +++++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 2898cd1fd..e455e9396 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2664,8 +2664,6 @@ def _quantize_block( set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max") if self.enable_quanted_input: - if not self.low_gpu_mem_usage: # In case of clearing memory twice - clear_memory() # clear cached memory during training q_outputs = self._get_block_outputs( block, input_ids, diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 57c2c67bc..84bc67ed9 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -442,9 +442,9 @@ def clear_memory_if_reached_threshold(threshold=0.85): """ # Detect CUDA/XPU devices if torch.cuda.is_available(): - name, device_api = "CUDA", torch.cuda + name, device_api = "cuda", torch.cuda elif hasattr(torch, "xpu") and torch.xpu.is_available(): - name, device_api = "XPU", torch.xpu + name, device_api = "xpu", torch.xpu else: return False @@ -457,14 +457,18 @@ def clear_memory_if_reached_threshold(threshold=0.85): if memory_usage_ratio >= threshold: logger.warning_once( - f"{name} device {i} has reached memory threshold. " + f"Major device ({name}:{i}) has reached memory threshold. " + "Memory clearing operation will be called during each iteration, which " + "will result in more time consumption." ) + logger.warning_once( + "To alleviate high memory usage on the major device, consider reducing the `batch_size` " + + "(and correspondingly increasing `gradient_accumulation_steps) or shortening the seqlen." + ) clear_memory() return True except Exception as e: - logger.warning_once(f"Failed to check memory for {name} device {i}: {e}") + logger.warning_once(f"Failed to check memory for {name}:{i}: {e}") return False @@ -480,7 +484,7 @@ def check_memory_availability(device, inputs, weight, org_seqlen, org_bs): Returns: tuple: A tuple containing availability status (bool), modified sequence length (int), - and modified batch size (int). + and modified batch size (int). """ weight_memory = weight.numel() * weight.element_size() if "cuda" in device: From 0aaaa09b45697bed499562b7afea9792716b4c2a Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Thu, 6 Nov 2025 00:01:11 -0500 Subject: [PATCH 18/19] update per review comments Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 17 +++++++++++------ auto_round/utils/device.py | 10 ++++------ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index e455e9396..a77ef3bc6 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2413,7 +2413,7 @@ def _get_current_q_output( input_others: dict, indices: list[int], device: str, - output_device: str = "cpu", + cache_device: str = "cpu", ) -> torch.Tensor: current_input_ids, current_input_others = self._sampling_inputs( input_ids, @@ -2424,7 +2424,7 @@ def _get_current_q_output( share_cache_keys=self.shared_cache_keys, ) output_q = self.block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device) - return output_q.to(output_device) + return output_q.to(cache_device) def _get_current_num_elm( self, @@ -2461,9 +2461,13 @@ def _quantize_block( set_module(block, n, new_layer) # card_0_in_high_risk indicates that card_0 memory is already in high usage (90%) w/o any weights # loss_device is used to calculate loss on the second device if available and card_0_in_high_risk - card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning( - block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device - ) + if self.device_map == "auto" or ((isinstance(self.device_map, str) and "," in self.device_map)): + card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning( + block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device + ) + else: + block = block.to(device) + card_0_in_high_risk, loss_device = False, device if is_complex_device_mapping(self.device_map): for n, m in block.named_modules(): @@ -2797,7 +2801,8 @@ def _quantize_blocks( q_input=q_input, device=device, ) - del m.config + if hasattr(model, "config"): + del m.config if self.is_packing_immediate: from auto_round.export import PACKING_LAYER_WITH_FORMAT diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 84bc67ed9..46ad865d4 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -412,10 +412,8 @@ def _clear_memory_for_cpu_and_cuda(tensor=None): del tensor gc.collect() if torch.cuda.is_available(): - torch.cuda.synchronize() torch.cuda.empty_cache() if torch.xpu.is_available(): - torch.xpu.synchronize() torch.xpu.empty_cache() @@ -938,10 +936,10 @@ def set_auto_device_map_for_block_with_tuning( block: torch.nn.Module, device_map, input_ids: list[torch.Tensor], - low_gpu_mem_usage=False, - pick_samples=8, - output_device=None, - card_0_threshold=0.9, + low_gpu_mem_usage: bool = False, + pick_samples: int = 8, + output_device: str | torch.device = None, + card_0_threshold: float = 0.9, ): """ Automatically sets the device map for the block based on available GPUs and memory constraints. From f530032caa25390964ae7897949fa57d81cc5c39 Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Thu, 6 Nov 2025 00:07:05 -0500 Subject: [PATCH 19/19] rename pick samples to batch size Signed-off-by: He, Xin3 --- auto_round/utils/device.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 46ad865d4..b0ecf9019 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -830,7 +830,7 @@ def get_moe_memory_ratio(block: torch.nn.Module) -> float: def estimate_tuning_block_mem( - block: torch.nn.Module, input_ids: list[torch.Tensor], pick_samples: int + block: torch.nn.Module, input_ids: list[torch.Tensor], batch_size: int ) -> tuple[dict, float]: """ Calculates the memory consumption of a specific block in the model. @@ -838,7 +838,7 @@ def estimate_tuning_block_mem( Args: block (torch.nn.Module): The block of the model to analyze. input_ids (list[torch.Tensor]): A list of input tensors for the block. - pick_samples (int): Number of samples to consider for memory estimation. + batch_size (int): Number of samples to consider for memory estimation. Returns: tuple: A tuple containing the following: @@ -871,12 +871,12 @@ def estimate_tuning_block_mem( in_features, out_features = get_layer_features(module) if in_features is not None and out_features is not None: # Output tensor size: batch_size * seq_len * out_features * element_size - output_size = pick_samples * seq_len * out_features * element_size + output_size = batch_size * seq_len * out_features * element_size output_memory_gb = output_size / 1024**3 # If enable_act_quant, add input tensor memory to param_memory if enable_act_quant: - input_size = pick_samples * seq_len * in_features * element_size + input_size = batch_size * seq_len * in_features * element_size input_memory_gb = input_size / 1024**3 param_memory_gb += input_memory_gb else: @@ -937,7 +937,7 @@ def set_auto_device_map_for_block_with_tuning( device_map, input_ids: list[torch.Tensor], low_gpu_mem_usage: bool = False, - pick_samples: int = 8, + batch_size: int = 8, output_device: str | torch.device = None, card_0_threshold: float = 0.9, ): @@ -949,7 +949,7 @@ def set_auto_device_map_for_block_with_tuning( device_map (str | int | dict): Specifies the device mapping. input_ids (list[torch.Tensor]): List of input tensors used for estimating memory requirements. low_gpu_mem_usage (bool, optional): If True, ignoring input/output memory. Defaults to False. - pick_samples (int, optional): Number of samples to consider for memory estimation. Defaults to 8. + batch_size (int, optional): Number of samples to consider for memory estimation. Defaults to 8. output_device (str | torch.device, optional): Device to move unassigned modules to. Defaults to None. card_0_threshold (float, optional): Threshold ratio to determine if the first device is at high risk of running out of memory. Defaults to 0.9 (90%). @@ -996,7 +996,7 @@ def set_auto_device_map_for_block_with_tuning( device_0_memory = get_device_memory(device_list[0] if device_list else 0) device_1_memory = get_device_memory(device_list[1] if device_list else 1) layer_memory_dict, layer_activation_memory, block_input_output_memory, additional_memory = ( - estimate_tuning_block_mem(block, input_ids, pick_samples) + estimate_tuning_block_mem(block, input_ids, batch_size) ) loss_memory = block_input_output_memory / 2 # GB, rough estimate for loss tensor memory if low_gpu_mem_usage: