New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mismatch between logits from generate and forward with an attention mask for most GPT models #18388
Comments
cc @gante for generate :) |
Hi @LaurenceA 👋 With decoder-only models, such as the ones you mentioned, padding should be done on the left. This is because the output is a continuation of the input prompt -- there would be gaps in the output without left padding. Our code to automatically prepare the position IDs for a given attention mask in decoder-only models has left-sided padding in mind and differs from the one you wrote in your example, hence the output mismatch :) Not being aware that left-sided padding should be used for these models is a common issue. I'm leaving this issue open as a reminder that we should add some form of warning for users. 👉 example of code to prepare the position IDs Here's your example, with left padding and the same position IDs creation method: """
MWE showing that logits from generate match those from forward, except for the first token?
"""
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.distributions import Categorical
import torch as t
#Broken:
model_name = "distilgpt2"
#model_name = "gpt2"
#model_name = "EleutherAI/gpt-neo-125M"
#model_name = "EleutherAI/gpt-neo-1.3B"
#Working:
#model_name = "EleutherAI/gpt-j-6B"
lm = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
prompt = tokenizer(["big unpadded five token prompt ", "padded three token "], return_tensors='pt', padding=True, add_special_tokens=True)
#generate with plain sampling (https://huggingface.co/blog/how-to-generate)
result = lm.generate(prompt["input_ids"], attention_mask=prompt["attention_mask"], do_sample=True, output_scores=True, return_dict_in_generate=True, top_k=0, max_length=10)
x, logits_gen = result.sequences, result.scores
logits_gen = t.stack(logits_gen, 1)
x_attention_mask = (x != tokenizer.eos_token_id).to(dtype=t.int64)
position_ids = x_attention_mask.cumsum(-1)-1
position_ids.masked_fill_(x_attention_mask == 0, 1)
print("Attention mask for prompt + generated text")
print(x_attention_mask)
print("Position IDs")
print(position_ids)
logits_for = lm(x, attention_mask=x_attention_mask, position_ids=position_ids).logits
#we drop the last element, and the first prompt_length-1 elements to get
#logits from forward to match those from generate
logits_for = logits_for[:, (prompt["input_ids"].shape[-1]-1):-1]
P_for = Categorical(logits = logits_for)
P_gen = Categorical(logits = logits_gen)
#Take only generated tokens
x = x[..., prompt['input_ids'].shape[-1]:]
log_prob_for = P_for.log_prob(x)
log_prob_gen = P_gen.log_prob(x)
print("log-probs from forward")
print(log_prob_for)
print("log-probs from generate")
print(log_prob_gen) |
@LaurenceA if you run (#19067) |
System Info
transformers
version: 4.21.0Who can help?
@patil-suraj, @patrickvonplaten, @LysandreJik
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
I'm trying to get logits or log-probabilities from
generate
to match those fromforward
in the presence of a padded prompt.For GPT models, I managed to get almost everything working, by setting the
position_ids
forforward
(see MWE script).However, there still seems to be a slight mismatch with the first token, if the prompt has an attention mask. You can see this in the returned output, from this script, which is:
Note the slightly mismatch between the bottom-left log-prob, which doesn't happen for any other log-probability.
I've tried a few GPT flavour models: we get problem for
distilgpt2
,gpt2
,EleutherAI/gpt-neo-125M
andEleutherAI/gpt-neo-1.3B
. But the log-probs all match forEleutherAI/gpt-j-6B
.The text was updated successfully, but these errors were encountered: