Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regarding setup_chat_format overwriting existing special tokens #1819

Closed
zyzhang1130 opened this issue Jul 9, 2024 · 9 comments
Closed

Regarding setup_chat_format overwriting existing special tokens #1819

zyzhang1130 opened this issue Jul 9, 2024 · 9 comments

Comments

@zyzhang1130
Copy link

because of https://github.com/huggingface/trl/blob/2860ce5091e689bab167454453e9ddbe2337de3d/trl/models/utils.py#L90C2-L90C48, setup_chat_format will overwrite the existing eos_token, pad_token, and bos_token. Why need to do such thing? My understanding is that setup_chat_format is supposed to facilitate instruction fine-tuning by adding special tokens to indicate the beginning and end of the turn of the dialogues of the user and model (i.e., assistant). If this is the case, then simply adding chat_format.bos_token, chat_format.eos_token should suffice right? The issue with overwriting existing eos_token, pad_token, and bos_token is that the pre-trained models we are supposed to finetune were trained on these tokens, and hence overwriting them would certainly cause problems right?

@AIR-hl
Copy link
Contributor

AIR-hl commented Jul 12, 2024

If you fine-tuning model with full parameters it usually does not cause problems. But if you fine-tuning model with peft method such as Lora, it may cause problems

@zyzhang1130
Copy link
Author

zyzhang1130 commented Jul 16, 2024

@AIR-hl I see. I can see why that is the case, as full-parameter tuning updates the embeddings of the new token in the input embedding layer. I have one closely related question: how is training model with chat template enabled different from using formatting_func and 'data_collator'? Conceptually I feel they aim to achieve the same goal, and the later is found easily in a lot of tutorials/code online. However, I feel the official huggingface documentation does not address their distinction explicitly. Is there something special that only using chat template can achieve?

Update: actually there might still be a problem. If setup_chat_format only adds additional special tokens for the beginning/end of a turn in a dialogue, this is fine. But the current implementation also replaces the original bos, eos tokens regardless of what model is used. I think this would render pre-training useless.

@deema-A
Copy link

deema-A commented Jul 17, 2024

RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
	size mismatch for base_model.model.model.embed_tokens.weight: copying a param with shape torch.Size([32002, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]).
	size mismatch for base_model.model.lm_head.weight: copying a param with shape torch.Size([32002, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]).

any idea?

@zyzhang1130
Copy link
Author

RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
	size mismatch for base_model.model.model.embed_tokens.weight: copying a param with shape torch.Size([32002, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]).
	size mismatch for base_model.model.lm_head.weight: copying a param with shape torch.Size([32002, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]).

any idea?

were you using lora to fine-tune your model?

@AIR-hl
Copy link
Contributor

AIR-hl commented Jul 19, 2024

@AIR-hl I see. I can see why that is the case, as full-parameter tuning updates the embeddings of the new token in the input embedding layer. I have one closely related question: how is training model with chat template enabled different from using formatting_func and 'data_collator'? Conceptually I feel they aim to achieve the same goal, and the later is found easily in a lot of tutorials/code online. However, I feel the official huggingface documentation does not address their distinction explicitly. Is there something special that only using chat template can achieve?

Update: actually there might still be a problem. If setup_chat_format only adds additional special tokens for the beginning/end of a turn in a dialogue, this is fine. But the current implementation also places the original bos, eos tokens regardless of what model is used. I think this would render pre-training useless.

@zyzhang1130 In fact, the set_chat_format just provides a convenient way to format chat data in json, you can also customize the chat template based on the existing bos, eos tokens of the model. The above is just my understanding.

@deema-A
Copy link

deema-A commented Jul 19, 2024

RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
	size mismatch for base_model.model.model.embed_tokens.weight: copying a param with shape torch.Size([32002, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]).
	size mismatch for base_model.model.lm_head.weight: copying a param with shape torch.Size([32002, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]).

any idea?

were you using lora to fine-tune your model?

@zyzhang1130 yes

peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.05,
        r=16,
        bias="none",
        task_type="CAUSAL_LM", 
        base_model_name_or_path=model_id,
        modules_to_save = ["lm_head", "embed_tokens"]
)

@zyzhang1130
Copy link
Author

RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
	size mismatch for base_model.model.model.embed_tokens.weight: copying a param with shape torch.Size([32002, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]).
	size mismatch for base_model.model.lm_head.weight: copying a param with shape torch.Size([32002, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]).

any idea?

were you using lora to fine-tune your model?

@zyzhang1130 yes

peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.05,
        r=16,
        bias="none",
        task_type="CAUSAL_LM", 
        base_model_name_or_path=model_id,
        modules_to_save = ["lm_head", "embed_tokens"]
)

Lora doesn't play well with setup_chat_format, because setup_chat_format adds additional token embedding into the input embedding layer, and when you save a peft mode, it only saves the adaptor, so there will be issues. You will need to manually save the input embedding layer separately.

@zyzhang1130
Copy link
Author

@AIR-hl I see. I can see why that is the case, as full-parameter tuning updates the embeddings of the new token in the input embedding layer. I have one closely related question: how is training model with chat template enabled different from using formatting_func and 'data_collator'? Conceptually I feel they aim to achieve the same goal, and the later is found easily in a lot of tutorials/code online. However, I feel the official huggingface documentation does not address their distinction explicitly. Is there something special that only using chat template can achieve?
Update: actually there might still be a problem. If setup_chat_format only adds additional special tokens for the beginning/end of a turn in a dialogue, this is fine. But the current implementation also places the original bos, eos tokens regardless of what model is used. I think this would render pre-training useless.

@zyzhang1130 In fact, the set_chat_format just provides a convenient way to format chat data in json, you can also customize the chat template based on the existing bos, eos tokens of the model. The above is just my understanding.

yes, actually I had to implement a version of it myself to handle those issues. Other than formatting, set_chat_format also adds additional token embedding and renormalizes the embedding layer which is also crucial.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants