Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified auto_round/alg_ext.abi3.so
Binary file not shown.
Binary file modified auto_round/auto_scheme/default_alg.abi3.so
Binary file not shown.
4 changes: 2 additions & 2 deletions auto_round/auto_scheme/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
get_layer_features,
get_module,
is_hpex_available,
parse_all_available_device,
parse_available_devices,
)


Expand Down Expand Up @@ -223,7 +223,7 @@ def dispatch_model_by_all_available_devices(
model = dispatch_model(model, device_map=device_map)
return model

devices = parse_all_available_device(device_map)
devices = parse_available_devices(device_map)

if len(devices) == 1:
model.to(devices[0])
Expand Down
62 changes: 32 additions & 30 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
get_layer_names_in_block,
get_module,
htcore,
is_complex_device_mapping,
is_auto_device_mapping,
is_debug_mode,
is_fp8_linear,
is_fp8_model,
Expand All @@ -97,7 +97,7 @@
from auto_round.utils.device import (
clear_memory_if_reached_threshold,
get_major_device,
parse_all_available_device,
parse_available_devices,
set_auto_device_map_for_block_with_tuning,
set_non_auto_device_map,
)
Expand Down Expand Up @@ -305,6 +305,8 @@ def __init__(
if isinstance(self.device_map, str):
self.device_map = self.device_map.replace(" ", "")

self.device_list = parse_available_devices(device_map)

if isinstance(scheme, AutoScheme):
self.layer_config = self._gen_auto_scheme(model, scheme, dataset, self.device_map)

Expand Down Expand Up @@ -1108,7 +1110,7 @@ def _quantize_embedding_layer(self):
self.layer_config.setdefault(name, {}).update(config)

# Release memory
clear_memory()
clear_memory(device_list=self.device_list)

return is_quantized

Expand Down Expand Up @@ -1177,7 +1179,7 @@ def get_imatrix_hook(module, input, output):

accelerate.hooks.remove_hook_from_submodules(model)
model = model.to("cpu")
clear_memory()
clear_memory(device_list=self.device_list)
self._quantize_via_rtn_blockwise(all_to_quantized_module_names)
except torch.OutOfMemoryError:
cuda_error_msg = traceback.format_exc()
Expand All @@ -1189,7 +1191,7 @@ def get_imatrix_hook(module, input, output):
"Consider enabling `low_gpu_mem_usage` or using more GPUs via `--device 0,1,2,3`."
)
model = model.to("cpu")
clear_memory()
clear_memory(device_list=self.device_list)
if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1:
import accelerate

Expand Down Expand Up @@ -1361,7 +1363,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
except torch.OutOfMemoryError:
logger.warning("Fallback to CPU. Consider using more GPUs via `--device 0,1,2,3`.")
self.model = self.model.to("cpu")
clear_memory()
clear_memory(device_list=self.device_list)
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
import accelerate

Expand All @@ -1383,7 +1385,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
pbar.set_description(f"Quantizing {name}")
self._quantize_layer_via_rtn(name)
if cnt % clear_mem_freq == 0:
clear_memory()
clear_memory(device_list=self.device_list)
cnt = 1
cnt += 1
# Convert remaining fp8
Expand Down Expand Up @@ -1432,7 +1434,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
)
inputs["input_ids"] = inputs.pop(input_keys[0])

clear_memory(self.inputs)
clear_memory(self.inputs, device_list=self.device_list)

total_samples = len(inputs["input_ids"])
if total_samples < self.batch_size:
Expand All @@ -1457,12 +1459,12 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
if is_fp8_model(self.model):
convert_fp8_model_to_16b_model(block, dtype=self.amp_dtype, device=self.device)

if is_complex_device_mapping(self.device_map):
if is_auto_device_mapping(self.device_map):
set_auto_device_map_for_block_with_tuning(
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, self.device
)
# Dispatch model if needed
if is_complex_device_mapping(self.device_map):
if len(self.device_list) > 0:
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

for _, m in block.named_modules():
Expand All @@ -1480,7 +1482,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
self.device,
self.cache_device,
)
if is_complex_device_mapping(self.device_map):
if len(self.device_list) > 1:
accelerate.hooks.remove_hook_from_submodules(block)

if is_nv_fp(self.act_data_type) or is_static_wfp8afp8(self):
Expand Down Expand Up @@ -1509,7 +1511,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
for name in all_to_quantized_module_names:
self._quantize_layer_via_rtn(name)
if cnt % clear_mem_freq == 0:
clear_memory()
clear_memory(device_list=self.device_list)
cnt = 1
cnt += 1

Expand Down Expand Up @@ -1609,12 +1611,12 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
all_q_inputs = None
if is_quantized_embedding:
all_inputs = copy.deepcopy(self.inputs)
clear_memory(self.inputs)
clear_memory(self.inputs, device_list=self.device_list)
all_q_inputs = self.try_cache_inter_data_gpucpu(
all_first_block_names, self.nsamples, layer_names=layer_names
)
self.model = mv_module_from_gpu(self.model)
clear_memory()
clear_memory(device_list=self.device_list)
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
accelerate.hooks.remove_hook_from_submodules(self.model) # self.model.hf_device_map has not been changed
self.model = mv_module_from_gpu(self.model)
Expand All @@ -1634,7 +1636,7 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:

inputs, q_inputs = self._update_inputs(inputs, q_inputs)

clear_memory(self.inputs)
clear_memory(self.inputs, device_list=self.device_list)

if "input_ids" in inputs.keys():
total_samples = len(inputs["input_ids"])
Expand Down Expand Up @@ -1751,7 +1753,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
) ##self.model.hf_device_map has not been changed
if not self.immediate_saving:
self.model = mv_module_from_gpu(self.model)
clear_memory()
clear_memory(device_list=self.device_list)
quant_layer = self._quantize_layer
for layer_name in layer_names:
layer_input = layer_inputs[layer_name]
Expand All @@ -1766,7 +1768,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
m = get_module(self.model, layer_name)
immediate_saving(self, m, name=layer_name, last_group=True)
del layer_input
clear_memory(q_layer_input)
clear_memory(q_layer_input, device_list=self.device_list)

@torch.no_grad()
def _get_block_outputs(
Expand Down Expand Up @@ -1811,7 +1813,7 @@ def _get_block_outputs(
else:
output.extend(list(torch.split(tmp_output, 1, dim=self.batch_dim)))
if self.low_gpu_mem_usage:
clear_memory()
clear_memory(device_list=self.device_list)

return output

Expand Down Expand Up @@ -1983,7 +1985,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
# Change this if new device is supported
if str(self.model.device) == "cpu" and (not self.device.startswith("hpu")):
no_split_modules = getattr(self.model, "_no_split_modules", [])
devices = parse_all_available_device(self.device_map)
devices = parse_available_devices(self.device_map)
max_memory = get_balanced_memory(
self.model,
max_memory=None,
Expand Down Expand Up @@ -2026,7 +2028,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
self.model
) # self.model.hf_device_map has not been changed
self.model = mv_module_from_gpu(self.model)
clear_memory()
clear_memory(device_list=self.device_list)
# Important change after v0.51, on cpu, we use rtn mode for layers in layer_names
all_inputs = self.cache_inter_data(
block_names, nsamples, layer_names=[], last_cache_name=last_cache_name
Expand Down Expand Up @@ -2504,7 +2506,7 @@ def _quantize_block(
block = block.to(device)
card_0_in_high_risk, loss_device = False, device

if is_complex_device_mapping(self.device_map):
if len(self.device_list) > 1:
for n, m in block.named_modules():
if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"):
continue
Expand Down Expand Up @@ -2543,9 +2545,9 @@ def _quantize_block(

if q_input is not None:
if input_ids is not q_input:
clear_memory(input_ids)
clear_memory(input_ids, device_list=self.device_list)
else:
clear_memory()
clear_memory(device_list=self.device_list)
input_ids = q_input

quantized_layer_names, unquantized_layer_names = wrapper_block(
Expand Down Expand Up @@ -2660,13 +2662,13 @@ def _quantize_block(

if self.low_gpu_mem_usage and card_0_in_high_risk:
# clear memory to avoid OOM due to memory fragmentation
clear_memory_if_reached_threshold(threshold=0.5)
clear_memory_if_reached_threshold(threshold=0.5, device_list=self.device_list)

self._scale_loss_and_backward(scaler, loss)

if self.low_gpu_mem_usage and card_0_in_high_risk:
# clear memory to avoid OOM due to memory fragmentation
clear_memory_if_reached_threshold(threshold=0.8)
clear_memory_if_reached_threshold(threshold=0.8, device_list=self.device_list)

if i == 0:
init_loss = total_loss
Expand Down Expand Up @@ -2716,15 +2718,15 @@ def _quantize_block(
device,
cache_device=self.cache_device,
)
if is_complex_device_mapping(self.device_map):
if len(self.device_list) > 1:
accelerate.hooks.remove_hook_from_submodules(block)
mv_module_from_gpu(block)
clear_memory(input_ids)

return q_outputs, output

else:
if is_complex_device_mapping(self.device_map):
if len(self.device_list) > 1:
accelerate.hooks.remove_hook_from_submodules(block)
mv_module_from_gpu(block)
clear_memory(input_ids)
Expand Down Expand Up @@ -2758,12 +2760,12 @@ def _quantize_blocks(
Returns:
None
"""
clear_memory()
clear_memory(device_list=self.device_list)
for n, m in model.named_parameters():
m.requires_grad_(False)

input_ids, input_others = self._split_inputs(inputs)
clear_memory()
clear_memory(device_list=self.device_list)
input_ids = to_device(input_ids, self.cache_device)
input_others = to_device(input_others, self.cache_device)
# As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage
Expand Down Expand Up @@ -2867,7 +2869,7 @@ def _quantize_blocks(
del input_others
del inputs

clear_memory()
clear_memory(device_list=self.device_list)

def save_quantized(
self, output_dir: str = None, format: str = "auto_round", inplace: bool = True, **kwargs
Expand Down
Loading