Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix repetition penalty error in modeling_utils.py #2303

Conversation

patrickvonplaten
Copy link
Contributor

fix bug mention in #2302

@codecov-io
Copy link

codecov-io commented Dec 24, 2019

Codecov Report

Merging #2303 into master will decrease coverage by 0.01%.
The diff coverage is 0%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #2303      +/-   ##
==========================================
- Coverage   73.54%   73.52%   -0.02%     
==========================================
  Files          87       87              
  Lines       14789    14793       +4     
==========================================
  Hits        10876    10876              
- Misses       3913     3917       +4
Impacted Files Coverage Δ
src/transformers/modeling_utils.py 63.45% <0%> (-0.46%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 81db12c...18e5bdb. Read the comment docs.

@thomwolf
Copy link
Member

Good catch.
But this is actually the technique mentioned in http://arxiv.org/abs/1909.05858.
So to fix it we should check the code of Nitish (https://github.com/salesforce/ctrl) and apply the same behavior here.

@patrickvonplaten
Copy link
Contributor Author

I checked the code in https://github.com/salesforce/ctrl/blob/0f30306a8947ce0ede62e79c7e1f05a585cc56c9/generation.py#L217:
prompt_logits[_token][generated_token] /= penalty

So in the original code division is always used no matter what sign the prompt_logit of the previously generated tokens.

When going a bit deeper and looking at the actual values of the logit in

next_token_logits[i, previous_tokens] /= repetition_penalty

for different models the following can be observed:

For the models: ctrl, xlm the logit values tend to be positive, which explains why division by the repetition penalty is used. BUT, the values don't have to be positive, there were also very rare cases when using ctrl where the logit was actually negative in which case a division increases the probability of that word to be sampled.

For the models: gpt2, openai-gpt, xlnet the logit values tend to be negative, in which case dividing by a repetition penalty increases the probability of previously generated tokens to be sampled.

In the proposed PR, both cases would be correctly handled from a logical point of view.
If we want to stick to the original code on the other hand (only using division) we could add a warning that the repetition penalty should only be used in combination with ctrl.

@thomwolf
Copy link
Member

Ok, I see, thanks for documenting this. Let's go for this solution for now.

@thomwolf thomwolf merged commit aeef482 into huggingface:master Dec 25, 2019
@patrickvonplaten patrickvonplaten deleted the fix_error_with_repetition_penalty branch December 25, 2019 22:05
@mvpcom
Copy link

mvpcom commented Jan 5, 2020

Is this fix added to the pip package? So if we use pip install package this will be covered or not yet I have to install from source?

@w4nderlust
Copy link
Contributor

w4nderlust commented Jan 8, 2020

Reading this after it was mentioned in the PPLM example PR.
The fix makes total sense, but I have a concern: the amount by which a negative number is diminished is greater than the amount a positive number is diminished.
If we have two values, say -2 and 2 this happens:

x = np.array([-2, 2])
sx = np.exp(x)/sum(np.exp(x))
print(sx)  # array([0.01798621, 0.98201379])

if we apply the same penalty to both, we would want the probabilities to stay the same, but this is what happens:

p = [1/1.2, 1.2]
spx = np.exp(x/p)/sum(np.exp(x/p))
print(spx)  # array([0.01684577, 0.98315423])

On the other hand, if we apply the penalty to the probabilities after the softmax (and we renormalize) this is what happens:

p2 = [1.2, 1.2]
sp2x = (sx/p2)/sum(sx/p2)
print(sp2x)  # array([0.01798621, 0.98201379])

The probabilities are intact, as we want, because we don't want to penalize negative values more than we penalize positive values.
So my proposal is to perform the penalty after the softmax, on probability values, always dividing, rather than on the logits.
What do you think?

Edit:
In math i propose to move from:
CodeCogsEqn
to:
CodeCogsEqn (1)

@patrickvonplaten patrickvonplaten self-assigned this Feb 17, 2020
@patrickvonplaten
Copy link
Contributor Author

Sorry for the late response @w4nderlust !

I think you it makes a lot of sense what you are saying!

To implement your solution with minimal code change one could simply change Eq. (1):

CodeCogsEqn (10)

to the equivalent Eq. (2)

CodeCogsEqn (11)

One question that remains is how the new repetition penalties CodeCogsEqn (12) in Eq. (1) & (2) will have to differ from the old repetition penalties CodeCogsEqn (13) in Eq. (3):

CodeCogsEqn (8)

to have a similar effect on the softmax. It is quite obvious that CodeCogsEqn (13) reduces the prob of its token much more than CodeCogsEqn (12)

For the different LMHead models, I calculated CodeCogsEqn (14) for different values of CodeCogsEqn (15) . I simply generated randomly sampled sentences from the pretrained models and averaged the effect of the tokens for 5 runs with max_length=100 so that the averaged is formed of ca. CodeCogsEqn (18) tokens.

The following values show by how much CodeCogsEqn (13) scales down the prob after the softmax which is equivalent of what CodeCogsEqn (14) would have been set to:

Generate repetition penalty comparison for ctrl
Penalty factor: 1.1 - Without penalty / penalty ratio avg: 4e0
Penalty factor: 1.2 - Without penalty / penalty ratio avg: 31e0
Penalty factor: 1.3 - Without penalty / penalty ratio avg: 149e0
Penalty factor: 1.4 - Without penalty / penalty ratio avg: 25e3
Penalty factor: 1.5 - Without penalty / penalty ratio avg: 286e3
Generate repetition penalty comparison for distilgpt2
Penalty factor: 1.1 - Without penalty / penalty ratio avg: 23e3
Penalty factor: 1.2 - Without penalty / penalty ratio avg: 2e9
Penalty factor: 1.3 - Without penalty / penalty ratio avg: 223e9
Penalty factor: 1.4 - Without penalty / penalty ratio avg: 3e24
Generate repetition penalty comparison for gpt2
Penalty factor: 1.1 - Without penalty / penalty ratio avg: 1e9
Penalty factor: 1.2 - Without penalty / penalty ratio avg: 742e18
Generate repetition penalty comparison for xlm-clm-enfr-1024
Penalty factor: 1.1 - Without penalty / penalty ratio avg: 2e0
Penalty factor: 1.2 - Without penalty / penalty ratio avg: 3e0
Penalty factor: 1.3 - Without penalty / penalty ratio avg: 5e0
Penalty factor: 1.4 - Without penalty / penalty ratio avg: 9e0
Penalty factor: 1.5 - Without penalty / penalty ratio avg: 13e0
Generate repetition penalty comparison for openai-gpt
Penalty factor: 1.1 - Without penalty / penalty ratio avg: 1e0
Penalty factor: 1.2 - Without penalty / penalty ratio avg: 2e0
Penalty factor: 1.3 - Without penalty / penalty ratio avg: 4e0
Penalty factor: 1.4 - Without penalty / penalty ratio avg: 15e0
Penalty factor: 1.5 - Without penalty / penalty ratio avg: 19e0
Generate repetition penalty comparison for xlnet-base-cased
Penalty factor: 1.1 - Without penalty / penalty ratio avg: 5e0
Penalty factor: 1.2 - Without penalty / penalty ratio avg: 34e0
Penalty factor: 1.3 - Without penalty / penalty ratio avg: 2e3
Penalty factor: 1.4 - Without penalty / penalty ratio avg: 47e3
Penalty factor: 1.5 - Without penalty / penalty ratio avg: 8e6

It can be seen that gpt2 for example produces much larger logit values which lead to much more drastic reductions in the prob after softmax. The repetition penalty was originally introduced for ctrl so it's probably best to look at its behaviour.

@patrickvonplaten
Copy link
Contributor Author

So I think there are three possibilities:

  1. Follow the proposed solution from @w4nderlust implementing Eq.(1).
    This would mean though that the proposed repetition penalty of 1.3 in the ctrl paper would have to be changed to something around 150 which is quite a large value.

  2. Instead of using substracting by the log(rep_penalty) as in:
    CodeCogsEqn (11),
    one could only substract by the rep_penalty to give the equation:
    CodeCogsEqn (21),
    This way the values for CodeCogsEqn (22) would equal CodeCogsEqn (24) and thus be much smaller. The repetition penalty in ctlr would thus only have to be around 5 to equal the behavior of the old penalty of 1.3. One disadvantage would be that the neutral element in this case is 0 instead of 1 which might be a bit confusing.

  3. Just leave as it is now since from what I seen most logits almost always all either positive or either all negative, so that the current behavior is not very prone to lead to errors.

I would tend to solution 2, giving a clear explanation of the variable in the argument section of the language generation function.

What do you think @w4nderlust and @thomwolf ?

@w4nderlust
Copy link
Contributor

Thank you for the thorough analysis @patrickvonplaten ! I believe 2 would be fine. The nog just scales things differently, but there's no specific reason to have it, as it is a user tunable parameter anyway. The fact that the default would be 0 instead of one I think could be explained and one could point to this conversation in a comment to give the full picture. Although I understand this is not a huge issue (because of what you say in 3), I kinda believe 2 is better as the could potentially be in the future a different model that actually outputs both positive and negative logits and it that case this could make a substantial difference in the quality of the sampling.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants