diff --git a/auto_round/alg_ext.abi3.so b/auto_round/alg_ext.abi3.so index 13829d331..b89989d35 100755 Binary files a/auto_round/alg_ext.abi3.so and b/auto_round/alg_ext.abi3.so differ diff --git a/auto_round/auto_scheme/default_alg.abi3.so b/auto_round/auto_scheme/default_alg.abi3.so index 41d8d5634..2651f8cf5 100644 Binary files a/auto_round/auto_scheme/default_alg.abi3.so and b/auto_round/auto_scheme/default_alg.abi3.so differ diff --git a/auto_round/auto_scheme/utils.py b/auto_round/auto_scheme/utils.py index 67b7fd428..b355b4703 100644 --- a/auto_round/auto_scheme/utils.py +++ b/auto_round/auto_scheme/utils.py @@ -27,7 +27,7 @@ get_layer_features, get_module, is_hpex_available, - parse_all_available_device, + parse_available_devices, ) @@ -223,7 +223,7 @@ def dispatch_model_by_all_available_devices( model = dispatch_model(model, device_map=device_map) return model - devices = parse_all_available_device(device_map) + devices = parse_available_devices(device_map) if len(devices) == 1: model.to(devices[0]) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 9bffdbc12..45ed0c14f 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -81,7 +81,7 @@ get_layer_names_in_block, get_module, htcore, - is_complex_device_mapping, + is_auto_device_mapping, is_debug_mode, is_fp8_linear, is_fp8_model, @@ -97,7 +97,7 @@ from auto_round.utils.device import ( clear_memory_if_reached_threshold, get_major_device, - parse_all_available_device, + parse_available_devices, set_auto_device_map_for_block_with_tuning, set_non_auto_device_map, ) @@ -305,6 +305,8 @@ def __init__( if isinstance(self.device_map, str): self.device_map = self.device_map.replace(" ", "") + self.device_list = parse_available_devices(device_map) + if isinstance(scheme, AutoScheme): self.layer_config = self._gen_auto_scheme(model, scheme, dataset, self.device_map) @@ -1108,7 +1110,7 @@ def _quantize_embedding_layer(self): self.layer_config.setdefault(name, {}).update(config) # Release memory - clear_memory() + clear_memory(device_list=self.device_list) return is_quantized @@ -1177,7 +1179,7 @@ def get_imatrix_hook(module, input, output): accelerate.hooks.remove_hook_from_submodules(model) model = model.to("cpu") - clear_memory() + clear_memory(device_list=self.device_list) self._quantize_via_rtn_blockwise(all_to_quantized_module_names) except torch.OutOfMemoryError: cuda_error_msg = traceback.format_exc() @@ -1189,7 +1191,7 @@ def get_imatrix_hook(module, input, output): "Consider enabling `low_gpu_mem_usage` or using more GPUs via `--device 0,1,2,3`." ) model = model.to("cpu") - clear_memory() + clear_memory(device_list=self.device_list) if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1: import accelerate @@ -1361,7 +1363,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: except torch.OutOfMemoryError: logger.warning("Fallback to CPU. Consider using more GPUs via `--device 0,1,2,3`.") self.model = self.model.to("cpu") - clear_memory() + clear_memory(device_list=self.device_list) if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: import accelerate @@ -1383,7 +1385,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: pbar.set_description(f"Quantizing {name}") self._quantize_layer_via_rtn(name) if cnt % clear_mem_freq == 0: - clear_memory() + clear_memory(device_list=self.device_list) cnt = 1 cnt += 1 # Convert remaining fp8 @@ -1432,7 +1434,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) ) inputs["input_ids"] = inputs.pop(input_keys[0]) - clear_memory(self.inputs) + clear_memory(self.inputs, device_list=self.device_list) total_samples = len(inputs["input_ids"]) if total_samples < self.batch_size: @@ -1457,12 +1459,12 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) if is_fp8_model(self.model): convert_fp8_model_to_16b_model(block, dtype=self.amp_dtype, device=self.device) - if is_complex_device_mapping(self.device_map): + if is_auto_device_mapping(self.device_map): set_auto_device_map_for_block_with_tuning( block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, self.device ) # Dispatch model if needed - if is_complex_device_mapping(self.device_map): + if len(self.device_list) > 0: from accelerate.hooks import AlignDevicesHook, add_hook_to_module for _, m in block.named_modules(): @@ -1480,7 +1482,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) self.device, self.cache_device, ) - if is_complex_device_mapping(self.device_map): + if len(self.device_list) > 1: accelerate.hooks.remove_hook_from_submodules(block) if is_nv_fp(self.act_data_type) or is_static_wfp8afp8(self): @@ -1509,7 +1511,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) for name in all_to_quantized_module_names: self._quantize_layer_via_rtn(name) if cnt % clear_mem_freq == 0: - clear_memory() + clear_memory(device_list=self.device_list) cnt = 1 cnt += 1 @@ -1609,12 +1611,12 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: all_q_inputs = None if is_quantized_embedding: all_inputs = copy.deepcopy(self.inputs) - clear_memory(self.inputs) + clear_memory(self.inputs, device_list=self.device_list) all_q_inputs = self.try_cache_inter_data_gpucpu( all_first_block_names, self.nsamples, layer_names=layer_names ) self.model = mv_module_from_gpu(self.model) - clear_memory() + clear_memory(device_list=self.device_list) if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: accelerate.hooks.remove_hook_from_submodules(self.model) # self.model.hf_device_map has not been changed self.model = mv_module_from_gpu(self.model) @@ -1634,7 +1636,7 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: inputs, q_inputs = self._update_inputs(inputs, q_inputs) - clear_memory(self.inputs) + clear_memory(self.inputs, device_list=self.device_list) if "input_ids" in inputs.keys(): total_samples = len(inputs["input_ids"]) @@ -1751,7 +1753,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: ) ##self.model.hf_device_map has not been changed if not self.immediate_saving: self.model = mv_module_from_gpu(self.model) - clear_memory() + clear_memory(device_list=self.device_list) quant_layer = self._quantize_layer for layer_name in layer_names: layer_input = layer_inputs[layer_name] @@ -1766,7 +1768,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: m = get_module(self.model, layer_name) immediate_saving(self, m, name=layer_name, last_group=True) del layer_input - clear_memory(q_layer_input) + clear_memory(q_layer_input, device_list=self.device_list) @torch.no_grad() def _get_block_outputs( @@ -1811,7 +1813,7 @@ def _get_block_outputs( else: output.extend(list(torch.split(tmp_output, 1, dim=self.batch_dim))) if self.low_gpu_mem_usage: - clear_memory() + clear_memory(device_list=self.device_list) return output @@ -1983,7 +1985,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l # Change this if new device is supported if str(self.model.device) == "cpu" and (not self.device.startswith("hpu")): no_split_modules = getattr(self.model, "_no_split_modules", []) - devices = parse_all_available_device(self.device_map) + devices = parse_available_devices(self.device_map) max_memory = get_balanced_memory( self.model, max_memory=None, @@ -2026,7 +2028,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l self.model ) # self.model.hf_device_map has not been changed self.model = mv_module_from_gpu(self.model) - clear_memory() + clear_memory(device_list=self.device_list) # Important change after v0.51, on cpu, we use rtn mode for layers in layer_names all_inputs = self.cache_inter_data( block_names, nsamples, layer_names=[], last_cache_name=last_cache_name @@ -2504,7 +2506,7 @@ def _quantize_block( block = block.to(device) card_0_in_high_risk, loss_device = False, device - if is_complex_device_mapping(self.device_map): + if len(self.device_list) > 1: for n, m in block.named_modules(): if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"): continue @@ -2543,9 +2545,9 @@ def _quantize_block( if q_input is not None: if input_ids is not q_input: - clear_memory(input_ids) + clear_memory(input_ids, device_list=self.device_list) else: - clear_memory() + clear_memory(device_list=self.device_list) input_ids = q_input quantized_layer_names, unquantized_layer_names = wrapper_block( @@ -2660,13 +2662,13 @@ def _quantize_block( if self.low_gpu_mem_usage and card_0_in_high_risk: # clear memory to avoid OOM due to memory fragmentation - clear_memory_if_reached_threshold(threshold=0.5) + clear_memory_if_reached_threshold(threshold=0.5, device_list=self.device_list) self._scale_loss_and_backward(scaler, loss) if self.low_gpu_mem_usage and card_0_in_high_risk: # clear memory to avoid OOM due to memory fragmentation - clear_memory_if_reached_threshold(threshold=0.8) + clear_memory_if_reached_threshold(threshold=0.8, device_list=self.device_list) if i == 0: init_loss = total_loss @@ -2716,7 +2718,7 @@ def _quantize_block( device, cache_device=self.cache_device, ) - if is_complex_device_mapping(self.device_map): + if len(self.device_list) > 1: accelerate.hooks.remove_hook_from_submodules(block) mv_module_from_gpu(block) clear_memory(input_ids) @@ -2724,7 +2726,7 @@ def _quantize_block( return q_outputs, output else: - if is_complex_device_mapping(self.device_map): + if len(self.device_list) > 1: accelerate.hooks.remove_hook_from_submodules(block) mv_module_from_gpu(block) clear_memory(input_ids) @@ -2758,12 +2760,12 @@ def _quantize_blocks( Returns: None """ - clear_memory() + clear_memory(device_list=self.device_list) for n, m in model.named_parameters(): m.requires_grad_(False) input_ids, input_others = self._split_inputs(inputs) - clear_memory() + clear_memory(device_list=self.device_list) input_ids = to_device(input_ids, self.cache_device) input_others = to_device(input_others, self.cache_device) # As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage @@ -2867,7 +2869,7 @@ def _quantize_blocks( del input_others del inputs - clear_memory() + clear_memory(device_list=self.device_list) def save_quantized( self, output_dir: str = None, format: str = "auto_round", inplace: bool = True, **kwargs diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index d95fd6435..deb6f2122 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -355,7 +355,7 @@ def get_packing_device(device: str | torch.device | None = "auto") -> torch.devi raise TypeError(f"Unsupported device type: {type(device)} ({device})") -def is_complex_device_mapping(device_map): +def is_auto_device_mapping(device_map: str | int | dict | None): if device_map is None or isinstance(device_map, int): return False elif device_map == "auto": @@ -363,7 +363,7 @@ def is_complex_device_mapping(device_map): elif isinstance(device_map, str) and "," in device_map: return True elif isinstance(device_map, dict): - return True + return False else: return False @@ -404,7 +404,9 @@ def bytes_to_gigabytes(bytes) -> int: return bytes / 1024 / 1024 / 1024 -def _clear_memory_for_cpu_and_cuda(tensor=None): +def _clear_memory_for_cpu_and_cuda( + tensor: torch.Tensor | list[torch.Tensor] | None = None, device_list: tuple | list | None = None +): if isinstance(tensor, list): for i in range(len(tensor)): tensor[i] = None @@ -412,23 +414,40 @@ def _clear_memory_for_cpu_and_cuda(tensor=None): del tensor gc.collect() if torch.cuda.is_available(): - torch.cuda.empty_cache() + if device_list is None: + torch.cuda.synchronize() + # Fix https://github.com/intel/auto-round/issues/1004 + torch.cuda.empty_cache() + + elif len(device_list) > 1: + devices = [] + for device in device_list: + if not device.startswith("cuda"): + continue + if ":" in device: + device = device.split(":")[-1] + else: + device = 0 + devices.append(int(device)) + for device in devices: + torch.cuda.synchronize(device) + torch.cuda.empty_cache() if torch.xpu.is_available(): torch.xpu.empty_cache() @torch._dynamo.disable() -def clear_memory(tensor=None): +def clear_memory(tensor: torch.Tensor | None | list[torch.Tensor] = None, device_list: list | tuple | None = None): from auto_round.utils.device import is_hpex_available if is_hpex_available(): # hpu does not have empty_cache return else: - _clear_memory_for_cpu_and_cuda(tensor) + _clear_memory_for_cpu_and_cuda(tensor, device_list) -def clear_memory_if_reached_threshold(threshold=0.85): +def clear_memory_if_reached_threshold(threshold=0.85, device_list=None): """Check all available devices and clear memory if any device is using close to the threshold. Args: @@ -463,7 +482,7 @@ def clear_memory_if_reached_threshold(threshold=0.85): "To alleviate high memory usage on the major device, consider reducing the `batch_size` " + "(and correspondingly increasing `gradient_accumulation_steps) or shortening the seqlen." ) - clear_memory() + clear_memory(device_list=device_list) return True except Exception as e: logger.warning_once(f"Failed to check memory for {name}:{i}: {e}") @@ -1122,25 +1141,16 @@ def find_optimal_subset(arr, target): def set_avg_auto_device_map(model: torch.nn.Module, device_map): block_name_list = get_block_names(model) - device_list = None - if torch.cuda.is_available(): - num_devices = torch.cuda.device_count() - device_name = "cuda" - elif torch.xpu.is_available(): - num_devices = torch.xpu.device_count() - device_name = "xpu" - else: + device_list = parse_available_devices(device_map) + gpu_devices = [] + for device in device_list: + if device.startswith("cpu") or device.startswith("hpu"): + continue + gpu_devices.append(device) + num_devices = len(gpu_devices) + if num_devices < 1: return - if isinstance(device_map, str) and "," in device_map: - device_list = [int(dev) for dev in device_map.split(",") if dev.isdigit()] - num_devices = len(device_list) - - if device_list: - gpu_devices = [f"{device_name}:{i}" for i in device_list] - else: - gpu_devices = [f"{device_name}:{i}" for i in range(num_devices)] - for block_names in block_name_list: for block_name in block_names: params_dict = {} @@ -1184,7 +1194,7 @@ def set_avg_auto_device_map(model: torch.nn.Module, device_map): print(f"Group {i + 1}: {group}, Sum: {sum(group.values())}") -def parse_all_available_device(device_map: Union[str, torch.device, int, dict, None] = None) -> list: +def parse_available_devices(device_map: Union[str, torch.device, int, dict, None] = None) -> list: """ Parse the device map and return a list of all available devices. @@ -1242,6 +1252,21 @@ def parse_all_available_device(device_map: Union[str, torch.device, int, dict, N device_type = device_types[0] return [f"{device_type}:{device_map}"] if device_type != "cpu" else ["cpu"] + # ---- dict-like string ---- + if isinstance(device_map, str) and ":" in device_map and "," in device_map: + pairs = [p.strip() for p in device_map.split(",") if ":" in p] + devices = [] + for pair in pairs: + try: + key, *value_parts = pair.split(":") + value = ":".join(value_parts).strip() + if value.isdigit() and device_types[0] != "cpu": + value = device_types[0] + ":" + value + devices.append(value) + except ValueError: + continue + return devices + if isinstance(device_map, str): # Remove whitespace device_map = device_map.strip() @@ -1264,7 +1289,7 @@ def parse_all_available_device(device_map: Union[str, torch.device, int, dict, N # Extract all devices recursively from dict values devices = set() for v in device_map.values(): - devices.update(parse_all_available_device(v)) + devices.update(parse_available_devices(v)) return sorted(devices) raise TypeError(f"Unsupported device_map type: {type(device_map)}")