Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model finetuned using finetune_adapter not directly usable in generaete/chat... How to convert? #78

Closed
RDouglasSharp opened this issue May 19, 2023 · 3 comments
Labels
enhancement New feature or request generation

Comments

@RDouglasSharp
Copy link

RDouglasSharp commented May 19, 2023

I used the finetune_adapter.py script to generate a tuned model. I tried loading that tuned model back into chat.py, and I get the following error upon load:

RuntimeError: Error(s) in loading state_dict for Parrot:
        Missing key(s) in state_dict: "lm_head.weight", "transformer.wte.weight", "transformer.h.0.norm_1.weight", "transformer.h.0.norm_1.bias", "transformer.h.0.attn.attn.weight", "transformer.h.0.attn.attn.bias", "transformer.h.0.attn.proj.weight",
"transformer.h.0.attn.proj.bias", "transformer.h.0.norm_2.weight", "transformer.h.0.norm_2.bias", "transformer.h.0.mlp.fc.weight", "transformer.h.0.mlp.fc.bias", "transformer.h.0.mlp.proj.weight", "transformer.h.0.mlp.proj.bias", "transformer.h.1.norm_1.weight",
"transformer.h.1.norm_1.bias", "transformer.h.1.attn.attn.weight", "transformer.h.1.attn.attn.bias", "transformer.h.1.attn.proj.weight", "transformer.h.1.attn.proj.bias", "transformer.h.1.norm_2.weight", "transformer.h.1.norm_2.bias", "transformer.h.1.mlp.fc.weight",
"transformer.h.1.mlp.fc.bias", "transformer.h.1.mlp.proj.weight", "transformer.h.1.mlp.proj.bias", "transformer.h.2.norm_1.weight", "transformer.h.2.norm_1.bias", "transformer.h.2.attn.attn.weight", "transformer.h.2.attn.attn.bias", "transformer.h.2.attn.proj.weight",
"transformer.h.2.attn.proj.bias", "transformer.h.2.norm_2.weight", "transformer.h.2.norm_2.bias", "transformer.h.2.mlp.fc.weight", "transformer.h.2.mlp.fc.bias", "transformer.h.2.mlp.proj.weight", "transformer.h.2.mlp.proj.bias", "transformer.h.3.norm_1.weight",
"transformer.h.3.norm_1.bias", "transformer.h.3.attn.attn.weight", "transformer.h.3.attn.attn.bias", "transformer.h.3.attn.proj.weight", "transformer.h.3.attn.proj.bias", "transformer.h.3.norm_2.weight", "transformer.h.3.norm_2.bias", "transformer.h.3.mlp.fc.weight",
"transformer.h.3.mlp.fc.bias", "transformer.h.3.mlp.proj.weight", "transformer.h.3.mlp.proj.bias", "transformer.h.4.norm_1.weight", "transformer.h.4.norm_1.bias", "transformer.h.4.attn.attn.weight", "transformer.h.4.attn.attn.bias", "transformer.h.4.attn.proj.weight",
"transformer.h.4.attn.proj.bias", "transformer.h.4.norm_2.weight", "transformer.h.4.norm_2.bias", "transformer.h.4.mlp.fc.weight", "transformer.h.4.mlp.fc.bias", "transformer.h.4.mlp.proj.weight", "transformer.h.4.mlp.proj.bias", "transformer.h.5.norm_1.weight",
"transformer.h.5.norm_1.bias", "transformer.h.5.attn.attn.weight", "transformer.h.5.attn.attn.bias", "transformer.h.5.attn.proj.weight", "transformer.h.5.attn.proj.bias", "transformer.h.5.norm_2.weight", "transformer.h.5.norm_2.bias", "transformer.h.5.mlp.fc.weight",
"transformer.h.5.mlp.fc.bias", "transformer.h.5.mlp.proj.weight", "transformer.h.5.mlp.proj.bias", "transformer.h.6.norm_1.weight", "transformer.h.6.norm_1.bias", "transformer.h.6.attn.attn.weight", "transformer.h.6.attn.attn.bias", "transformer.h.6.attn.proj.weight",
"transformer.h.6.attn.proj.bias", "transformer.h.6.norm_2.weight", "transformer.h.6.norm_2.bias", "transformer.h.6.mlp.fc.weight", "transformer.h.6.mlp.fc.bias", "transformer.h.6.mlp.proj.weight", "transformer.h.6.mlp.proj.bias", "transformer.h.7.norm_1.weight",
"transformer.h.7.norm_1.bias", "transformer.h.7.attn.attn.weight", "transformer.h.7.attn.attn.bias", "transformer.h.7.attn.proj.weight", "transformer.h.7.attn.proj.bias", "transformer.h.7.norm_2.weight", "transformer.h.7.norm_2.bias", "transformer.h.7.mlp.fc.weight",
"transformer.h.7.mlp.fc.bias", "transformer.h.7.mlp.proj.weight", "transformer.h.7.mlp.proj.bias", "transformer.h.8.norm_1.weight", "transformer.h.8.norm_1.bias", "transformer.h.8.attn.attn.weight", "transformer.h.8.attn.attn.bias", "transformer.h.8.attn.proj.weight",
"transformer.h.8.attn.proj.bias", "transformer.h.8.norm_2.weight", "transformer.h.8.norm_2.bias", "transformer.h.8.mlp.fc.weight", "transformer.h.8.mlp.fc.bias", "transformer.h.8.mlp.proj.weight", "transformer.h.8.mlp.proj.bias", "transformer.h.9.norm_1.weight",
"transformer.h.9.norm_1.bias", "transformer.h.9.attn.attn.weight", "transformer.h.9.attn.attn.bias", "transformer.h.9.attn.proj.weight", "transformer.h.9.attn.proj.bias", "transformer.h.9.norm_2.weight", "transformer.h.9.norm_2.bias", "transformer.h.9.mlp.fc.weight",
"transformer.h.9.mlp.fc.bias", "transformer.h.9.mlp.proj.weight", "transformer.h.9.mlp.proj.bias", "transformer.h.10.norm_1.weight", "transformer.h.10.norm_1.bias", "transformer.h.10.attn.attn.weight", "transformer.h.10.attn.attn.bias",
"transformer.h.10.attn.proj.weight", "transformer.h.10.attn.proj.bias", "transformer.h.10.norm_2.weight", "transformer.h.10.norm_2.bias", "transformer.h.10.mlp.fc.weight", "transformer.h.10.mlp.fc.bias", "transformer.h.10.mlp.proj.weight",
"transformer.h.10.mlp.proj.bias", "transformer.h.11.norm_1.weight", "transformer.h.11.norm_1.bias", "transformer.h.11.attn.attn.weight", "transformer.h.11.attn.attn.bias", "transformer.h.11.attn.proj.weight", "transformer.h.11.attn.proj.bias",
"transformer.h.11.norm_2.weight", "transformer.h.11.norm_2.bias", "transformer.h.11.mlp.fc.weight", "transformer.h.11.mlp.fc.bias", "transformer.h.11.mlp.proj.weight", "transformer.h.11.mlp.proj.bias", "transformer.h.12.norm_1.weight", "transformer.h.12.norm_1.bias",
"transformer.h.12.attn.attn.weight", "transformer.h.12.attn.attn.bias", "transformer.h.12.attn.proj.weight", "transformer.h.12.attn.proj.bias", "transformer.h.12.norm_2.weight", "transformer.h.12.norm_2.bias", "transformer.h.12.mlp.fc.weight",
"transformer.h.12.mlp.fc.bias", "transformer.h.12.mlp.proj.weight", "transformer.h.12.mlp.proj.bias", "transformer.h.13.norm_1.weight", "transformer.h.13.norm_1.bias", "transformer.h.13.attn.attn.weight", "transformer.h.13.attn.attn.bias",
"transformer.h.13.attn.proj.weight", "transformer.h.13.attn.proj.bias", "transformer.h.13.norm_2.weight", "transformer.h.13.norm_2.bias", "transformer.h.13.mlp.fc.weight", "transformer.h.13.mlp.fc.bias", "transformer.h.13.mlp.proj.weight",
"transformer.h.13.mlp.proj.bias", "transformer.h.14.norm_1.weight", "transformer.h.14.norm_1.bias", "transformer.h.14.attn.attn.weight", "transformer.h.14.attn.attn.bias", "transformer.h.14.attn.proj.weight", "transformer.h.14.attn.proj.bias",
"transformer.h.14.norm_2.weight", "transformer.h.14.norm_2.bias", "transformer.h.14.mlp.fc.weight", "transformer.h.14.mlp.fc.bias", "transformer.h.14.mlp.proj.weight", "transformer.h.14.mlp.proj.bias", "transformer.h.15.norm_1.weight", "transformer.h.15.norm_1.bias",
"transformer.h.15.attn.attn.weight", "transformer.h.15.attn.attn.bias", "transformer.h.15.attn.proj.weight", "transformer.h.15.attn.proj.bias", "transformer.h.15.norm_2.weight", "transformer.h.15.norm_2.bias", "transformer.h.15.mlp.fc.weight",
"transformer.h.15.mlp.fc.bias", "transformer.h.15.mlp.proj.weight", "transformer.h.15.mlp.proj.bias", "transformer.ln_f.weight", "transformer.ln_f.bias".
        Unexpected key(s) in state_dict: "transformer.h.2.attn.gating_factor", "transformer.h.2.attn.adapter_wte.weight", "transformer.h.3.attn.gating_factor", "transformer.h.3.attn.adapter_wte.weight", "transformer.h.4.attn.gating_factor",
"transformer.h.4.attn.adapter_wte.weight", "transformer.h.5.attn.gating_factor", "transformer.h.5.attn.adapter_wte.weight", "transformer.h.6.attn.gating_factor", "transformer.h.6.attn.adapter_wte.weight", "transformer.h.7.attn.gating_factor",
"transformer.h.7.attn.adapter_wte.weight", "transformer.h.8.attn.gating_factor", "transformer.h.8.attn.adapter_wte.weight", "transformer.h.9.attn.gating_factor", "transformer.h.9.attn.adapter_wte.weight", "transformer.h.10.attn.gating_factor",
"transformer.h.10.attn.adapter_wte.weight", "transformer.h.11.attn.gating_factor", "transformer.h.11.attn.adapter_wte.weight", "transformer.h.12.attn.gating_factor", "transformer.h.12.attn.adapter_wte.weight", "transformer.h.13.attn.gating_factor",
"transformer.h.13.attn.adapter_wte.weight", "transformer.h.14.attn.gating_factor", "transformer.h.14.attn.adapter_wte.weight", "transformer.h.15.attn.gating_factor", "transformer.h.15.attn.adapter_wte.weight".

What am I doing wrong? How do I convert a tuned model checkpoint to what is expected by generate / chat?

@carmocca
Copy link
Contributor

carmocca commented May 19, 2023

We would need a chat_adapter.py script that loads the adapter checkpoint and processes the prompt as expected. If you'd like to contribute this, you can look at the differences between https://github.com/Lightning-AI/lit-parrot/blob/main/generate.py and https://github.com/Lightning-AI/lit-parrot/blob/main/generate_adapter.py to see what needs to be changed

@carmocca carmocca added enhancement New feature or request generation labels May 19, 2023
@carmocca
Copy link
Contributor

Oh, you also tried with generate.py. generate_adapter.py is the one you need to run

@RDouglasSharp
Copy link
Author

RDouglasSharp commented May 19, 2023 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request generation
Projects
None yet
Development

No branches or pull requests

2 participants