Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check if position_ids exists before using it #29306

Merged
merged 5 commits into from Feb 28, 2024
Merged

Conversation

jiqing-feng
Copy link
Contributor

@jiqing-feng jiqing-feng commented Feb 27, 2024

Hi @ArthurZucker and @younesbelkada

I think we should check if position_ids is None or not before we use it.

@@ -1274,7 +1274,11 @@ def prepare_inputs_for_generation(

# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
cache_position = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the attention mask is always passed to the model so the position ids are always created before this!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, should we set attention_mask as a required parameter? like

def prepare_inputs_for_generation(
        self, input_ids, attention_mask, past_key_values=None, inputs_embeds=None, **kwargs
    ):

Because we cannot control user behavior, I think we should avoid this error in codes or set a warning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have not seen any issue so far. If past key values is not None, the attention mask is created as well and always passed to the model by the generate function. Let's just check if positions ids exist or just use input_ids instead for device placement and shape as it always exists

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea! Have fixed it according to your comments:)

@jiqing-feng
Copy link
Contributor Author

The CI is weird, I can pass the CI locally, could you please help to re-run the CI? Thx!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM let's have a second look from @gante

@ArthurZucker
Copy link
Collaborator

feel free to merge from main and run make fixup to make sure CIs go green

@jiqing-feng
Copy link
Contributor Author

The CI still failed, and it is my local CI result:
image

@gante
Copy link
Member

gante commented Feb 28, 2024

@jiqing-feng I hope you don't mind, I took the liberty to fix the failing test case 🤗 TL;DR when inputs_embeds is passed, we can't rely on input_ids to get the correct input length

@gante gante merged commit 554e7ad into huggingface:main Feb 28, 2024
19 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@jiqing-feng
Copy link
Contributor Author

@jiqing-feng I hope you don't mind, I took the liberty to fix the failing test case 🤗 TL;DR when inputs_embeds is passed, we can't rely on input_ids to get the correct input length

Thanks for your fix!

@jiqing-feng jiqing-feng deleted the llama branch February 29, 2024 05:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants