Skip to content

Commit

Permalink
pass params to encode (#14370)
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj authored Nov 11, 2021
1 parent e92190c commit b1dbdf2
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/transformers/generation_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ def _run_loop_in_debug(cond_fn, body_fn, init_state):
state = body_fn(state)
return state

def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs):
def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs):
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
}
model_kwargs["encoder_outputs"] = self.encode(input_ids, return_dict=True, **encoder_kwargs)
model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
return model_kwargs

@staticmethod
Expand Down Expand Up @@ -251,7 +251,7 @@ def generate(

if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
# prepare decoder_input_ids for generation
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

Expand Down

0 comments on commit b1dbdf2

Please sign in to comment.