New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support returning raw logits in generate
#17521
Comments
cc @patil-suraj @gante as well |
I'm personally fine with adding a |
I'm cool with it 👍 (and it might be interesting to use as part of PT-TF cross tests) |
@patil-suraj what do you think? Do you want to open a PR to work on it? |
@shijie-wu seems willing to open a PR, as mentioned at the end of the issue description. |
I could open a PR for this. |
I'm okay with this, let me know if you need any help @shijie-wu :) |
Cool thanks for taking care of it @shijie-wu |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
So, is there any work on it? I did not find a new feature about getting the raw logits. |
I don't think so -- gently pinging @shijie-wu, who manifested interest in opening a PR :) |
sorry about the delay! i will resume working on it in the coming week. |
gently ping @shijie-wu --- any updates on this? |
@gante should I open a PR? I think the change is fairly minor. |
@xkianteb sounds good 👍 |
is there any update on this...? |
None that I know of. Open to contributors :) |
Hey for folks running into this issue: I have a snippet already getting the raw logits. Prob related to your quest as well @xkianteb . It's for RLHF PPO so you don't have to do another forward pass to get the logprobs. import torch
import transformers
import torch.nn.functional as F
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="right")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
pad_id = tokenizer.pad_token_id
policy = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
policy.generation_config.pad_token_id = policy.generation_config.eos_token_id
query = torch.tensor([
[pad_id, pad_id, 23073],
[pad_id, pad_id, 234],
])
temperature = 0.7
context_length = query.shape[1]
def forward(model, query_responses, tokenizer):
attention_mask = query_responses != tokenizer.pad_token_id
position_ids = attention_mask.cumsum(1) - attention_mask.long()
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
return model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
output_hidden_states=True,
)
def generate_and_return_logits(lm_backbone, queries, tokenizer, generation_config):
"""generate in a way that does not affect padding tokens"""
context_length = queries.shape[1]
attention_mask = queries != tokenizer.pad_token_id
input_ids = torch.masked_fill(queries, ~attention_mask, 0)
output = lm_backbone.generate(
input_ids=input_ids,
attention_mask=attention_mask,
# position_ids=attention_mask.cumsum(1) - attention_mask.long(), # already handled in generation
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True
)
logits = torch.stack(output.scores, 1)
return torch.cat((queries, output.sequences[:, context_length:]), dim=1), logits
generation_config = transformers.GenerationConfig(
max_new_tokens=5,
min_new_tokens=5,
temperature=temperature,
top_k=0.0,
top_p=1.0,
do_sample=True,
)
query_response, logits = generate_and_return_logits(policy, query, tokenizer, generation_config)
response = query_response[:, context_length:]
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
print(f"{response=}")
print(f"{logprob=}")
output = forward(policy, query_response, tokenizer)
logits = output.logits[:, context_length - 1 : -1]
logits /= temperature
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
print(f"{logprob=}")
|
(see #28667) |
Feature request
Support returning raw logits in
generate
by either:Motivation
logits_processor
is appended to a new instance ofLogitsProcessorList
. As a result, users cannot get the raw logits using the current implementation even with a customLogitsProcessor
.See further discussion in #17424
Your contribution
I could open a PR to reorder how
logits_processor
is merged with the predefined list ofLogitsProcessorList
.The text was updated successfully, but these errors were encountered: