Skip to content

Commit

Permalink
Add inputs_embeds functionality when generating with GPT-Neox (huggin…
Browse files Browse the repository at this point in the history
…gface#22916)

* support gpt neox generate with inputs embeds

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py

great thx for the suggestion!

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

---------

Co-authored-by: Lei Li <tobiaslee@qq.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
  • Loading branch information
3 people authored and novice03 committed Jun 23, 2023
1 parent 7f17603 commit 8e76e09
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,9 @@ def forward(
attentions=outputs.attentions,
)

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
input_shape = input_ids.shape

# cut decoder_input_ids if past is used
Expand All @@ -716,12 +718,21 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
}
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
)

return model_inputs

def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
Expand Down

0 comments on commit 8e76e09

Please sign in to comment.