Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spikes in PPO policy loss #101

Closed
lvwerra opened this issue Jan 24, 2023 · 10 comments
Closed

Spikes in PPO policy loss #101

lvwerra opened this issue Jan 24, 2023 · 10 comments

Comments

@lvwerra
Copy link
Member

lvwerra commented Jan 24, 2023

We sometimes experience huge loss spikes in the policy loss which either cause the training to fail or take a very long time to recover. It would be useful to investigate where they come from and how to mitigate them. cc @natolambert

Screenshot 2023-01-24 at 16 24 36

Screenshot 2023-01-24 at 16 48 34

@younesbelkada
Copy link
Contributor

younesbelkada commented Jan 24, 2023

One idea could be that we don't mask out the logits corresponding to padding tokens when computing the loss, it is something I am having a look in #100 - But I am not sure here if this is really the rootcause of this

@natolambert
Copy link
Contributor

Yeah, so something weird is going one with a simultaneous large drop in entropy, clip fraction, etc. Can we log the model outputs at that step? Is there any chance the model output gets stuck on something?

@natolambert
Copy link
Contributor

natolambert commented Jan 25, 2023

@younesbelkada your idea makes sense.

Some follow ups:

  1. @lvwerra what experiment setup was this? I'd love to dig further.
  2. what does a clip frac of .55 mean, is that half of the value samples are clipped in the PPO update? Or am I off by a factor of 100?

Below is musings on PPO stability:

  • Thread from stable baselines, suggests entropy coefficient was way too high (different domain than RLHF)
    (will add more if I find it)

The more I look, there is surely some numerical instability in the loss computation at that step (NaN), which is impressive it recovers from. I'm thinking about what is the right intermediate values to log (maybe optionally). Can we do something that if there is a NaN or a big loss value, we dump a bunch of values to the logger? I am sure we will see things like this when doing more RLHF.
3. How should we configure the logger for a rich researchy-approach (lots of unknowns).

@DaehanKim
Copy link

DaehanKim commented Feb 3, 2023

I also observed a spike in policy loss when running sentiment-control example, and I initially thought it's because of some strange samples or high variance in positive logits.

And I found this : pipeline doesn't always output 'POSITIVE' logit at 1 index.
순서바뀜

and in the notebook, output[1]['score'] is considered as a positive logit and fed into the PPOTrainer. I guess this causes unstable training because reward signal is not valid. Am I making sense?

btw, I didn't realize this and run several experiments with changed reward definitions (that uses both positive and negative logits) and reward_mean wasn't increasing as training goes on.
image

I'll report further experiment results at #120

@DaehanKim
Copy link

I corrected parsing pipeline output and loss spike still remains in sentiment-control notebook example.
so there may be another reaseon for this unstability.

image

@lvwerra
Copy link
Member Author

lvwerra commented Feb 3, 2023

Thanks @DaehanKim, yes there is an issue besides the order of the logits. I tracked it down to some changes done in #80 (no spikes at the beginning of the PR and spikes at time of merge) and I started tracking the issue down in #126. I'll report as well here if I figure it out!

@lvwerra
Copy link
Member Author

lvwerra commented Feb 7, 2023

The issue with the loss spikes in the sentiment control notebook was that sometimes only a few new tokens would be generated (1-2) and this would cause the loss to spike. Not sure, yet, where exactly this behaviour comes from but we now know where to look: we can actively generate short sequences and investigate what causes the loss explosion.

@tengxiaoliu
Copy link

I also experienced the spike loss in my case. I'm using the seq2seq t5 model as the backbone. The model is initialized with a supervised finetuned model. I find that the spike loss comes from steps that have a negative advantage and an extremely high ratio r(\theta). This falls in the situation 6 in the figure below.
image

In my case, removing pg_losses1 and only keeping the clipped pg_losses2 can help restrict the ratio and stabilize the loss. I didn't train the model from scratch, so the clip fraction is low (less than 3%). But this is a problem if the clip fraction is too high and most of the loss is clipped. It's not a general solution though, just some findings from my case.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@b11z
Copy link

b11z commented Apr 24, 2024

The issue that @DaehanKim noticed is also present in the gpt2-sentiment.ipynb example. It might be nice to propagate the extract_pipe_output fix to that notebook as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants