Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inputs_embeds functionality when generating with GPT-Neox #22916

Merged
merged 2 commits into from
Apr 21, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,9 @@ def forward(
attentions=outputs.attentions,
)

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
input_shape = input_ids.shape

# cut decoder_input_ids if past is used
Expand All @@ -716,12 +718,20 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
}
if inputs_embeds is not None and past_key_values is None:
TobiasLee marked this conversation as resolved.
Show resolved Hide resolved
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
)

return model_inputs

def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
Expand Down