-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Open
Labels
Description
Describe the bug
i noticed that when i add align_device_hook to module in pipeline manually, then load_lora_weights function will enable the sequential cpu offload. so i dig deeper and find that load_lora_weights function use _optionally_disable_offloading function to decide whether to sequentially cpu offload. this use _optionally_disable_offloading function was:
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)so i was curious that why is_sequential_cpu_offload = True when component has AlignDevicesHook? Shouldn't it be True only when the component device is CPU?
Reproduction
from diffusers import StableDiffusionControlNetImg2ImgPipeline,ControlNetModel
from accelerate.hooks import attach_align_device_hook_on_blocks
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained("/media/74nvme/checkpoints/diffusers_models/stable-diffusion-v1-5/",controlnet=[controlnet1,controlnet2],torch_dtype=torch.float16).to('cuda:0')
module_names, _ = pipe._get_signature_keys(pipe)
modules = [getattr(pipe, n, None) for n in module_names]
module_names = [name for m,name in zip(modules,module_names) if isinstance(m, torch.nn.Module)]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
print(module_names)
for module,name in zip(modules,module_names):
if name == 'unet' or name == 'controlnet':
module.to('cuda:0')
attach_align_device_hook_on_blocks(
module,
execution_device=module.device,
)
else:
module.to('cuda:1')
attach_align_device_hook_on_blocks(
module,
execution_device=module.device,
)
and then
pipe.load_lora_weights(lora_weights_path) will change all component device
Logs
No response
System Info
diffusers:0.25.1
torch:2.2.0+cu118
Who can help?
No response