Skip to content

Add Sdpa Support: [Electra]#37106

Open
nnilayy wants to merge 2 commits intohuggingface:mainfrom
nnilayy:electra_sdpa
Open

Add Sdpa Support: [Electra]#37106
nnilayy wants to merge 2 commits intohuggingface:mainfrom
nnilayy:electra_sdpa

Conversation

@nnilayy
Copy link
Copy Markdown
Contributor

@nnilayy nnilayy commented Mar 29, 2025

What does this PR do?

Towards #28005 and #37105

Adds SDPA (Scaled Dot-Product Attention) support for Google's Electra 🤗.

Sample benchmarks comparing sdpa and eager attention for both the google/electra-base-generator and google/electra-base-discriminator models, under both training and inference, are provided below.

Benchmarking scripts (adapted from @fxmarty's SDPA scripts) for the same, for training, inference, and execution of Electra Model on task of MaskedLM were made and are also linked below.

Electra Sdpa Benchmarking Scripts:

Reference Sdpa Benchmarking Scripts by @fxmarty :

PS: Memory Savings %'s remained consistent across the respective runs, but Speed %'s varied. To ensure reliability, each benchmark (for both training and inference on both models) was run five times, and the reported results are the mean across all runs. The full set of individual run results is linked in this Benchmarking Runs gist.

Benchmarks For google/electra-base-generator

Training Benchmark

num_training_steps batch_size seq_len is_cuda is_half Time per batch (eager - s) Time per batch (sdpa - s) Speedup (%) Eager peak mem (MB) sdpa peak mem (MB) Mem saving (%)
100 2 32 True True 0.0314 0.024 30.2028 249.684 249.684 0.0
100 2 64 True True 0.031 0.0244 27.0706 253.593 253.593 0.0
100 2 128 True True 0.0308 0.0238 30.0252 261.409 261.409 0.0
100 2 256 True True 0.03 0.024 26.7504 299.593 277.243 8.061
100 4 32 True True 0.0308 0.024 28.1694 253.593 253.593 0.0
100 4 64 True True 0.0306 0.0234 30.8344 261.409 261.409 0.0
100 4 128 True True 0.0302 0.0236 27.4524 283.864 277.243 2.388
100 4 256 True True 0.031 0.0252 22.8936 517.463 454.745 13.792
100 8 32 True True 0.0304 0.0232 31.0826 261.409 261.409 0.0
100 8 64 True True 0.03 0.0234 27.6406 277.243 277.243 0.0
100 8 128 True True 0.031 0.0246 25.3224 486.006 454.745 6.874
100 8 256 True True 0.04 0.0394 2.3676 947.451 821.915 15.274
100 16 32 True True 0.03 0.0238 26.6688 277.446 277.446 0.0
100 16 64 True True 0.0312 0.0242 28.5026 470.277 454.745 3.416
100 16 128 True True 0.0368 0.0378 -2.7766 884.536 821.915 7.619
100 16 256 True True 0.0774 0.0712 8.6964 1804.67 1553.798 16.146

Inference Benchmark

num_batches batch_size seq_len is_cuda is_half use_mask Per token latency (eager - ms) Per token latency (sdpa - ms) Speedup (%) Mem eager (MB) Mem sdpa (MB) Mem saved (%)
50 2 32 True True True 0.1566 0.1176 33.217 85.004 85.004 0.0
50 2 64 True True True 0.0778 0.0584 33.349 92.949 92.949 0.0
50 2 128 True True True 0.0424 0.0288 46.1208 108.84 109.043 -0.186
50 2 256 True True True 0.021 0.0164 27.676 140.825 140.825 0.0
50 4 32 True True True 0.0764 0.0604 26.8224 92.949 92.949 0.0
50 4 64 True True True 0.0386 0.0294 30.849 109.043 109.043 0.0
50 4 128 True True True 0.0198 0.0158 25.8548 140.825 140.825 0.0
50 4 256 True True True 0.0112 0.009 25.464 204.997 204.997 0.0
50 8 32 True True True 0.0372 0.0282 31.764 109.043 109.043 0.0
50 8 64 True True True 0.0194 0.0156 24.4376 140.825 140.825 0.0
50 8 128 True True True 0.0108 0.0092 21.5298 204.997 204.997 0.0
50 8 256 True True True 0.007 0.006 18.5566 332.935 332.935 0.0
50 16 32 True True True 0.02 0.0152 32.5968 140.825 140.825 0.0
50 16 64 True True True 0.0112 0.0092 23.3076 204.997 204.997 0.0
50 16 128 True True True 0.007 0.006 17.745 332.935 332.935 0.0
50 16 256 True True True 0.006 0.005 21.8322 585.568 585.568 0.0

Benchmarks For google/electra-base-discriminator

Training Benchmark

num_training_steps batch_size seq_len is_cuda is_half Time per batch (eager - s) Time per batch (sdpa - s) Speedup (%) Eager peak mem (MB) sdpa peak mem (MB) Mem saving (%)
100 2 32 True True 0.0312 0.0242 29.8518 568.647 566.126 0.445
100 2 64 True True 0.0308 0.024 29.3112 572.187 570.846 0.235
100 2 128 True True 0.0312 0.0248 25.9492 588.412 587.232 0.201
100 2 256 True True 0.031 0.0298 5.5548 631.001 602.644 4.705
100 4 32 True True 0.0312 0.0238 30.2574 572.187 570.846 0.235
100 4 64 True True 0.0312 0.0242 28.247 588.412 587.232 0.201
100 4 128 True True 0.0314 0.0286 10.2442 604.379 602.644 0.288
100 4 256 True True 0.053 0.0498 6.4654 1022.623 831.197 23.0304
100 8 32 True True 0.031 0.0242 28.8872 587.263 586.813 0.077
100 8 64 True True 0.0308 0.0278 10.7408 602.151 601.596 0.0924
100 8 128 True True 0.0492 0.0474 3.8086 927.308 831.197 11.5626
100 8 256 True True 0.105 0.0976 7.4594 1798.541 1418.122 26.8258
100 16 32 True True 0.0314 0.0286 10.8604 598.278 595.934 0.3934
100 16 64 True True 0.048 0.0464 3.7276 877.185 828.890 5.8264
100 16 128 True True 0.0986 0.0944 4.3308 1605.603 1415.605 13.4216
100 16 256 True True 0.2188 0.2036 7.4636 3318.753 2564.827 29.395

Inference Benchmark

num_batches batch_size seq_len is_cuda is_half use_mask Per token latency (eager - ms) Per token latency (sdpa - ms) Speedup (%) Mem eager (MB) Mem sdpa (MB) Mem saved (%)
50 2 32 True True True 0.1586 0.115 37.9428 244.13 243.868 0.107
50 2 64 True True True 0.0782 0.0568 37.5862 252.141 251.878 0.104
50 2 128 True True True 0.0424 0.03 40.611 268.163 268.103 0.022
50 2 256 True True True 0.0204 0.0166 24.129 300.409 300.147 0.087
50 4 32 True True True 0.0738 0.061 22.644 252.141 251.878 0.104
50 4 64 True True True 0.0382 0.0296 28.352 268.365 268.103 0.098
50 4 128 True True True 0.0206 0.016 27.1508 300.409 300.147 0.087
50 4 256 True True True 0.017 0.014 22.3272 365.106 364.844 0.072
50 8 32 True True True 0.0378 0.0292 29.0196 268.365 268.103 0.098
50 8 64 True True True 0.0204 0.016 27.1664 300.409 300.147 0.087
50 8 128 True True True 0.0158 0.0138 14.3932 365.106 364.844 0.072
50 8 256 True True True 0.017 0.014 21.5304 494.093 493.962 0.027
50 16 32 True True True 0.02 0.0164 20.9608 300.409 300.147 0.087
50 16 64 True True True 0.015 0.0138 10.5592 365.106 364.844 0.072
50 16 128 True True True 0.0156 0.0138 14.2278 494.093 493.962 0.027
50 16 256 True True True 0.017 0.0142 18.969 748.823 748.561 0.035

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@fxmarty @ArthurZucker @amyeroberts @LysandreJik

@github-actions github-actions Bot marked this pull request as draft March 29, 2025 14:37
@github-actions
Copy link
Copy Markdown
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@nnilayy nnilayy marked this pull request as ready for review March 29, 2025 14:51
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for opening the PR!

Comment on lines +347 to +348
# Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->Electra
class ElectraSdpaSelfAttention(ElectraSelfAttention):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey! Would you mind updating the SDPA and non SDPA to also support flex attention?
You have an example of this in modeling_llama..py with the ATTENTION_INTERFACE!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants