Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT model generate() function not correctly skipping the padding tokens indicated by attention_mask #14521

Closed
niansong1996 opened this issue Nov 25, 2021 · 11 comments

Comments

@niansong1996
Copy link

niansong1996 commented Nov 25, 2021

According to #7552, the padding tokens will be skipped when calculating the postional_id during generate(), if the corresponding positions are masked out in attention_mask. If I understand this correctly, this would mean that the appearance of padding tokens does not matter as long as they are not attended to. However, I found that it is not exactly the case, do I miss something here?


Check the following code for reproduction:

import torch
from transformers import GPTNeoForCausalLM, GPT2Tokenizer

# note that input_str_1 and input_str_2 only differs in number & postion of eos tokens
input_str_1 = "# in a kilometer race , a beats b by 48 meters or 12 seconds . what time does a take to complete the race ? n0 = 48.0 n1 = 12.0\nleg = n0 / n1\n<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>"
input_str_2 = "# in a kilometer race , a beats b by 48 meters or 12 seconds . what time does a take to complete the race ? n0 = 48.0 n1 = 12.0\n<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>leg = n0 / n1\n"

tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
tokenizer.pad_token = tokenizer.eos_token
gradient_ckpt = True
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M", pad_token_id=tokenizer.eos_token_id, gradient_checkpointing=gradient_ckpt, use_cache=not gradient_ckpt)

def test_generate(input_str: str):
    input_ids = tokenizer.encode(input_str, add_special_tokens=False, return_tensors="pt")
    attention_mask = torch.where(input_ids == tokenizer.eos_token_id, torch.zeros_like(input_ids), torch.ones_like(input_ids)).to(model.device)
    output_ids = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=30, num_return_sequences=1)
    output_str = tokenizer.decode(output_ids[0], skip_special_tokens=False, clean_up_tokenization_spaces=False)

    print(f"##################\n{output_str}\n##################")

test_generate(input_str_1)
test_generate(input_str_2)
@niansong1996 niansong1996 changed the title Iterative sentence generation and document generation results does not match GPT model generate() function not correctly skipping the padding tokens indicated by attention_mask Nov 29, 2021
@LysandreJik
Copy link
Member

Maybe of interest to @patrickvonplaten @Narsil

@niansong1996
Copy link
Author

Update: I changed my experiment code from right padding to left padding and the performance is greatly improved. If the generate() function truly skips the padding tokens, this should not have happened.

@Narsil
Copy link
Contributor

Narsil commented Dec 6, 2021

I just checked, and the attention_mask is correctly sent back to the model Gpt_neo so if anything it seems that the model would be the culprit.

Looking at the code, the position_ids are correctly skipped : https://github.com/huggingface/transformers/blob/master/src/transformers/models/gpt_neo/modeling_gpt_neo.py#L688

Then you can check that the attention_mask adds a very large negative number : https://github.com/huggingface/transformers/blob/master/src/transformers/models/gpt_neo/modeling_gpt_neo.py#L200

I am not familiar enough with the internals to know if that's enough, but it definitely seems to be doing what it should.

I even tried a much smaller example:

input_str_1 = "This is a test of<|endoftext|>"
input_str_2 = "This is a test<|endoftext|> of"

Now I checked that the ids are actually correct ( which is not necessarily the case with extra spaces etc..)

[ 1212, 318, 257, 1332, 286, 50256]
[ 1212, 318, 257, 1332, 50256, 286]

And then both generate exactly the same thing.

Is there a possibility that the issue comes from slightly twisted input_ids in your script ?

@niansong1996
Copy link
Author

niansong1996 commented Dec 6, 2021

Hi @Narsil, thanks a lot for the reply!

Yeah, I can see those code as well and it seems to be doing the correct thing but the results I am getting suggests otherwise. It is possible, however, related to how GPT-NEO handles those positional ids internally.

With the smaller example here, though the generated sequences are the same, the logits are actually different, which is why it exhibits the incorrect behavior in longer sequences. Here is the code to reproduce:

import torch
from transformers import GPTNeoForCausalLM, GPT2Tokenizer

# note that input_str_3 and input_str_4 only differs in number & postion of eos tokens
input_str_3 = "This is a test of<|endoftext|>"
input_str_4 = "This is a test<|endoftext|> of"

tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
tokenizer.pad_token = tokenizer.eos_token
gradient_ckpt = True
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M", pad_token_id=tokenizer.eos_token_id, gradient_checkpointing=gradient_ckpt, use_cache=not gradient_ckpt)

def check_first_token_prob(input_str: str):
    input_ids = tokenizer.encode(input_str, add_special_tokens=False, return_tensors="pt")
    attention_mask = torch.where(input_ids == tokenizer.eos_token_id, torch.zeros_like(input_ids), torch.ones_like(input_ids)).to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=30, num_return_sequences=1, 
                             output_scores=True, return_dict_in_generate=True)

    print(f"##################\n{outputs['scores'][-1][0]}\n##################")
    return outputs['scores'][-1][0]

print(sum(check_first_token_prob(input_str_3) - check_first_token_prob(input_str_4)))

The output I got is:

##################
tensor([-15.1894, -12.5526, -13.0819,  ..., -19.2879, -14.2211, -12.7208])
##################
##################
tensor([-15.1894, -12.5526, -13.0818,  ..., -19.2878, -14.2211, -12.7208])
##################
tensor(-0.8249)

The output scores only differs in a very small amount since the sequence is short and the position of the padding token is only off-by-one, but it's still different.

@Narsil
Copy link
Contributor

Narsil commented Dec 7, 2021

Tagging @patil-suraj, if you have more information on how the attention_mask works and if that behavior is in line with what it should do ?

Just for reference, I also checked outputs, and indeed there's variance (even more than in you post, I get:

----------------------------------------
tensor([[[ -8.1140,  -5.9630,  -8.3320,  ..., -18.4336, -13.0972,  -8.0018],
         [ -9.3932,  -7.8721, -12.6465,  ..., -17.8364, -15.9489, -11.9218],
         [ -7.0515,  -6.0169,  -8.5999,  ..., -15.7377, -12.0931,  -8.7372],
         [ -6.9112, -10.0014, -12.7149,  ..., -20.2539, -17.8208, -11.0143],
         [-10.9951,  -8.5840, -10.7879,  ..., -13.4873, -12.2152,  -9.3264],
         [ -6.2603,  -3.7231,  -7.3898,  ..., -11.6948, -10.7496,  -7.6801]]])
----------------------------------------
tensor([[[ -8.1140,  -5.9630,  -8.3320,  ..., -18.4336, -13.0972,  -8.0018],
         [ -9.3932,  -7.8721, -12.6465,  ..., -17.8364, -15.9489, -11.9218],
         [ -7.0515,  -6.0169,  -8.5999,  ..., -15.7377, -12.0931,  -8.7372],
         [ -6.9112, -10.0014, -12.7149,  ..., -20.2539, -17.8208, -11.0143],
         [ -7.6365,  -7.4540, -13.7994,  ..., -17.4893, -16.3242, -12.3888],
         [-10.9951,  -8.5840, -10.7879,  ..., -13.4873, -12.2152,  -9.3264]]]) # Here particularly different
----------------------------------------

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Dec 10, 2021

Jumping in the conversation here to maybe solve some problems.

One thing to remember is that generate() will always auto-regressively sample from the last token. This means that if the last token is a padding token than it will sample from it which is always incorrect.

This means one should never look at the output of the padding token, i.e. in @Narsil example:

----------------------------------------
tensor([[[ -8.1140,  -5.9630,  -8.3320,  ..., -18.4336, -13.0972,  -8.0018],
         [ -9.3932,  -7.8721, -12.6465,  ..., -17.8364, -15.9489, -11.9218],
         [ -7.0515,  -6.0169,  -8.5999,  ..., -15.7377, -12.0931,  -8.7372],
         [ -6.9112, -10.0014, -12.7149,  ..., -20.2539, -17.8208, -11.0143],
         [-10.9951,  -8.5840, -10.7879,  ..., -13.4873, -12.2152,  -9.3264],
         [ -6.2603,  -3.7231,  -7.3898,  ..., -11.6948, -10.7496,  -7.6801]]])
----------------------------------------
tensor([[[ -8.1140,  -5.9630,  -8.3320,  ..., -18.4336, -13.0972,  -8.0018],
         [ -9.3932,  -7.8721, -12.6465,  ..., -17.8364, -15.9489, -11.9218],
         [ -7.0515,  -6.0169,  -8.5999,  ..., -15.7377, -12.0931,  -8.7372],
         [ -6.9112, -10.0014, -12.7149,  ..., -20.2539, -17.8208, -11.0143],
         [ -7.6365,  -7.4540, -13.7994,  ..., -17.4893, -16.3242, -12.3888],
         [-10.9951,  -8.5840, -10.7879,  ..., -13.4873, -12.2152,  -9.3264]]]) # Here particularly different
----------------------------------------

this means that the last row of the first logits and the previous to last row of the second logits are useless (they correspond to padding tokens). What we should instead compare here is the previous to last row of the first logits to the last row of the second logits (both corresponding to the output logits of "of") - which are identical. This shows that the position ids are correctly shifted.

Now as a conclusion for padded inputs to GPT-like models one should always use padding=left because otherwise the model will necessarly have to sample from a padding token which is wrong (maybe we should put a warning for this actually somewhere - @Narsil what do you think about adding a warning (in pseudo code):

if model is not encoder decoder and any of last token is padding token -> then throw a warning that the user should probably use padding=left

@niansong1996
Copy link
Author

@patrickvonplaten thanks a lot for the clarification! It confirms what I found in the experiments -- right padding for the GPT-like model is incorrect and leads to performance degradation.

However, I do think the problem for not correctly skipping the padding tokens still exists in general. if sampling from the padding token will lead to incorrect results, then in the following examples, the logits for the generated tokens should be the same since the last token is not padding token anymore:

input_str_3 = "This is a test of<|endoftext|> some"
input_str_4 = "This is a test<|endoftext|> of some"

However, the output I've been getting is:

##################
tensor([-15.8802, -16.3779, -15.6428,  ..., -21.8622, -17.9515, -14.6956])
##################
##################
tensor([-15.8802, -16.3779, -15.6428,  ..., -21.8622, -17.9514, -14.6956])
##################
tensor(0.6359)

Notice that they look the same, but when doing subtraction and summation, we can see they are of different values.

In principle, if the padding tokens are correctly skipped everywhere, then it would not matter even if I have input like this:

input_str_3 = "This is a test of<|endoftext|><|endoftext|><|endoftext|> some"
input_str_4 = "This<|endoftext|> is a test<|endoftext|> of some"

Or am I understanding it incorrectly?

The full code snippet I used to generate the output is pasted below:

import torch
from transformers import GPTNeoForCausalLM, GPT2Tokenizer

# note that input_str_3 and input_str_4 only differs in number & postion of eos tokens
input_str_3 = "This is a test of<|endoftext|> some"
input_str_4 = "This is a test<|endoftext|> of some"

tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
tokenizer.pad_token = tokenizer.eos_token
gradient_ckpt = True
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M", pad_token_id=tokenizer.eos_token_id, gradient_checkpointing=gradient_ckpt, use_cache=not gradient_ckpt)

def check_first_token_prob(input_str: str):
    input_ids = tokenizer.encode(input_str, add_special_tokens=False, return_tensors="pt")
    attention_mask = torch.where(input_ids == tokenizer.eos_token_id, torch.zeros_like(input_ids), torch.ones_like(input_ids)).to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=30, num_return_sequences=1, 
                             output_scores=True, return_dict_in_generate=True, do_sample=False)

    print(f"##################\n{outputs['scores'][-1][0]}\n##################")
    return outputs['scores'][-1][0]

print(sum(check_first_token_prob(input_str_3) - check_first_token_prob(input_str_4)))

@patrickvonplaten
Copy link
Contributor

Hey @niansong1996,

I think your understanding is very much correct here. If I understand your example

##################
tensor([-15.8802, -16.3779, -15.6428,  ..., -21.8622, -17.9515, -14.6956])
##################
##################
tensor([-15.8802, -16.3779, -15.6428,  ..., -21.8622, -17.9514, -14.6956])
##################
tensor(0.6359)

you are seeing (very) small differences in the output logits that shouldn't be there.
I'm quite sure that this is because masked tokens are not perfectly masked but just increase by a large negative number (-10.000) to not have any issues with float16. Now this is amplified in GPT2 for two reasons:

  1. GPT2 uses a causal mask by default with -10,000 and then in the token is also masked it adds -10,000 again instead of replacing it with just -10,000. E.g. see those lines:

    attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))

  2. GPT2 has been seen to produce very large logits (e.g.: fix repetition penalty error in modeling_utils.py #2303 (comment)) which means that small differences in the padding, e.g. using -10,000 and -20,000 instead of -inf before the softmax can actually make a significant difference.

Now taking this into account for your example:

input_str_3 = "This is a test of<|endoftext|> some"
input_str_4 = "This is a test<|endoftext|> of some"

It means the following for input_str_3, "of" attends to "<|endoftext|>" just with a padding penalty of -10,000 (padding mask) while for "input_str_4", "of" attends to "<|endoftext|>" just with a padding penalty of -20,000 (padding mask + causal mask). Even though -10,000 and -20,000 both essentially mean the softmax is zero, those differences can up in GPT2 (especially since it tends to have extreme values).

I think you're reasoning is 100% correct and think those small differences on what values are used for padding could be the explanation - you could maybe try to replace all -10,000 with -torch.inf to see if the problem persists

@github-actions
Copy link

github-actions bot commented Jan 6, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@sonsus
Copy link

sonsus commented Feb 3, 2022

I found this issue extremely helpful for my experiment. I was wondering why pretrained decoder-only LM's are failing to generate anything with tokenizer.add_special_tokens({'pad_token': '[PAD]'});model.resize_token_embeddings(len(tokenizer). This issue pretty much explains why my implementation failed so badly on generation task. Again, I really appreciate =]

@ShiyuNee
Copy link

ShiyuNee commented Aug 8, 2023

Jumping in the conversation here to maybe solve some problems.

One thing to remember is that generate() will always auto-regressively sample from the last token. This means that if the last token is a padding token than it will sample from it which is always incorrect.

This means one should never look at the output of the padding token, i.e. in @Narsil example:

----------------------------------------
tensor([[[ -8.1140,  -5.9630,  -8.3320,  ..., -18.4336, -13.0972,  -8.0018],
         [ -9.3932,  -7.8721, -12.6465,  ..., -17.8364, -15.9489, -11.9218],
         [ -7.0515,  -6.0169,  -8.5999,  ..., -15.7377, -12.0931,  -8.7372],
         [ -6.9112, -10.0014, -12.7149,  ..., -20.2539, -17.8208, -11.0143],
         [-10.9951,  -8.5840, -10.7879,  ..., -13.4873, -12.2152,  -9.3264],
         [ -6.2603,  -3.7231,  -7.3898,  ..., -11.6948, -10.7496,  -7.6801]]])
----------------------------------------
tensor([[[ -8.1140,  -5.9630,  -8.3320,  ..., -18.4336, -13.0972,  -8.0018],
         [ -9.3932,  -7.8721, -12.6465,  ..., -17.8364, -15.9489, -11.9218],
         [ -7.0515,  -6.0169,  -8.5999,  ..., -15.7377, -12.0931,  -8.7372],
         [ -6.9112, -10.0014, -12.7149,  ..., -20.2539, -17.8208, -11.0143],
         [ -7.6365,  -7.4540, -13.7994,  ..., -17.4893, -16.3242, -12.3888],
         [-10.9951,  -8.5840, -10.7879,  ..., -13.4873, -12.2152,  -9.3264]]]) # Here particularly different
----------------------------------------

this means that the last row of the first logits and the previous to last row of the second logits are useless (they correspond to padding tokens). What we should instead compare here is the previous to last row of the first logits to the last row of the second logits (both corresponding to the output logits of "of") - which are identical. This shows that the position ids are correctly shifted.

Now as a conclusion for padded inputs to GPT-like models one should always use padding=left because otherwise the model will necessarly have to sample from a padding token which is wrong (maybe we should put a warning for this actually somewhere - @Narsil what do you think about adding a warning (in pseudo code):

if model is not encoder decoder and any of last token is padding token -> then throw a warning that the user should probably use padding=left

Why samplimg from pad_token is always incorrect?
For example, in a QA task, we use as pad_token. Given the input "question ", we want to get the output "question answer". We can see that it is ok to generate from the 'pad_token'.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants