diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index deb6f2122..62ac63eb9 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -987,12 +987,7 @@ def set_auto_device_map_for_block_with_tuning( Note: This function is intended for internal use in device memory management and tuning. """ - if not (device_map == "auto" or ((isinstance(device_map, str) and "," in device_map))): - block = block.to(output_device) - card_0_in_high_risk = False # card 0 contains weight, clear_memory will not help much - loss_device = output_device - return card_0_in_high_risk, loss_device - + card_0_in_high_risk, loss_device = False, output_device if torch.cuda.is_available(): num_devices = torch.cuda.device_count() device_name = "cuda" @@ -1000,7 +995,14 @@ def set_auto_device_map_for_block_with_tuning( num_devices = torch.xpu.device_count() device_name = "xpu" else: - return + return card_0_in_high_risk, loss_device + + if not ( + device_map == "auto" or ((isinstance(device_map, str) and "," in device_map)) or num_devices > 1 + ): # Only 1 card is available or non-auto device map + block = block.to(output_device) + return card_0_in_high_risk, loss_device + device_list = None if isinstance(device_map, str) and "," in device_map: device_list = [int(dev) for dev in device_map.split(",") if dev.isdigit()]