Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support returning raw logits in generate #17521

Closed
shijie-wu opened this issue Jun 2, 2022 · 19 comments
Closed

Support returning raw logits in generate #17521

shijie-wu opened this issue Jun 2, 2022 · 19 comments

Comments

@shijie-wu
Copy link
Contributor

Feature request

Support returning raw logits in generate by either:

  1. creating a new arg that enables return of raw logits
  2. or support callback that allow users to collect the raw logits

Motivation

  • Raw logits "would be the most understandable & consistent across generation methods" (@patrickvonplaten)
  • For testing, returning raw logits would help "identify which parts get wrong if any test failure occurs" (@ydshieh)
  • There's concern about "rampant too many options" (@Narsil), thus I would prefer the second option to support this feature.
  • However, the second option still needs code change to support it. As the user provided logits_processor is appended to a new instance of LogitsProcessorList. As a result, users cannot get the raw logits using the current implementation even with a custom LogitsProcessor.

See further discussion in #17424

Your contribution

I could open a PR to reorder how logits_processor is merged with the predefined list of LogitsProcessorList.

@patrickvonplaten
Copy link
Contributor

cc @patil-suraj @gante as well

@patrickvonplaten
Copy link
Contributor

I'm personally fine with adding a output_logits flag to generate since it already has 50+ flags it won't make a difference and it's a useful feature indeed. What do you think @patil-suraj @gante ?

@gante
Copy link
Member

gante commented Jun 2, 2022

I'm cool with it 👍 (and it might be interesting to use as part of PT-TF cross tests)

@patrickvonplaten
Copy link
Contributor

@patil-suraj what do you think? Do you want to open a PR to work on it?

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 3, 2022

@patil-suraj what do you think? Do you want to open a PR to work on it?

@shijie-wu seems willing to open a PR, as mentioned at the end of the issue description.

@shijie-wu
Copy link
Contributor Author

I could open a PR for this.

@patil-suraj
Copy link
Contributor

I'm okay with this, let me know if you need any help @shijie-wu :)

@patrickvonplaten
Copy link
Contributor

Cool thanks for taking care of it @shijie-wu

@huggingface huggingface deleted a comment from github-actions bot Jul 4, 2022
@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Aug 5, 2022
@lxianl455
Copy link

So, is there any work on it? I did not find a new feature about getting the raw logits.

@gante
Copy link
Member

gante commented Oct 23, 2022

I don't think so -- gently pinging @shijie-wu, who manifested interest in opening a PR :)

@shijie-wu
Copy link
Contributor Author

sorry about the delay! i will resume working on it in the coming week.

@xkianteb
Copy link

gently ping @shijie-wu --- any updates on this?

@xkianteb
Copy link

@gante should I open a PR? I think the change is fairly minor.

@gante
Copy link
Member

gante commented Mar 26, 2023

@xkianteb sounds good 👍

@khalidsaifullaah
Copy link

is there any update on this...?

@gante
Copy link
Member

gante commented Aug 3, 2023

None that I know of. Open to contributors :)

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Feb 1, 2024

Hey for folks running into this issue: I have a snippet already getting the raw logits. Prob related to your quest as well @xkianteb . It's for RLHF PPO so you don't have to do another forward pass to get the logprobs.

import torch
import transformers
import torch.nn.functional as F

tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="right")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
pad_id = tokenizer.pad_token_id
policy = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
policy.generation_config.pad_token_id = policy.generation_config.eos_token_id

query = torch.tensor([
    [pad_id, pad_id, 23073],
    [pad_id, pad_id, 234],
])
temperature = 0.7
context_length = query.shape[1]

def forward(model, query_responses, tokenizer):
    attention_mask = query_responses != tokenizer.pad_token_id
    position_ids = attention_mask.cumsum(1) - attention_mask.long()
    input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
    return model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        return_dict=True,
        output_hidden_states=True,
    )

def generate_and_return_logits(lm_backbone, queries, tokenizer, generation_config):
    """generate in a way that does not affect padding tokens"""
    context_length = queries.shape[1]
    attention_mask = queries != tokenizer.pad_token_id
    input_ids = torch.masked_fill(queries, ~attention_mask, 0)
    output = lm_backbone.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # already handled in generation
        generation_config=generation_config,
        return_dict_in_generate=True,
        output_scores=True
    )
    logits = torch.stack(output.scores, 1)
    return torch.cat((queries, output.sequences[:, context_length:]), dim=1), logits

generation_config = transformers.GenerationConfig(
    max_new_tokens=5,
    min_new_tokens=5,
    temperature=temperature,
    top_k=0.0,
    top_p=1.0,
    do_sample=True,
)
query_response, logits = generate_and_return_logits(policy, query, tokenizer, generation_config)
response = query_response[:, context_length:]
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
print(f"{response=}")
print(f"{logprob=}")

output = forward(policy, query_response, tokenizer)
logits = output.logits[:, context_length - 1 : -1]
logits /= temperature
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
print(f"{logprob=}")
response=tensor([[  198,   198,     3,   399,   532],
        [  198,   198, 48412,  4803, 19321]])
logprob=tensor([[-3.2519e+00, -5.9604e-06, -5.2666e+00, -7.8440e+00, -2.6367e+00],
        [-1.5943e+00, -5.6028e-06, -9.8833e+00, -2.3764e+00, -4.8006e+00]])
logprob=tensor([[-3.2519e+00, -5.9604e-06, -5.2666e+00, -7.8440e+00, -2.6367e+00],
        [-1.5943e+00, -5.6028e-06, -9.8833e+00, -2.3764e+00, -4.8006e+00]],
       grad_fn=<SqueezeBackward1>)

@gante
Copy link
Member

gante commented Feb 7, 2024

(see #28667)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants