diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 39602680c..0a682079a 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1441,7 +1441,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) if is_complex_device_mapping(self.device_map): set_auto_device_map_for_block_with_tuning( - block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size + block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, self.device ) # Dispatch model if needed if is_complex_device_mapping(self.device_map): @@ -2332,10 +2332,10 @@ def _quantize_layer( if total_loss < best_loss: best_loss = total_loss if not self.not_use_best_mse: - best_params = collect_best_params(wrapper_linear) + best_params = collect_best_params(wrapper_linear, self.cache_device) last_best_iter = i if self.not_use_best_mse and i == self.iters - 1: - best_params = collect_best_params(wrapper_linear) + best_params = collect_best_params(wrapper_linear, self.cache_device) if not self.not_use_best_mse: if 0 < self.dynamic_max_gap <= i - last_best_iter: @@ -2413,6 +2413,7 @@ def _get_current_q_output( input_others: dict, indices: list[int], device: str, + cache_device: str = "cpu", ) -> torch.Tensor: current_input_ids, current_input_others = self._sampling_inputs( input_ids, @@ -2423,7 +2424,7 @@ def _get_current_q_output( share_cache_keys=self.shared_cache_keys, ) output_q = self.block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device) - return output_q + return output_q.to(cache_device) def _get_current_num_elm( self, @@ -2458,13 +2459,15 @@ def _quantize_block( if is_fp8_linear(m): new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device).to(device) set_module(block, n, new_layer) - - if is_complex_device_mapping(self.device_map): - set_auto_device_map_for_block_with_tuning( + # card_0_in_high_risk indicates that card_0 memory is already in high usage (90%) w/o any weights + # loss_device is used to calculate loss on the second device if available and card_0_in_high_risk + if self.device_map == "auto" or ((isinstance(self.device_map, str) and "," in self.device_map)): + card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning( block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device ) else: block = block.to(device) + card_0_in_high_risk, loss_device = False, device if is_complex_device_mapping(self.device_map): for n, m in block.named_modules(): @@ -2594,13 +2597,13 @@ def _quantize_block( current_output = self._get_current_output(output, indices) - current_output = to_device(current_output, device) + current_output = to_device(current_output, loss_device) - output_q = self._get_current_q_output(block, input_ids, input_others, indices, device) + output_q = self._get_current_q_output(block, input_ids, input_others, indices, device, loss_device) if self.attention_mask: tmp_attention_mask = [self.attention_mask[i] for i in indices] - tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) + tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(loss_device) tmp_attention_mask.unsqueeze_(-1) num_elm = torch.sum(tmp_attention_mask).item() if num_elm == 0: @@ -2608,7 +2611,7 @@ def _quantize_block( else: tmp_attention_mask = 1.0 if self.amp: - with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): + with autocast(device_type=loss_device.split(":")[0], dtype=self.amp_dtype): loss = mse_loss( # pylint: disable=not-callable output_q * tmp_attention_mask, current_output * tmp_attention_mask ) @@ -2619,20 +2622,29 @@ def _quantize_block( ) total_loss += loss.item() / num_elm + + if self.low_gpu_mem_usage and card_0_in_high_risk: + # clear memory to avoid OOM due to memory fragmentation + clear_memory_if_reached_threshold(threshold=0.5) + self._scale_loss_and_backward(scaler, loss) + if self.low_gpu_mem_usage and card_0_in_high_risk: + # clear memory to avoid OOM due to memory fragmentation + clear_memory_if_reached_threshold(threshold=0.8) + if i == 0: init_loss = total_loss if total_loss < best_loss: best_loss = total_loss if not self.not_use_best_mse: - best_params = collect_best_params(block) + best_params = collect_best_params(block, self.cache_device) # print(f"get better result at iter {i}, the loss is {total_loss}", flush=True) last_best_iter = i if self.not_use_best_mse and i == self.iters - 1: - best_params = collect_best_params(block) + best_params = collect_best_params(block, self.cache_device) if not self.not_use_best_mse: if 0 < self.dynamic_max_gap <= i - last_best_iter: @@ -2649,6 +2661,8 @@ def _quantize_block( f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" ) logger.info(dump_info) + if self.low_gpu_mem_usage: + clear_memory() # clear cached memory during training if len(unquantized_layer_names) != 0: logger.info(f"{unquantized_layer_names} have not been quantized") with torch.no_grad(): @@ -2659,8 +2673,6 @@ def _quantize_block( set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max") if self.enable_quanted_input: - if self.low_gpu_mem_usage: - clear_memory() q_outputs = self._get_block_outputs( block, input_ids, @@ -2786,6 +2798,7 @@ def _quantize_blocks( modules = [get_module(model, n) for n in names] m = WrapperMultiblock(modules) + m.config = model.config if hasattr(model, "config") else None q_input, input_ids = quantize_block( m, input_ids, @@ -2793,6 +2806,8 @@ def _quantize_blocks( q_input=q_input, device=device, ) + if hasattr(model, "config"): + del m.config if self.is_packing_immediate: from auto_round.export import PACKING_LAYER_WITH_FORMAT diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index 7d6f28806..ba5db377c 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -199,13 +199,14 @@ def check_awq_gemm_compatibility(model, bits, group_size, sym, layer_configs=Non return True, "" -def collect_best_params(block): +def collect_best_params(block, cache_device="cpu"): + """Collect the best parameters from the block to the specified device.""" params = {} for n, m in block.named_modules(): if hasattr(m, "orig_layer"): params[n] = {} for key in m.params.keys(): - params[n][key] = copy.deepcopy(m.params[key].data) + params[n][key] = m.params[key].data.to(cache_device, copy=True) return params diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 3cfcde3dc..b0ecf9019 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -440,31 +440,34 @@ def clear_memory_if_reached_threshold(threshold=0.85): """ # Detect CUDA/XPU devices if torch.cuda.is_available(): - name, device_api = "CUDA", torch.cuda + name, device_api = "cuda", torch.cuda elif hasattr(torch, "xpu") and torch.xpu.is_available(): - name, device_api = "XPU", torch.xpu + name, device_api = "xpu", torch.xpu else: - return + return False num_devices = device_api.device_count() for i in range(num_devices): try: total_memory = device_api.get_device_properties(i).total_memory - allocated_memory = device_api.memory_reserved(i) if name == "CUDA" else device_api.memory_allocated(i) - memory_usage_ratio = allocated_memory / total_memory + reserved_memory = device_api.memory_reserved(i) + memory_usage_ratio = reserved_memory / total_memory if memory_usage_ratio >= threshold: logger.warning_once( - f"{name} device {i}: Memory usage {memory_usage_ratio*100:.2f}% " - f"exceeds threshold {threshold*100:.2f}%. Clearing memory..." + f"Major device ({name}:{i}) has reached memory threshold. " + + "Memory clearing operation will be called during each iteration, which " + + "will result in more time consumption." + ) + logger.warning_once( + "To alleviate high memory usage on the major device, consider reducing the `batch_size` " + + "(and correspondingly increasing `gradient_accumulation_steps) or shortening the seqlen." ) clear_memory() - allocated_memory = device_api.memory_reserved(i) if name == "CUDA" else device_api.memory_allocated(i) - memory_usage_ratio = allocated_memory / total_memory - logger.warning_once(f"Cleared memory. {name} device {i}: Memory usage {memory_usage_ratio*100:.2f}%") return True except Exception as e: - logger.warning_once(f"Failed to check memory for {name} device {i}: {e}") + logger.warning_once(f"Failed to check memory for {name}:{i}: {e}") + return False def check_memory_availability(device, inputs, weight, org_seqlen, org_bs): @@ -479,7 +482,7 @@ def check_memory_availability(device, inputs, weight, org_seqlen, org_bs): Returns: tuple: A tuple containing availability status (bool), modified sequence length (int), - and modified batch size (int). + and modified batch size (int). """ weight_memory = weight.numel() * weight.element_size() if "cuda" in device: @@ -741,8 +744,93 @@ def find_best_device(layer_name, estimated_memory, layer_idx): return ordered_device_map, names +def get_first_available_attr(obj, attr_names: list[str], default=None): + """ + Get the first available attribute from a list of attribute names. + + Args: + obj: The object to get the attribute from. + attr_names (list[str]): List of attribute names to try in order. + default: Default value to return if none of the attributes exist. + + Returns: + The value of the first available attribute, or default if none exist. + """ + for attr_name in attr_names: + value = getattr(obj, attr_name, None) + if value is not None: + return value + return default + + +def get_moe_memory_ratio(block: torch.nn.Module) -> float: + """ + Calculate the memory ratio for MoE (Mixture of Experts) models. + + For MoE models, only num_experts_per_tok experts are activated per token, + not all experts. This function returns the ratio of active experts to total experts. + + Args: + block (torch.nn.Module): The model block to analyze. + + Returns: + float: Memory ratio (num_experts_per_tok / num_experts). + Returns 1.0 for non-MoE models. + bool: True if the model is MoE, False otherwise. + + Examples: + - Non-MoE model: returns 1.0 + - Mixtral (2/8 experts): returns 0.25 + - Qwen2MoE (4/60 experts): returns ~0.067 + """ + from auto_round.utils.model import is_moe + + for name, module in block.named_modules(): + if not is_moe(module): + continue + + config = getattr(block, "config", None) + if config is None: + break + + # Try to get num_experts_per_tok (active experts count) + # Mixtral, Qwen2MoE, DeepSeek, GPT-OSS, Llama4, LLaDAMoE, SmallThinker + num_experts_per_tok = get_first_available_attr( + config, ["num_experts_per_tok", "moe_num_active_primary_experts"] + ) + + # HunYuan MoE uses moe_topk (array), get first element + if num_experts_per_tok is None: + moe_topk = getattr(config, "moe_topk", None) # HunYuan MoE V1 + if moe_topk is not None and isinstance(moe_topk, (list, tuple)) and len(moe_topk) > 0: + num_experts_per_tok = moe_topk[0] + elif moe_topk is not None: + num_experts_per_tok = moe_topk + + if num_experts_per_tok is None: + break + + # Get total number of experts + # Mixtral, PhiMoE, Grok, Llama4, Qwen2MoE, Olmo, BailingMoE, GroveMoE, HunYuan, LLaDAMoE, SmallThinker, DeepSeek + num_experts = get_first_available_attr( + config, ["num_local_experts", "num_experts", "moe_num_primary_experts", "n_routed_experts"] + ) + + if num_experts is not None and num_experts > 0: + moe_ratio = num_experts_per_tok / num_experts + logger.debug( + f"MoE detected: {num_experts_per_tok}/{num_experts} experts active per token, " + f"activation memory ratio: {moe_ratio:.2f}" + ) + logger.debug(f"Using MoE memory ratio: {moe_ratio:.4f}") + return moe_ratio, True + break # Only check once per block + + return 1.0, False # Default ratio for non-MoE models + + def estimate_tuning_block_mem( - block: torch.nn.Module, input_ids: list[torch.Tensor], pick_samples: int + block: torch.nn.Module, input_ids: list[torch.Tensor], batch_size: int ) -> tuple[dict, float]: """ Calculates the memory consumption of a specific block in the model. @@ -750,26 +838,27 @@ def estimate_tuning_block_mem( Args: block (torch.nn.Module): The block of the model to analyze. input_ids (list[torch.Tensor]): A list of input tensors for the block. - pick_samples (int): Number of samples to consider for memory estimation. + batch_size (int): Number of samples to consider for memory estimation. Returns: tuple: A tuple containing the following: - layer_memory_dict (dict): A dictionary mapping layer names to their memory consumption (in GB). - Format: {layer_name: {"param_memory": float, "output_memory": float}} + Format: {layer_name: {"param_memory": float, "output_memory": float}} - input_output_memory (float): The memory consumption (in GB) for input and output tensors of the block. - additional_memory (float): Additional memory overhead (in GB) for operations like attention. """ # Calculate all block parameters memory and build layer-wise memory dict - from auto_round.utils.model import get_layer_features + from auto_round.utils.model import get_layer_features, is_moe layer_memory_dict = {} - total_param_mem = 0 # Calculate batch_size and sequence_length from input_ids for output memory estimation seq_len = input_ids[0].shape[1] if input_ids and len(input_ids[0].shape) >= 2 else 1 element_size = input_ids[0].element_size() if input_ids else 2 # Default to 2 bytes (fp16/bf16) + moe_ratio, has_moe = get_moe_memory_ratio(block) # Get MoE memory ratio (1.0 for non-MoE models) + for name, module in block.named_modules(): if check_to_quantized(module): enable_act_quant = module.act_bits <= 8 @@ -782,45 +871,75 @@ def estimate_tuning_block_mem( in_features, out_features = get_layer_features(module) if in_features is not None and out_features is not None: # Output tensor size: batch_size * seq_len * out_features * element_size - output_size = pick_samples * seq_len * out_features * element_size + output_size = batch_size * seq_len * out_features * element_size output_memory_gb = output_size / 1024**3 # If enable_act_quant, add input tensor memory to param_memory if enable_act_quant: - input_size = pick_samples * seq_len * in_features * element_size + input_size = batch_size * seq_len * in_features * element_size input_memory_gb = input_size / 1024**3 param_memory_gb += input_memory_gb else: output_memory_gb = 0.0 + if has_moe: + pparent_module = get_module(block, layer_name.rsplit(".", 2)[0]) if "." in layer_name else block + is_moe_expert = "expert" in layer_name.lower() and isinstance(pparent_module, torch.nn.ModuleList) + else: + is_moe_expert = False + # memory * 2, because it contains grad tensor. - layer_memory_dict[layer_name] = {"param_memory": param_memory_gb * 2, "output_memory": output_memory_gb * 2} + layer_memory_dict[layer_name] = { + "param_memory": param_memory_gb * 2, + "output_memory": output_memory_gb * 2, + "is_moe_expert": is_moe_expert, + } # Assuming bfloat16 or float32, input and output block_input_output_memory = 2 * sum(tensor.nbytes for tensor in input_ids) / 1024**3 # Roughly estimate additional memory for attention and other operations - additional_activation_memory = sum(info["output_memory"] for info in layer_memory_dict.values()) + # For MoE expert layers, multiply activation memory by the ratio of active experts + # For non-MoE layers (attention, norm, etc.), use full activation memory + layer_activation_memory = 0.0 + for layer_name, info in layer_memory_dict.items(): + if info.get("is_moe_expert", False): + # MoE expert layer: only a fraction of experts are active + layer_activation_memory += info["output_memory"] * moe_ratio + else: + # Non-MoE layer: use full activation memory + layer_activation_memory += info["output_memory"] + + # layer_activation_memory considers other ops activation memory # 1GB considers norm weight, sdpa, reference_output, etc. - additional_memory = additional_activation_memory + 1 # GB + additional_memory = layer_activation_memory + 1 # GB + if has_moe: + # TODO: Cannot estimate the memory usage correctly for MoE models yet. + # For MoE models, additional memory usage can be higher due to routing, gating, + # and multiple expert activations. Here we use a conservative estimate. + moe_additional_memory = additional_memory * 6 # GB + additional_memory += moe_additional_memory if torch.xpu.is_available(): # https://github.com/intel/torch-xpu-ops/issues/2232 # TODO: XPU takes more memory than expected. for llama 8B, it's about 12 GB xpu_additional_memory = 12 # GB additional_memory += xpu_additional_memory - logger.warning_once("XPU additional memory usage of SDPA is estimated to be 12 GB.") - logger.warning_once("Remove it after https://github.com/intel/torch-xpu-ops/issues/2232 is fixed.") + logger.warning_once( + "[Memory Estimation]: If there is an abnormal memory issue, please collect log with " + + "AR_LOG_LEVEL=debug and raise issue to us." + ) - return layer_memory_dict, block_input_output_memory, additional_memory + return layer_memory_dict, layer_activation_memory, block_input_output_memory, additional_memory def set_auto_device_map_for_block_with_tuning( block: torch.nn.Module, device_map, input_ids: list[torch.Tensor], - low_gpu_mem_usage=False, - pick_samples=8, - output_device=None, + low_gpu_mem_usage: bool = False, + batch_size: int = 8, + output_device: str | torch.device = None, + card_0_threshold: float = 0.9, ): """ Automatically sets the device map for the block based on available GPUs and memory constraints. @@ -830,11 +949,16 @@ def set_auto_device_map_for_block_with_tuning( device_map (str | int | dict): Specifies the device mapping. input_ids (list[torch.Tensor]): List of input tensors used for estimating memory requirements. low_gpu_mem_usage (bool, optional): If True, ignoring input/output memory. Defaults to False. - pick_samples (int, optional): Number of samples to consider for memory estimation. Defaults to 8. + batch_size (int, optional): Number of samples to consider for memory estimation. Defaults to 8. output_device (str | torch.device, optional): Device to move unassigned modules to. Defaults to None. + card_0_threshold (float, optional): Threshold ratio to determine if the first device is at high risk of + running out of memory. Defaults to 0.9 (90%). Returns: - None + card_0_in_high_risk (bool): True if the first device is at risk of running out of memory, False otherwise. + card_0_in_high_risk = card_0_used_memory / device_0_memory > card_0_threshold + card_0_used_memory = card_0_left_memory + block_input_output_memory + additional_memory + We may need to clear card 0 memory more frequently during training/inference in that case. Raises: RuntimeError: If no CUDA or XPU devices are found. @@ -842,6 +966,12 @@ def set_auto_device_map_for_block_with_tuning( Note: This function is intended for internal use in device memory management and tuning. """ + if not (device_map == "auto" or ((isinstance(device_map, str) and "," in device_map))): + block = block.to(output_device) + card_0_in_high_risk = False # card 0 contains weight, clear_memory will not help much + loss_device = output_device + return card_0_in_high_risk, loss_device + if torch.cuda.is_available(): num_devices = torch.cuda.device_count() device_name = "cuda" @@ -857,35 +987,39 @@ def set_auto_device_map_for_block_with_tuning( if device_list: gpu_devices = [f"{device_name}:{i}" for i in device_list] device_0 = gpu_devices[0] + device_1 = gpu_devices[1] else: gpu_devices = [f"{device_name}:{i}" for i in range(num_devices)] device_0 = f"{device_name}:0" + device_1 = f"{device_name}:1" device_0_memory = get_device_memory(device_list[0] if device_list else 0) - layer_memory_dict, block_input_output_memory, additional_memory = estimate_tuning_block_mem( - block, input_ids, pick_samples + device_1_memory = get_device_memory(device_list[1] if device_list else 1) + layer_memory_dict, layer_activation_memory, block_input_output_memory, additional_memory = ( + estimate_tuning_block_mem(block, input_ids, batch_size) ) + loss_memory = block_input_output_memory / 2 # GB, rough estimate for loss tensor memory if low_gpu_mem_usage: block_input_output_memory = 0 - # Calculate total block memory from layer memory dict (including both param and output memory) total_block_param_memory = sum(info["param_memory"] for info in layer_memory_dict.values()) - total_block_output_memory = sum(info["output_memory"] for info in layer_memory_dict.values()) # Average dispatch strategy # card_0_left_memory = card_0_mem - block_input_output_memory - additional_memory - layer_outputs_memory - logger.debug("Card 0 used memory details [Estimated]:") + card_0_used_memory = block_input_output_memory + layer_activation_memory + additional_memory + logger.debug(f"Card 0 used memory details [Estimated]: {card_0_used_memory} GB") logger.debug(f" Block input output cache memory: {block_input_output_memory} GB") - logger.debug(f" Quantized layer outputs memory: {total_block_output_memory} GB") + logger.debug(f" Quantized layer outputs memory: {layer_activation_memory} GB") logger.debug(f" Additional_memory from other ops: {additional_memory} GB") - card_0_left_memory = max( - 0, device_0_memory - block_input_output_memory - total_block_output_memory - additional_memory - ) + card_0_left_memory = max(0, (device_0_memory - card_0_used_memory)) + card_0_in_high_risk = card_0_used_memory / device_0_memory >= card_0_threshold + card_1_left_memory = max(0, device_1_memory - loss_memory) if card_0_in_high_risk else device_1_memory + loss_device = device_1 if card_0_in_high_risk else output_device # Calculate total available memory across all devices - total_available_memory = card_0_left_memory - for i in range(1, len(gpu_devices)): + total_available_memory = card_0_left_memory + card_1_left_memory + for i in range(2, len(gpu_devices)): device_idx = device_list[i] if device_list else i total_available_memory += get_device_memory(device_idx) @@ -916,6 +1050,8 @@ def set_auto_device_map_for_block_with_tuning( if has_params or has_buffers: module = module.to(output_device) + return card_0_in_high_risk, loss_device + def partition_dict_numbers(number_dict, n): """ @@ -986,18 +1122,18 @@ def find_optimal_subset(arr, target): def set_avg_auto_device_map(model: torch.nn.Module, device_map): block_name_list = get_block_names(model) device_list = None + if torch.cuda.is_available(): + num_devices = torch.cuda.device_count() + device_name = "cuda" + elif torch.xpu.is_available(): + num_devices = torch.xpu.device_count() + device_name = "xpu" + else: + return + if isinstance(device_map, str) and "," in device_map: device_list = [int(dev) for dev in device_map.split(",") if dev.isdigit()] num_devices = len(device_list) - else: - if torch.cuda.is_available(): - num_devices = torch.cuda.device_count() - device_name = "cuda" - elif torch.xpu.is_available(): - num_devices = torch.xpu.device_count() - device_name = "xpu" - else: - return if device_list: gpu_devices = [f"{device_name}:{i}" for i in device_list] diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index f8b42ec7a..4d2c7d5fd 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -550,7 +550,7 @@ def __init__(self, orig_layer, bit=4, group_size=-1, device="cpu"): def unwrapper(self, best_params): if best_params is None: return self.orig_layer - v = best_params["v"] + v = best_params["v"].to(self.device) weight_q, _, _ = self.quant_func( self.orig_layer.weight, self.bits, self.group_size, v, q_scale_thresh=self.q_scale_thresh ) @@ -601,7 +601,7 @@ def __init__(self, orig_layer, bit=4, group_size=-1, device="cpu"): def unwrapper(self, best_params): if best_params is None: return self.orig_layer - v = best_params["v"] + v = best_params["v"].to(self.device) weight_q, _, _ = self.quant_func( self.orig_layer.weight, self.bits, self.group_size, v, q_scale_thresh=self.q_scale_thresh )