diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index cfc7c5896f97..b2e91e3f6ac3 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -150,7 +150,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): - if is_safetensors_available(): + if (is_safetensors_available() and weight_name is None) or weight_name.endswith(".safetensors"): if weight_name is None: weight_name = LORA_WEIGHT_NAME_SAFE try: diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index bc025f6eeb56..c1f3bc05d7c6 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -445,6 +445,43 @@ def test_lora_save_load_safetensors(self): # LoRA and no LoRA should NOT be the same assert (sample - old_sample).abs().max() > 1e-4 + def test_lora_save_load_safetensors_load_torch(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + lora_attn_procs = {} + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRACrossAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + lora_attn_procs[name] = lora_attn_procs[name].to(model.device) + + model.set_attn_processor(lora_attn_procs) + # Saving as torch, properly reloads with directly filename + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin") + def test_lora_on_off(self): # enable deterministic behavior for gradient checkpointing init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()