Skip to content
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c910f44
fix multple device issue
wenhuach21 Nov 7, 2025
c744521
refine
wenhuach21 Nov 7, 2025
7ac0451
update
wenhuach21 Nov 7, 2025
064f2a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2025
9a7410d
update
wenhuach21 Nov 7, 2025
6a7bfc4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2025
90cb314
update
wenhuach21 Nov 7, 2025
90f54be
Merge branch 'fix_multi_device_issue' of https://github.com/intel/aut…
wenhuach21 Nov 7, 2025
ca7000e
Update auto_round/utils/device.py
wenhuach21 Nov 7, 2025
a3139c1
Update auto_round/utils/device.py
wenhuach21 Nov 7, 2025
af420dd
fix
wenhuach21 Nov 7, 2025
4758f11
Merge branch 'main' into fix_multi_device_issue
wenhuach21 Nov 7, 2025
46ea267
Update auto_round/utils/device.py
wenhuach21 Nov 7, 2025
f383322
Update auto_round/utils/device.py
wenhuach21 Nov 7, 2025
5fc528a
fix
wenhuach21 Nov 7, 2025
edbecee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2025
ec05a9d
add comments
wenhuach21 Nov 7, 2025
a4bad33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2025
72b2a68
update alt_ext
wenhuach21 Nov 8, 2025
cb25e15
Merge branch 'fix_multi_device_issue' of https://github.com/intel/aut…
wenhuach21 Nov 8, 2025
7e11734
add device list to clear memory
wenhuach21 Nov 10, 2025
8101c25
Merge branch 'main' into fix_multi_device_issue
wenhuach21 Nov 10, 2025
0a225c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2025
64f3085
fix bug
wenhuach21 Nov 10, 2025
0ff0cbe
Merge branch 'fix_multi_device_issue' of https://github.com/intel/aut…
wenhuach21 Nov 10, 2025
5ed0257
fix bug
wenhuach21 Nov 10, 2025
73cb6c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,12 +1459,12 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
if is_fp8_model(self.model):
convert_fp8_model_to_16b_model(block, dtype=self.amp_dtype, device=self.device)

if is_auto_device_mapping(self.device_map):
if is_auto_device_mapping(self.device_map) and len(self.device_list) > 1:
set_auto_device_map_for_block_with_tuning(
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, self.device
)
# Dispatch model if needed
if len(self.device_list) > 0:
if len(self.device_list) > 1:
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

for _, m in block.named_modules():
Expand Down Expand Up @@ -2498,7 +2498,7 @@ def _quantize_block(
set_module(block, n, new_layer)
# card_0_in_high_risk indicates that card_0 memory is already in high usage (90%) w/o any weights
# loss_device is used to calculate loss on the second device if available and card_0_in_high_risk
if self.device_map == "auto" or ((isinstance(self.device_map, str) and "," in self.device_map)):
if is_auto_device_mapping(self.device_map) and len(self.device_list) > 1:
card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning(
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device
)
Expand Down Expand Up @@ -2699,7 +2699,7 @@ def _quantize_block(
)
logger.info(dump_info)
if self.low_gpu_mem_usage:
clear_memory() # clear cached memory during training
clear_memory(self.device_list) # clear cached memory during training
if len(unquantized_layer_names) != 0:
logger.info(f"{unquantized_layer_names} have not been quantized")
with torch.no_grad():
Expand All @@ -2721,15 +2721,15 @@ def _quantize_block(
if len(self.device_list) > 1:
accelerate.hooks.remove_hook_from_submodules(block)
mv_module_from_gpu(block)
clear_memory(input_ids)
clear_memory(input_ids, self.device_list)

return q_outputs, output

else:
if len(self.device_list) > 1:
accelerate.hooks.remove_hook_from_submodules(block)
mv_module_from_gpu(block)
clear_memory(input_ids)
clear_memory(input_ids, self.device_list)
return None, output

def _split_inputs(self, inputs: dict) -> tuple[torch.Tensor, dict]:
Expand Down