Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer problem about loading lora weights #57

Open
dcfucheng opened this issue Nov 24, 2023 · 3 comments
Open

Infer problem about loading lora weights #57

dcfucheng opened this issue Nov 24, 2023 · 3 comments

Comments

@dcfucheng
Copy link

Hey~

Good jobs~ I have trained SD Lora on my custom dataset. But I have some problems with inference ONLY.

With the state_dict() we saved by
'''
lora_state_dict = get_peft_model_state_dict(unet_, adapter_name="default")
StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict)
'''

The keys of the saved model are named like
'''
base_model.model.mid_block.resnets.1.time_emb_proj.lora_B.weight
'''

But I checked the pytorch_lora_weights.safetensors are like
'''
lora_unet_up_blocks_2_attentions_0_proj_in.lora_up.weight
'''
which can be correctly loaded by "pipe.load_lora_weights()".

But the models we saved can not be loaded directly.
So, the question is how to load the Lora weights we save. Or should we convert the Lora weights before we save?

Thanks~

@JingyeChen
Copy link

i� have encounted the same problem

@dcfucheng
Copy link
Author

i� have encounted the same problem

I try to load Lora weight as this way. The weights can be loaded, but I train the SD2.1 which generates a noise picture. #65

So, I am not sure this is correct. You can try it. Welcome to discuss~

def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
    kohya_ss_state_dict = {}
    for peft_key, weight in module.items():
        kohya_key = peft_key.replace("unet.base_model.model", prefix)
        kohya_key = kohya_key.replace("lora_A", "lora_down")
        kohya_key = kohya_key.replace("lora_B", "lora_up")
        kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
        kohya_ss_state_dict[kohya_key] = weight.to(dtype)
        # Set alpha parameter
        if "lora_down" in kohya_key:
            alpha_key = f'{kohya_key.split(".")[0]}.alpha'
            kohya_ss_state_dict[alpha_key] = torch.tensor(8).to(dtype)

    return kohya_ss_state_dict

from safetensors.torch import load_file
lora_weight= load_file('/path/unet_lora/pytorch_lora_weights.safetensors')
lora_state_dict = get_module_kohya_state_dict(lora_weight, "lora_unet", torch.float16)
pipe.load_lora_weights(lora_state_dict)
pipe.fuse_lora()

@zjysteven
Copy link

zjysteven commented Apr 23, 2024

i� have encounted the same problem

I try to load Lora weight as this way. The weights can be loaded, but I train the SD2.1 which generates a noise picture. #65

So, I am not sure this is correct. You can try it. Welcome to discuss~

def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
    kohya_ss_state_dict = {}
    for peft_key, weight in module.items():
        kohya_key = peft_key.replace("unet.base_model.model", prefix)
        kohya_key = kohya_key.replace("lora_A", "lora_down")
        kohya_key = kohya_key.replace("lora_B", "lora_up")
        kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
        kohya_ss_state_dict[kohya_key] = weight.to(dtype)
        # Set alpha parameter
        if "lora_down" in kohya_key:
            alpha_key = f'{kohya_key.split(".")[0]}.alpha'
            kohya_ss_state_dict[alpha_key] = torch.tensor(8).to(dtype)

    return kohya_ss_state_dict

from safetensors.torch import load_file
lora_weight= load_file('/path/unet_lora/pytorch_lora_weights.safetensors')
lora_state_dict = get_module_kohya_state_dict(lora_weight, "lora_unet", torch.float16)
pipe.load_lora_weights(lora_state_dict)
pipe.fuse_lora()

This works, but one caveat is that in this current snippet kohya_ss_state_dict[alpha_key] = torch.tensor(8).to(dtype) the 8 which is the lora_alpha is hard-coded. Remember to change according to your config. An alternative which is more robust is to save the lora weights with the provided get_module_kohya_state_dict function in the training script:

def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
kohya_ss_state_dict = {}
for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items():
kohya_key = peft_key.replace("base_model.model", prefix)
kohya_key = kohya_key.replace("lora_A", "lora_down")
kohya_key = kohya_key.replace("lora_B", "lora_up")
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
kohya_ss_state_dict[kohya_key] = weight.to(dtype)
# Set alpha parameter
if "lora_down" in kohya_key:
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
return kohya_ss_state_dict

with which you can save the trained lora weights by

from safetensors.torch import save_file

lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype)
# this will add 'unet.' prefix to the state_dict keys
# StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict)
# instead can directly save the state_dict
save_file(lora_state_dict, os.path.join(output_dir, "unet_lora", "pytorch_lora_weights.safetensors"))

Then you should be able to load directly with pipe.load_lora_weights().

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants