Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mismatch between logits from generate and forward with an attention mask for most GPT models #18388

Closed
4 tasks
LaurenceA opened this issue Aug 1, 2022 · 3 comments
Assignees
Labels

Comments

@LaurenceA
Copy link

LaurenceA commented Aug 1, 2022

System Info

  • transformers version: 4.21.0
  • Platform: Linux-3.10.0-1160.45.1.el7.x86_64-x86_64-with-glibc2.10
  • Python version: 3.8.8
  • Huggingface_hub version: 0.8.1
  • PyTorch version (GPU?): 1.10.0+cu113 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@patil-suraj, @patrickvonplaten, @LysandreJik

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

"""
MWE showing that logits from generate match those from forward, except for the first token?
"""
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.distributions import Categorical
import torch as t

#Broken:
model_name = "distilgpt2"
#model_name = "gpt2"
#model_name = "EleutherAI/gpt-neo-125M"
#model_name = "EleutherAI/gpt-neo-1.3B"
#Working:
#model_name = "EleutherAI/gpt-j-6B"
lm = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, padding='right')


tokenizer.pad_token = tokenizer.eos_token
prompt = tokenizer(["big unpadded five token prompt ", "padded three token "], return_tensors='pt', padding=True, add_special_tokens=True)

#generate with plain sampling (https://huggingface.co/blog/how-to-generate)

result = lm.generate(prompt["input_ids"], attention_mask=prompt["attention_mask"], do_sample=True, output_scores=True, return_dict_in_generate=True, top_k=0, max_length=10)
x, logits_gen = result.sequences, result.scores
logits_gen = t.stack(logits_gen, 1)

x_attention_mask = (x != tokenizer.eos_token_id).to(dtype=t.int64)
position_ids = x_attention_mask.cumsum(-1)-1
print("Attention mask for prompt + generated text")
print(x_attention_mask)
print("Position IDs")
print(position_ids)
logits_for = lm(x, attention_mask=x_attention_mask, position_ids=position_ids).logits
#we drop the last element, and the first prompt_length-1 elements to get
#logits from forward to match those from generate
logits_for = logits_for[:, (prompt["input_ids"].shape[-1]-1):-1]

P_for = Categorical(logits = logits_for)
P_gen = Categorical(logits = logits_gen)

#Take only generated tokens
x = x[..., prompt['input_ids'].shape[-1]:]
log_prob_for = P_for.log_prob(x)
log_prob_gen = P_gen.log_prob(x)

print("log-probs from forward")
print(log_prob_for)
print("log-probs from generate")
print(log_prob_gen)

Expected behavior

I'm trying to get logits or log-probabilities from generate to match those from forward in the presence of a padded prompt.

For GPT models, I managed to get almost everything working, by setting the position_ids for forward (see MWE script).

However, there still seems to be a slight mismatch with the first token, if the prompt has an attention mask. You can see this in the returned output, from this script, which is:

Attention mask for prompt + generated text
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0, 0, 1, 1, 1]])
Position IDs
tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 4, 4, 5, 6, 7]])
log-probs from forward
tensor([[ -8.3152,  -5.5587,  -3.0973],
        [ -2.6509, -10.6300,  -7.5426]], grad_fn=<SqueezeBackward1>)
log-probs from generate
tensor([[ -8.3152,  -5.5587,  -3.0973],
        [ -2.7818, -10.6300,  -7.5426]])

Note the slightly mismatch between the bottom-left log-prob, which doesn't happen for any other log-probability.

I've tried a few GPT flavour models: we get problem for distilgpt2, gpt2, EleutherAI/gpt-neo-125M and EleutherAI/gpt-neo-1.3B. But the log-probs all match for EleutherAI/gpt-j-6B.

@LaurenceA LaurenceA added the bug label Aug 1, 2022
@LysandreJik
Copy link
Member

cc @gante for generate :)

@gante
Copy link
Member

gante commented Aug 3, 2022

Hi @LaurenceA 👋

With decoder-only models, such as the ones you mentioned, padding should be done on the left. This is because the output is a continuation of the input prompt -- there would be gaps in the output without left padding. Our code to automatically prepare the position IDs for a given attention mask in decoder-only models has left-sided padding in mind and differs from the one you wrote in your example, hence the output mismatch :)

Not being aware that left-sided padding should be used for these models is a common issue. I'm leaving this issue open as a reminder that we should add some form of warning for users.


👉 example of code to prepare the position IDs

Here's your example, with left padding and the same position IDs creation method:

"""
MWE showing that logits from generate match those from forward, except for the first token?
"""
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.distributions import Categorical
import torch as t

#Broken:
model_name = "distilgpt2"
#model_name = "gpt2"
#model_name = "EleutherAI/gpt-neo-125M"
#model_name = "EleutherAI/gpt-neo-1.3B"
#Working:
#model_name = "EleutherAI/gpt-j-6B"
lm = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")


tokenizer.pad_token = tokenizer.eos_token
prompt = tokenizer(["big unpadded five token prompt ", "padded three token "], return_tensors='pt', padding=True, add_special_tokens=True)

#generate with plain sampling (https://huggingface.co/blog/how-to-generate)

result = lm.generate(prompt["input_ids"], attention_mask=prompt["attention_mask"], do_sample=True, output_scores=True, return_dict_in_generate=True, top_k=0, max_length=10)
x, logits_gen = result.sequences, result.scores
logits_gen = t.stack(logits_gen, 1)

x_attention_mask = (x != tokenizer.eos_token_id).to(dtype=t.int64)
position_ids = x_attention_mask.cumsum(-1)-1
position_ids.masked_fill_(x_attention_mask == 0, 1)
print("Attention mask for prompt + generated text")
print(x_attention_mask)
print("Position IDs")
print(position_ids)
logits_for = lm(x, attention_mask=x_attention_mask, position_ids=position_ids).logits
#we drop the last element, and the first prompt_length-1 elements to get
#logits from forward to match those from generate
logits_for = logits_for[:, (prompt["input_ids"].shape[-1]-1):-1]

P_for = Categorical(logits = logits_for)
P_gen = Categorical(logits = logits_gen)

#Take only generated tokens
x = x[..., prompt['input_ids'].shape[-1]:]
log_prob_for = P_for.log_prob(x)
log_prob_gen = P_gen.log_prob(x)

print("log-probs from forward")
print(log_prob_for)
print("log-probs from generate")
print(log_prob_gen)

@gante gante self-assigned this Aug 3, 2022
@huggingface huggingface deleted a comment from github-actions bot Aug 31, 2022
@huggingface huggingface deleted a comment from github-actions bot Sep 27, 2022
@gante
Copy link
Member

gante commented Sep 28, 2022

@LaurenceA if you run generate from the current main, you should see a warning if you don't use left-padding with decoder-only models like GPT2 :)

(#19067)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants