diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f118cf5276b0..2716c79c702f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -999,11 +999,11 @@ def _get_candidate_generator( generation_config: GenerationConfig, input_ids: torch.LongTensor, inputs_tensor: torch.Tensor, - assistant_model: "PreTrainedModel", logits_processor: LogitsProcessorList, - target_tokenizer: "PreTrainedTokenizerBase", - assistant_tokenizer: "PreTrainedTokenizerBase", model_kwargs: dict, + assistant_model: Optional["PreTrainedModel"] = None, + target_tokenizer: Optional["PreTrainedTokenizerBase"] = None, + assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None, ) -> CandidateGenerator: """ Returns the candidate generator to be used in `assisted_generation` @@ -1300,7 +1300,6 @@ def _get_stopping_criteria( generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList], tokenizer: Optional["PreTrainedTokenizerBase"] = None, - **kwargs, ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if generation_config.max_length is not None: @@ -1493,35 +1492,38 @@ def compute_transition_scores( return transition_scores - def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer): - if assistant_model is None: - return - - if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder: - attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"] - attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check] - are_equal = all( - getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check + def _validate_generation_mode(self, generation_mode, generation_mode_kwargs): + if generation_mode == GenerationMode.BEAM_SEARCH and "streamer" in generation_mode_kwargs: + raise ValueError( + "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." ) - if not are_equal: - raise ValueError( - "The main model and the assistant don't have compatible encoder-dependent input shapes. " - "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper." - ) - doc_reference = ( - "(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)" - ) - if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size: - if assistant_tokenizer is not None: - raise ValueError( - f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}." - ) - else: - if tokenizer is None or assistant_tokenizer is None: - raise ValueError( - f"The main and assistant models have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}." + if (assistant_model := generation_mode_kwargs.get("assistant_model")) is not None: + if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder: + attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"] + attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check] + are_equal = all( + getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check ) + if not are_equal: + raise ValueError( + "The main model and the assistant don't have compatible encoder-dependent input shapes. " + "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper." + ) + + doc_reference = ( + "(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)" + ) + if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size: + if "assistant_tokenizer" in generation_mode_kwargs: + raise ValueError( + f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}." + ) + else: + if "tokenizer" not in generation_mode_kwargs or "assistant_tokenizer" not in generation_mode_kwargs: + raise ValueError( + f"The main and assistant models have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}." + ) def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" @@ -1869,7 +1871,7 @@ def _prepare_cache_for_generation( self, generation_config: GenerationConfig, model_kwargs: dict, - assistant_model: "PreTrainedModel", + generation_mode: GenerationMode, batch_size: int, max_cache_length: int, ) -> bool: @@ -1923,7 +1925,10 @@ def _prepare_cache_for_generation( # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, # which is only supported in dynamic caches atm - if assistant_model is not None and generation_config.cache_implementation is not None: + if ( + generation_mode == GenerationMode.ASSISTED_GENERATION + and generation_config.cache_implementation is not None + ): logger.warning_once( "An assistant model is provided, using a dynamic cache instead of a cache of type=" f"'{generation_config.cache_implementation}'." @@ -1933,7 +1938,6 @@ def _prepare_cache_for_generation( # Assisted decoding and contrastive search require cache rollback, which is incompatible with sliding layers. # To handle this, we skip passing the model config to DynamicCache (forcing a full-layer cache). # The "dynamic_full" option is a shortcut for generate() users to avoid sliding layers on their own. - generation_mode = generation_config.get_generation_mode(assistant_model) if ( generation_mode in (GenerationMode.ASSISTED_GENERATION, GenerationMode.CONTRASTIVE_SEARCH) or generation_config.cache_implementation == "dynamic_full" @@ -2125,15 +2129,13 @@ def _valid_auto_compile_criteria(self, model_kwargs: dict, generation_config: Ge def _get_deprecated_gen_repo( self, - generation_config: GenerationConfig, + generation_mode: GenerationMode, trust_remote_code: bool, custom_generate: Optional[str] = None, - assistant_model: Optional["PreTrainedModel"] = None, ) -> Optional[str]: """ - Returns the Hub repo for a deprecated generation strategy, if any. + Returns the Hub repo for a deprecated generation mode, if any. """ - generation_mode = generation_config.get_generation_mode(assistant_model) moved_to_hub_modes = { GenerationMode.DOLA_GENERATION: "transformers-community/dola", GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search", @@ -2156,6 +2158,37 @@ def _get_deprecated_gen_repo( ) return repo + def _extract_generation_mode_kwargs( + self, + custom_generate, + kwargs, + synced_gpus, + assistant_model, + streamer, + ) -> dict[str, Any]: + """ + Extracts and returns the generation mode related keyword arguments from the provided kwargs. + """ + generation_mode_kwargs = { + "tokenizer": kwargs.pop("tokenizer", None), + "assistant_tokenizer": kwargs.pop("assistant_tokenizer", None), + "assistant_model": assistant_model, + "streamer": streamer, + } + if synced_gpus is not None: + generation_mode_kwargs["synced_gpus"] = ( + is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + ) and dist.get_world_size() > 1 + generation_mode_kwargs = {k: v for k, v in generation_mode_kwargs.items() if v is not None} + # Custom_generate callables can have their own set of arguments + # To extract them, we compare the signature with the standard _sample method + if isinstance(custom_generate, Callable): + usual_mode_kwargs = inspect.signature(GenerationMixin._sample).parameters.keys() + custom_generate_kwargs = inspect.signature(custom_generate).parameters.keys() + new_custom_keys = custom_generate_kwargs - usual_mode_kwargs + generation_mode_kwargs = {k: kwargs.pop(k) for k in new_custom_keys if k in kwargs} + return generation_mode_kwargs + @torch.no_grad() def generate( self, @@ -2292,47 +2325,46 @@ def generate( ) return custom_generate_function(model=self, **generate_arguments) - # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call - tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria - assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation + # 1. Handle kwargs, `generation_config`, validate them and obtain generation mode + generation_mode_kwargs = self._extract_generation_mode_kwargs( + custom_generate, + kwargs, + synced_gpus, + assistant_model, + streamer, + ) generation_config, model_kwargs = self._prepare_generation_config( generation_config, use_model_defaults, **kwargs ) + generation_mode = generation_config.get_generation_mode(assistant_model) + self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) + self._validate_generation_mode(generation_mode, generation_mode_kwargs) # Deprecation-related step: set Hub repo for deprecated strategies. # NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode. # It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps. # TODO joao, manuel: remove this in v4.62.0 - if deprecate_mode_repo := self._get_deprecated_gen_repo( - generation_config, trust_remote_code, custom_generate, assistant_model - ): + if deprecate_mode_repo := self._get_deprecated_gen_repo(generation_mode, trust_remote_code, custom_generate): return GenerationMixin.generate( self, - inputs, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - assistant_model, - streamer, - negative_prompt_ids, - negative_prompt_attention_mask, - use_model_defaults, + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + assistant_model=assistant_model, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + use_model_defaults=use_model_defaults, custom_generate=deprecate_mode_repo, trust_remote_code=trust_remote_code, - tokenizer=tokenizer, - assistant_tokenizer=assistant_tokenizer, + **generation_mode_kwargs, **kwargs, ) # 2. Set generation parameters if not already defined - if synced_gpus is None: - synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1 - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() @@ -2406,7 +2438,7 @@ def generate( ) if generation_config.token_healing: - input_ids = self.heal_tokens(input_ids, tokenizer) + input_ids = self.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer")) if streamer is not None: streamer.put(input_ids.cpu()) @@ -2444,17 +2476,9 @@ def generate( ): max_cache_length += inputs_tensor.shape[1] self._prepare_cache_for_generation( - generation_config, model_kwargs, assistant_model, batch_size, max_cache_length + generation_config, model_kwargs, generation_mode, batch_size, max_cache_length ) - # 8. determine generation mode - generation_mode = generation_config.get_generation_mode(assistant_model) - - if streamer is not None and (generation_config.num_beams > 1): - raise ValueError( - "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." - ) - if self.device.type != input_ids.device.type: warnings.warn( "You are calling .generate() with the `input_ids` being on a device type different" @@ -2466,7 +2490,7 @@ def generate( UserWarning, ) - # 9. prepare logits processors and stopping criteria + # 8. prepare logits processors and stopping criteria prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, @@ -2479,13 +2503,15 @@ def generate( negative_prompt_attention_mask=negative_prompt_attention_mask, ) prepared_stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs + generation_config=generation_config, + stopping_criteria=stopping_criteria, + tokenizer=generation_mode_kwargs.get("tokenizer"), ) # Set model_kwargs `use_cache` so we can use it later in forward runs model_kwargs["use_cache"] = generation_config.use_cache - # 10. go into different generation modes + # 9. go into different generation modes if isinstance(custom_generate, Callable): result = custom_generate( self, @@ -2493,8 +2519,7 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, + **generation_mode_kwargs, **model_kwargs, ) elif generation_mode == GenerationMode.ASSISTED_GENERATION: @@ -2516,50 +2541,48 @@ def generate( f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" ) - # 11. Get the candidate generator, given the parameterization + # 10. Get the candidate generator, given the parameterization candidate_generator = self._get_candidate_generator( generation_config=generation_config, input_ids=input_ids, inputs_tensor=inputs_tensor, - assistant_model=assistant_model, + assistant_model=generation_mode_kwargs.pop("assistant_model", None), logits_processor=logits_processor, - target_tokenizer=tokenizer, - assistant_tokenizer=assistant_tokenizer, + target_tokenizer=generation_mode_kwargs.pop("tokenizer", None), + assistant_tokenizer=generation_mode_kwargs.pop("assistant_tokenizer", None), model_kwargs=model_kwargs, ) - # 12. run assisted generate + # 11. run assisted generate result = self._assisted_decoding( input_ids, candidate_generator=candidate_generator, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, + **generation_mode_kwargs, **model_kwargs, ) elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) + # 10. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) result = self._sample( input_ids, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, + **generation_mode_kwargs, **model_kwargs, ) elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): - # 11. run beam sample + # 10. run beam sample result = self._beam_search( input_ids, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, + **generation_mode_kwargs, **model_kwargs, ) @@ -2681,8 +2704,8 @@ def _sample( logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, - synced_gpus: bool, - streamer: Optional["BaseStreamer"], + synced_gpus: bool = False, + streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -3110,7 +3133,7 @@ def _beam_search( logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, - synced_gpus: bool, + synced_gpus: bool = False, **model_kwargs, ) -> Union[GenerateBeamOutput, torch.LongTensor]: r""" @@ -3447,8 +3470,8 @@ def _assisted_decoding( logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, - synced_gpus: bool, - streamer: Optional["BaseStreamer"], + synced_gpus: bool = False, + streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" diff --git a/src/transformers/models/csm/generation_csm.py b/src/transformers/models/csm/generation_csm.py index 9c2f06e6562f..b14f353685c2 100644 --- a/src/transformers/models/csm/generation_csm.py +++ b/src/transformers/models/csm/generation_csm.py @@ -153,8 +153,8 @@ def _sample( logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, - synced_gpus: bool, - streamer: Optional["BaseStreamer"], + synced_gpus: bool = False, + streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: """ diff --git a/src/transformers/models/dia/generation_dia.py b/src/transformers/models/dia/generation_dia.py index 22b607ec2865..439b498b0988 100644 --- a/src/transformers/models/dia/generation_dia.py +++ b/src/transformers/models/dia/generation_dia.py @@ -265,14 +265,20 @@ def _main_generate_loop( ): # ********** mostly taken from main generate function up to calling the different methods (see NOTE) ********** # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call - tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria - assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation - + generation_mode_kwargs = self._extract_generation_mode_kwargs( + custom_generate, + kwargs, + synced_gpus, + assistant_model, + streamer, + ) generation_config, model_kwargs = self._prepare_generation_config( generation_config, use_model_defaults, **kwargs ) + generation_mode = generation_config.get_generation_mode(assistant_model) + self._validate_model_kwargs(model_kwargs.copy()) - self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) + self._validate_generation_mode(generation_mode, generation_mode_kwargs) # 2. Set generation parameters if not already defined if synced_gpus is None: @@ -308,7 +314,7 @@ def _main_generate_loop( ) if generation_config.token_healing: - input_ids = self.heal_tokens(input_ids, tokenizer) + input_ids = self.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer")) if streamer is not None: streamer.put(input_ids.cpu()) @@ -347,18 +353,10 @@ def _main_generate_loop( ): max_cache_length += inputs_tensor.shape[1] self._prepare_cache_for_generation( - generation_config, model_kwargs, assistant_model, batch_size, max_cache_length + generation_config, model_kwargs, generation_mode, batch_size, max_cache_length ) - # 8. determine generation mode - generation_mode = generation_config.get_generation_mode(assistant_model) - - if streamer is not None and (generation_config.num_beams > 1): - raise ValueError( - "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." - ) - - # 9. prepare logits processors and stopping criteria + # 8. prepare logits processors and stopping criteria prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, @@ -371,7 +369,9 @@ def _main_generate_loop( negative_prompt_attention_mask=negative_prompt_attention_mask, ) prepared_stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs + generation_config=generation_config, + stopping_criteria=stopping_criteria, + tokenizer=generation_mode_kwargs.get("tokenizer"), ) # Set model_kwargs `use_cache` so we can use it later in forward runs @@ -393,8 +393,7 @@ def _main_generate_loop( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, + **generation_mode_kwargs, **model_kwargs, ) else: diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index 641eec0634d8..c10a0f80acf1 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -1222,7 +1222,7 @@ def _prepare_model_inputs( self.codec_model._prepare_cache_for_generation( generation_config=self.codec_model.generation_config, model_kwargs=temporary_model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=batch_size, max_cache_length=self.config.codec_config.sliding_window, ) diff --git a/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py index 03b442b2edbd..8541a911e947 100644 --- a/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py @@ -357,7 +357,7 @@ def _prepare_model_inputs( self.codec_model._prepare_cache_for_generation( generation_config=self.codec_model.generation_config, model_kwargs=temporary_model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=batch_size, max_cache_length=self.config.codec_config.sliding_window, ) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index d16c32bd5cdf..8a66f9e13912 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1258,7 +1258,7 @@ def generate( self._prepare_cache_for_generation( generation_config, model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=batch_size, max_cache_length=max_cache_length, ) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index bd103e36c034..58c012a1cfb9 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -2173,7 +2173,7 @@ def generate( self._prepare_cache_for_generation( generation_config, model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=batch_size, max_cache_length=max_cache_length, ) diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 5f1c592d3230..f3932137a082 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1566,7 +1566,7 @@ def extend_enc_output(tensor, num_beams=None): self._prepare_cache_for_generation( generation_config, model_kwargs, - assistant_model=None, + generation_mode=None, batch_size=input_ids.shape[0], max_cache_length=generation_config.max_length - 1, )