Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix kwargs handling in generate_with_fallback #29225

Merged
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,8 @@ def generate_with_fallback(
do_condition_on_prev_tokens,
kwargs,
):
kwargs = copy.copy(kwargs)

# 6.6 Batch generate current chunk
seek_sequence_list = [None for _ in range(cur_bsz)]
seek_outputs_list = [None for _ in range(cur_bsz)]
Expand All @@ -773,8 +775,12 @@ def generate_with_fallback(
generation_config.do_sample = temperature is not None and temperature > 0.0

generation_config.temperature = temperature if generation_config.do_sample else 1.0
generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1
generation_config.num_beams = kwargs.get("num_beams", 1) if not generation_config.do_sample else 1

generate_kwargs = copy.copy(kwargs)
for key in ["do_sample", "temperature", "num_beams"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

temperature shouldn't be in kwargs as it's already an argument of .generate here right ?

It seems okay to check for do_sample and num_beams here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to be extra cautious here and make sure everything is safe locally, rather than relying on what gets passed down from 2 call frames up the stack. But I can remove temperature if you prefer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me as is - there's a preference for more explicit handling of kwargs than more buried ones

if key in generate_kwargs:
del generate_kwargs[key]
seek_outputs = super().generate(
segment_input,
generation_config,
Expand All @@ -783,7 +789,7 @@ def generate_with_fallback(
prefix_allowed_tokens_fn,
synced_gpus,
decoder_input_ids=decoder_input_ids,
**kwargs,
**generate_kwargs,
)

# post-process sequence tokens and outputs to be in list form
Expand Down