Skip to content

Commit

Permalink
feat(flax): allow encoder_outputs in generate (#15554)
Browse files Browse the repository at this point in the history
* feat(flax): allow encoder_outputs in generate

* doc(flax): encoder_outputs in generate

* fix: style

* fix: style
  • Loading branch information
borisdayma authored Feb 8, 2022
1 parent 8406fa6 commit 077c00c
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/transformers/generation_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def generate(
params (`Dict[str, jnp.ndarray]`, *optional*):
Optionally the model parameters can be passed. Can be useful for parallelized generation.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model.
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
should be prefixed with *decoder_*. Also accepts `encoder_outputs` to skip encoder part.
Return:
[`~file_utils.ModelOutput`].
Expand Down Expand Up @@ -251,7 +253,8 @@ def generate(

if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
if model_kwargs.get("encoder_outputs") is None:
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
# prepare decoder_input_ids for generation
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

Expand Down

0 comments on commit 077c00c

Please sign in to comment.