Refactor output handling in generate for cleaner decoding methods#40887
Refactor output handling in generate for cleaner decoding methods#40887manueldeprada wants to merge 22 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This reverts commit e3aed39.
|
@gante this PR is more of a RFC to see what you think than a full PR. If you agree this simplifies generate, I will put more work to make it clean for assisted gen and make tests happy! |
|
related: #39834 |
gante
left a comment
There was a problem hiding this comment.
Very much on board with this 👍 👍 👍
|
@gante This is ready for review, left some comments in the code. The I would suggest merging this PR as-is, and then I can make a second PR that enables custom |
| output_attentions = generation_config.output_attentions | ||
| output_hidden_states = generation_config.output_hidden_states | ||
| output_scores = generation_config.output_scores | ||
| output_logits = generation_config.output_logits |
There was a problem hiding this comment.
Since output_x comes from generation config, how do you suggest we enable extra generation outputs?
It could be a output_features=['attentions', 'hidden_states', 'scores'] etc
| "will be skipped." | ||
| ) | ||
|
|
||
| if can_compile: |
zucchini-nlp
left a comment
There was a problem hiding this comment.
hey @manueldeprada , great job! I am happy to have a first step for better generation output handling.
Do you think we can make the dynamic output dict in this PR, since we already started the refactor? Would be super cool to get rid of near-duplicate code
| if not generation_config.return_dict_in_generate: | ||
| return {"return_dict_in_generate": False, "next_scores": None} | ||
| output_attentions = generation_config.output_attentions | ||
| output_hidden_states = generation_config.output_hidden_states | ||
| output_scores = generation_config.output_scores | ||
| output_logits = generation_config.output_logits | ||
|
|
||
| next_scores = () if output_scores else None | ||
| next_logits = () if output_logits else None | ||
| decoder_attentions = () if output_attentions else None | ||
| cross_attentions = () if output_attentions and self.config.is_encoder_decoder else None | ||
| decoder_hidden_states = () if output_hidden_states else None | ||
|
|
||
| encoder_attentions = encoder_hidden_states = None |
There was a problem hiding this comment.
i think we have to push further and make it output any value dynamically, as requested by users. Currently the PR splits out existing logic into its own fn but the existing code is very much repetitive
IMO we can check the model_outputs.keys() and dynamically update our generation output dict with the keys that are available in getattr(generation_config, f"output_{key}"). Since all models follow standard naming in output dict, it should have no edge cases
| if cur_len is not None: | ||
| for arg in splittable_args: | ||
| if generate_output.get(arg) is not None: | ||
| kwargs[arg] = _split_model_outputs( | ||
| kwargs[arg], | ||
| cur_len, | ||
| added_len, | ||
| is_prefill_pass=len(generate_output[arg]) == 0, | ||
| is_decoder_attention=(arg == "decoder_attentions"), | ||
| ) | ||
| for arg in cropable_args: | ||
| if generate_output[arg] is not None: | ||
| kwargs[arg] = tuple(kwargs[arg][:, i, :] for i in range(added_len)) | ||
| else: | ||
| for arg in all_args: | ||
| if generate_output.get(arg) is not None: | ||
| kwargs[arg] = (kwargs[arg],) |
There was a problem hiding this comment.
hmm, this could be simplified no, if we set cur_len=1 as default. Then we can always try to split the output, it will catch up depending on length value
| if any(cache_key in model_kwargs for cache_key in ALL_CACHE_NAMES): | ||
| cache_key = next(cache_key for cache_key in ALL_CACHE_NAMES if cache_key in model_kwargs) | ||
| cache = model_kwargs[cache_key] |
There was a problem hiding this comment.
nit: with smth like caches_in_kwargs := [cache_key in model_kwargs for cache_key in ALL_CACHE_NAMES] we can avoid looping twice
| return encoder_decoder_cls( | ||
| sequences=sequences, | ||
| scores=generate_output["next_scores"], | ||
| logits=generate_output["next_logits"], | ||
| encoder_attentions=generate_output["encoder_attentions"], | ||
| encoder_hidden_states=generate_output["encoder_hidden_states"], | ||
| decoder_attentions=generate_output["decoder_attentions"], | ||
| cross_attentions=generate_output["cross_attentions"], | ||
| decoder_hidden_states=generate_output["decoder_hidden_states"], | ||
| past_key_values=cache, | ||
| **kwargs, | ||
| ) | ||
| else: | ||
| return decoder_only_cls( | ||
| sequences=sequences, | ||
| scores=generate_output["next_scores"], | ||
| logits=generate_output["next_logits"], | ||
| attentions=generate_output["decoder_attentions"], | ||
| hidden_states=generate_output["decoder_hidden_states"], | ||
| past_key_values=cache, | ||
| **kwargs, | ||
| ) | ||
|
|
There was a problem hiding this comment.
maybe for future, would be great to not rely on expected set of keys and unpack everything in generate_output to the output dict. The GenerationOutput dict would have to be able to output anything for that.
Pseudo code like below
output_cls = encoder_decoder_cls if self.config.is_encoder_decoder else decoder_cls
return output_cls(sequences=sequences, past_key_values=cache, **generate_output)
|
hey @manueldeprada , how is this going? Do you plan to finish it or I can take over, since returning arbitrary outputs from model is important for new multimodal models |
Each decoding method has a common block of output handling boilerplate that worsens readability:
This PR takes that boilerplate to reusable generate helpers
TODO: add generalization so that users can say output_x and x from forward gets forwarded.