Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eval Loss NaN on Llama-2 #13

Open
mmichaelzhang opened this issue Jul 22, 2023 · 3 comments
Open

Eval Loss NaN on Llama-2 #13

mmichaelzhang opened this issue Jul 22, 2023 · 3 comments
Labels
question Further information is requested

Comments

@mmichaelzhang
Copy link

Hi,

By anychance, have you tried literally run on the llama-2 model?

I tried using default llama parameters for pruning and post-training, resulting in similar wikitext2 score (~19) but much worse score for ptb (~70).

Also, when running post-training with the parameter set of llama, llama-2 loss explodes after ~0.2 epoch. Tried using smaller lr (1e-5) yet eval loss exploded to nan.

It would be of great help if you could provide some insights on both pruning and post-training parameters.

Thanks.

@horseee
Copy link
Owner

horseee commented Jul 22, 2023

Greetings!

Regarding the test results on PTB after pruning, the reason for the worse score lies in the unpruned llama2-7B model, which was approximately 47, significantly higher than llama-7B (~22) on PTB.

As for the issue of NaN during post-training, we encountered the same problem you reported. Currently, we are searching for the appropriate hyper-parameters to fine-tune the pruned model. If we obtain any new findings or find any bugs in our code, we will promptly update you.

@horseee horseee added the question Further information is requested label Jul 22, 2023
@mmichaelzhang
Copy link
Author

Thank you for the timely reply! Hope to get back with good news.

Cheers.

@kyang-06
Copy link

kyang-06 commented Mar 4, 2024

@mmichaelzhang Have you resolved this issue? I also observed training loss explosion and encountered performance deterioration for llama2-7b using default llama settings:

Wikitext2 w/o tune Ptb w/o tune BoolQ Acc PIQA Acc HellaSwag Acc_norm WinoGrande Acc ARC-e acc ARC-c Acc_norm OBQA Acc_norm
19.24 72.61 37.83 52.34 26.64 49.41 25.08 27.82 28.40

As pruned model weights is quantized by int8 and frozen for post-training, I think the phenomenon is non-related with BF16/FP16 dtype, which is considered as the cause by authors:

Tip: Training LLaMA-2 in float16 is not recommended and is known to produce nan; as such, the model should be trained in bfloat16.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants