Skip to content

Conversation

vasqu
Copy link
Contributor

@vasqu vasqu commented May 22, 2025

Keeping track of the models that are done:

  • Bert
  • Roberta
  • Albert
  • Data2VecText
  • Xmod
  • Electra
  • XLM Roberta
  • Roberta Prelayernorm
  • Ernie
  • Camembert
  • Bert Generation
  • XLM Robert XL
  • RoCBert
  • Mobile Bert

Up to discussion:

  • Flash attention is flaky; I suspect the norm layers to be responsible (encountered similar things with gemma3)

Would need another round; questionable if worth it (ordered by prio):

  • Tapas
  • Vision x Text models, e.g.
    • Bridgetower (text would be good to go, would need image counterpart)
    • Altclip
    • ...
  • RemBert
  • Megatron Bert

@vasqu vasqu changed the title 🔴[Atttention] Bert-based Models Attention Refactor 🔴[Attention] Bert-based Models Attention Refactor May 22, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu
Copy link
Contributor Author

vasqu commented Jun 30, 2025

run-slow: bert

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/bert']
quantizations: [] ...

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I am asking a lot but I believe this will contribute to unbloat our code more!

@vasqu vasqu marked this pull request as ready for review September 17, 2025 13:55
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's GOOOOOO


if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
position_ids = self.position_ids[:, :seq_length]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's because we expect the pos id to be correct right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, either

  • in base transformers fashion we expect the correct positions or use this as a fallback (will always work with right padding)
  • in vLLM we expect the correct position ids. This is also done in the padding vs padding free test. Also informed Harry about this requirement.

@ArthurZucker
Copy link
Collaborator

run-slow: bert, auto, bart, roberta

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/auto', 'models/bart', 'models/bert', 'models/roberta']
quantizations: [] ...

@vasqu
Copy link
Contributor Author

vasqu commented Sep 18, 2025

run-slow: bert, auto, bart, roberta

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/auto', 'models/bart', 'models/bert', 'models/roberta']
quantizations: [] ...

@vasqu
Copy link
Contributor Author

vasqu commented Sep 18, 2025

Same tests fail on main (examples_torch) + some fa tests fail but they are known to be flaky. Otherwise, looks good!

@ArthurZucker ArthurZucker merged commit 155f7e2 into main Sep 19, 2025
21 of 25 checks passed
@ArthurZucker ArthurZucker deleted the vas-bert-attn-refactors branch September 19, 2025 09:24
Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was just checking the PR for other things and noticed a typing issue probably!

Comment on lines -1094 to +1102
past_key_values: Optional[Cache] = None,
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a typo no @vasqu? I don't think we can ever have a list here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for legacy caches (the old typing we had there for this). Will be removed for v5 when cache classes are finally mandatory!

ErfanBaghaei pushed a commit to ErfanBaghaei/transformers that referenced this pull request Sep 25, 2025
* clean start to bert refactor

* some test fixes

* style

* fix last tests

* be strict on positional embeddings, fixup according tests

* cache support

* more cache fixes, new causal API

* simplify masks, fix tests for gen

* flex attn, static cache support, round of fixes

* ?

* this time

* style

* fix flash attention tests, flex attention requires torch 2.7.x to work with multiple classes (as recompile strats force a size call which is wrongly interpreted before)

* roberta

* fixup sdpa remains

* attention split, simplify args and kwargs, better typing

* fix encoder decoder

* fix test

* modular roberta

* albert

* data2vectext, making it modular tomorrow

* modular data2vec text

* tmp disable

* xmod + cache position fixes

* whoops

* electra + markuplm, small fixes

* remove wrong copy

* xlm_roberta + some embedding fixes

* roberta prelayernorm

* RemBert: remove copy, maybe doing it later

* ernie

* fix roberta offloading

* camembert

* copy fixes

* bert generation + fixes on eager

* xlm roberta xl

* bridgetower (text) + seamlessv2 copy fixes

* rocbert + small fixes

* whoops

* small round of fixups

* NOTE: kernels didnt load with an earlier version, some fixup (needs another look bc cross deps)

* the end of the tunnel?

* fixup nllbmoe + style

* we dont need this anymore

* megatron bert is barely used, low prio skip for now

* Modernize bert (template for others)

NOTE: trying to push this through, might be overdue if not in time possible

* check inputs for all others (if checkmarked)

* fix bridgetower

* style

* fix encoder decoder (partially but cause found and fix also, just needs to be done for everything else)

* proper fix for bert to force intermediate dict outputs

* propagate to others

* style

* xlm roberta xl investigation, its the layernorm...

* mobile bert

* revert this, might cause issues with composed models

* review

* style
vijayabhaskar-ev pushed a commit to vijayabhaskar-ev/transformers that referenced this pull request Oct 2, 2025
* clean start to bert refactor

* some test fixes

* style

* fix last tests

* be strict on positional embeddings, fixup according tests

* cache support

* more cache fixes, new causal API

* simplify masks, fix tests for gen

* flex attn, static cache support, round of fixes

* ?

* this time

* style

* fix flash attention tests, flex attention requires torch 2.7.x to work with multiple classes (as recompile strats force a size call which is wrongly interpreted before)

* roberta

* fixup sdpa remains

* attention split, simplify args and kwargs, better typing

* fix encoder decoder

* fix test

* modular roberta

* albert

* data2vectext, making it modular tomorrow

* modular data2vec text

* tmp disable

* xmod + cache position fixes

* whoops

* electra + markuplm, small fixes

* remove wrong copy

* xlm_roberta + some embedding fixes

* roberta prelayernorm

* RemBert: remove copy, maybe doing it later

* ernie

* fix roberta offloading

* camembert

* copy fixes

* bert generation + fixes on eager

* xlm roberta xl

* bridgetower (text) + seamlessv2 copy fixes

* rocbert + small fixes

* whoops

* small round of fixups

* NOTE: kernels didnt load with an earlier version, some fixup (needs another look bc cross deps)

* the end of the tunnel?

* fixup nllbmoe + style

* we dont need this anymore

* megatron bert is barely used, low prio skip for now

* Modernize bert (template for others)

NOTE: trying to push this through, might be overdue if not in time possible

* check inputs for all others (if checkmarked)

* fix bridgetower

* style

* fix encoder decoder (partially but cause found and fix also, just needs to be done for everything else)

* proper fix for bert to force intermediate dict outputs

* propagate to others

* style

* xlm roberta xl investigation, its the layernorm...

* mobile bert

* revert this, might cause issues with composed models

* review

* style
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request Oct 4, 2025
* clean start to bert refactor

* some test fixes

* style

* fix last tests

* be strict on positional embeddings, fixup according tests

* cache support

* more cache fixes, new causal API

* simplify masks, fix tests for gen

* flex attn, static cache support, round of fixes

* ?

* this time

* style

* fix flash attention tests, flex attention requires torch 2.7.x to work with multiple classes (as recompile strats force a size call which is wrongly interpreted before)

* roberta

* fixup sdpa remains

* attention split, simplify args and kwargs, better typing

* fix encoder decoder

* fix test

* modular roberta

* albert

* data2vectext, making it modular tomorrow

* modular data2vec text

* tmp disable

* xmod + cache position fixes

* whoops

* electra + markuplm, small fixes

* remove wrong copy

* xlm_roberta + some embedding fixes

* roberta prelayernorm

* RemBert: remove copy, maybe doing it later

* ernie

* fix roberta offloading

* camembert

* copy fixes

* bert generation + fixes on eager

* xlm roberta xl

* bridgetower (text) + seamlessv2 copy fixes

* rocbert + small fixes

* whoops

* small round of fixups

* NOTE: kernels didnt load with an earlier version, some fixup (needs another look bc cross deps)

* the end of the tunnel?

* fixup nllbmoe + style

* we dont need this anymore

* megatron bert is barely used, low prio skip for now

* Modernize bert (template for others)

NOTE: trying to push this through, might be overdue if not in time possible

* check inputs for all others (if checkmarked)

* fix bridgetower

* style

* fix encoder decoder (partially but cause found and fix also, just needs to be done for everything else)

* proper fix for bert to force intermediate dict outputs

* propagate to others

* style

* xlm roberta xl investigation, its the layernorm...

* mobile bert

* revert this, might cause issues with composed models

* review

* style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants