Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable flag to not pass PAD tokens in ffwd #775

Merged
merged 21 commits into from
Dec 11, 2023
Merged

Enable flag to not pass PAD tokens in ffwd #775

merged 21 commits into from
Dec 11, 2023

Conversation

bcui19
Copy link
Contributor

@bcui19 bcui19 commented Dec 4, 2023

This PR does two things:

  1. Modifies the attn_bias function to always return the attention_mask.
  2. Enables us to remove pad tokens before calling .forward on the ffwd network then re-add in the pad tokens.

Loss curves on a fully randomly initialized network:
image

We also get slightly higher throughput from this when there are PAD tokens in our dataset (and no degradation when compared to main with attn_impl: triton:
image

wandb: https://wandb.ai/mosaic-ml/padding_check?workspace=user-bcui

@bcui19 bcui19 changed the title [DRAFT] Changing how attention_mask is being passed around Enable flag to not pass PAD tokens in ffwd Dec 4, 2023
@bcui19 bcui19 marked this pull request as ready for review December 4, 2023 22:03
@bcui19 bcui19 requested a review from a team as a code owner December 4, 2023 22:03
@vchiley
Copy link
Contributor

vchiley commented Dec 4, 2023

We also get slightly higher throughput from this when there are PAD tokens in our dataset:

Can you run main vs your branch with use_pad_tok_in_ffwd flag vs your branch without use_pad_tok_in_ffwd flag?

Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test that tests numerical equivalence of computation with and without the flag? might be off by a bit because of numerics, but lets see.

Copy link
Collaborator

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nits

llmfoundry/models/layers/blocks.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/blocks.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/blocks.py Outdated Show resolved Hide resolved
bcui19 and others added 2 commits December 7, 2023 16:24
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
@bcui19 bcui19 requested a review from vchiley December 11, 2023 21:03
tests/models/test_model.py Outdated Show resolved Hide resolved
@vchiley vchiley self-requested a review December 11, 2023 21:41
Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

llmfoundry/models/mpt/configuration_mpt.py Outdated Show resolved Hide resolved
tests/models/test_model.py Outdated Show resolved Hide resolved
bcui19 and others added 2 commits December 11, 2023 17:03
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
@bcui19 bcui19 merged commit 410d5c7 into main Dec 11, 2023
8 checks passed
@dakinggg dakinggg deleted the mask_pad_token branch February 3, 2024 01:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants