From 22fe3d7bdcb2e42d523780e5cc36746ffaa7d7f5 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 27 Jul 2023 01:30:41 +0530 Subject: [PATCH] Faster controlnet model instantiation, and allow controlnets to be loaded (from ckpt) in a parallel thread with a SD model (ckpt) without tensor errors (race condition) --- .../pipelines/stable_diffusion/convert_from_ckpt.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index fdbe1dfaeffb..1560966f211c 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -1070,7 +1070,9 @@ def convert_controlnet_checkpoint( if cross_attention_dim is not None: ctrlnet_config["cross_attention_dim"] = cross_attention_dim - controlnet = ControlNetModel(**ctrlnet_config) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + controlnet = ControlNetModel(**ctrlnet_config) # Some controlnet ckpt files are distributed independently from the rest of the # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ @@ -1088,7 +1090,11 @@ def convert_controlnet_checkpoint( skip_extract_state_dict=skip_extract_state_dict, ) - controlnet.load_state_dict(converted_ctrl_checkpoint) + if is_accelerate_available(): + for param_name, param in converted_ctrl_checkpoint.items(): + set_module_tensor_to_device(controlnet, param_name, "cpu", value=param) + else: + controlnet.load_state_dict(converted_ctrl_checkpoint) return controlnet