From 22906abaea974456f29616491ba74f0752f47d1b Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Tue, 11 Nov 2025 22:24:21 -0500 Subject: [PATCH] dispatch model with real max memory Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index a2b61eae3..c098320f9 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -25,7 +25,7 @@ import accelerate import torch from accelerate.big_modeling import dispatch_model, infer_auto_device_map -from accelerate.utils import get_balanced_memory +from accelerate.utils import get_max_memory from torch import autocast from tqdm import tqdm from transformers import set_seed @@ -1992,11 +1992,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l if str(self.model.device) == "cpu" and (not self.device.startswith("hpu")): no_split_modules = getattr(self.model, "_no_split_modules", []) devices = parse_available_devices(self.device_map) - max_memory = get_balanced_memory( - self.model, - max_memory=None, - no_split_module_classes=no_split_modules, - ) + max_memory = get_max_memory() new_max_memory = {} for device in devices: if ":" in device: