-
Notifications
You must be signed in to change notification settings - Fork 30.7k
🔴[Attention
] Bert-based Models Attention Refactor
#38301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Atttention
] Bert-based Models Attention RefactorAttention
] Bert-based Models Attention Refactor
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…k with multiple classes (as recompile strats force a size call which is wrongly interpreted before)
run-slow: bert |
This comment contains run-slow, running the specified jobs: models: ['models/bert'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I am asking a lot but I believe this will contribute to unbloat our code more!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's GOOOOOO
|
||
if position_ids is None: | ||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] | ||
position_ids = self.position_ids[:, :seq_length] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's because we expect the pos id to be correct right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, either
- in base transformers fashion we expect the correct positions or use this as a fallback (will always work with right padding)
- in vLLM we expect the correct position ids. This is also done in the padding vs padding free test. Also informed Harry about this requirement.
src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py
Outdated
Show resolved
Hide resolved
run-slow: bert, auto, bart, roberta |
This comment contains run-slow, running the specified jobs: models: ['models/auto', 'models/bart', 'models/bert', 'models/roberta'] |
run-slow: bert, auto, bart, roberta |
This comment contains run-slow, running the specified jobs: models: ['models/auto', 'models/bart', 'models/bert', 'models/roberta'] |
Same tests fail on main ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was just checking the PR for other things and noticed a typing issue probably!
past_key_values: Optional[Cache] = None, | ||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a typo no @vasqu? I don't think we can ever have a list here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's for legacy caches (the old typing we had there for this). Will be removed for v5 when cache classes are finally mandatory!
* clean start to bert refactor * some test fixes * style * fix last tests * be strict on positional embeddings, fixup according tests * cache support * more cache fixes, new causal API * simplify masks, fix tests for gen * flex attn, static cache support, round of fixes * ? * this time * style * fix flash attention tests, flex attention requires torch 2.7.x to work with multiple classes (as recompile strats force a size call which is wrongly interpreted before) * roberta * fixup sdpa remains * attention split, simplify args and kwargs, better typing * fix encoder decoder * fix test * modular roberta * albert * data2vectext, making it modular tomorrow * modular data2vec text * tmp disable * xmod + cache position fixes * whoops * electra + markuplm, small fixes * remove wrong copy * xlm_roberta + some embedding fixes * roberta prelayernorm * RemBert: remove copy, maybe doing it later * ernie * fix roberta offloading * camembert * copy fixes * bert generation + fixes on eager * xlm roberta xl * bridgetower (text) + seamlessv2 copy fixes * rocbert + small fixes * whoops * small round of fixups * NOTE: kernels didnt load with an earlier version, some fixup (needs another look bc cross deps) * the end of the tunnel? * fixup nllbmoe + style * we dont need this anymore * megatron bert is barely used, low prio skip for now * Modernize bert (template for others) NOTE: trying to push this through, might be overdue if not in time possible * check inputs for all others (if checkmarked) * fix bridgetower * style * fix encoder decoder (partially but cause found and fix also, just needs to be done for everything else) * proper fix for bert to force intermediate dict outputs * propagate to others * style * xlm roberta xl investigation, its the layernorm... * mobile bert * revert this, might cause issues with composed models * review * style
* clean start to bert refactor * some test fixes * style * fix last tests * be strict on positional embeddings, fixup according tests * cache support * more cache fixes, new causal API * simplify masks, fix tests for gen * flex attn, static cache support, round of fixes * ? * this time * style * fix flash attention tests, flex attention requires torch 2.7.x to work with multiple classes (as recompile strats force a size call which is wrongly interpreted before) * roberta * fixup sdpa remains * attention split, simplify args and kwargs, better typing * fix encoder decoder * fix test * modular roberta * albert * data2vectext, making it modular tomorrow * modular data2vec text * tmp disable * xmod + cache position fixes * whoops * electra + markuplm, small fixes * remove wrong copy * xlm_roberta + some embedding fixes * roberta prelayernorm * RemBert: remove copy, maybe doing it later * ernie * fix roberta offloading * camembert * copy fixes * bert generation + fixes on eager * xlm roberta xl * bridgetower (text) + seamlessv2 copy fixes * rocbert + small fixes * whoops * small round of fixups * NOTE: kernels didnt load with an earlier version, some fixup (needs another look bc cross deps) * the end of the tunnel? * fixup nllbmoe + style * we dont need this anymore * megatron bert is barely used, low prio skip for now * Modernize bert (template for others) NOTE: trying to push this through, might be overdue if not in time possible * check inputs for all others (if checkmarked) * fix bridgetower * style * fix encoder decoder (partially but cause found and fix also, just needs to be done for everything else) * proper fix for bert to force intermediate dict outputs * propagate to others * style * xlm roberta xl investigation, its the layernorm... * mobile bert * revert this, might cause issues with composed models * review * style
* clean start to bert refactor * some test fixes * style * fix last tests * be strict on positional embeddings, fixup according tests * cache support * more cache fixes, new causal API * simplify masks, fix tests for gen * flex attn, static cache support, round of fixes * ? * this time * style * fix flash attention tests, flex attention requires torch 2.7.x to work with multiple classes (as recompile strats force a size call which is wrongly interpreted before) * roberta * fixup sdpa remains * attention split, simplify args and kwargs, better typing * fix encoder decoder * fix test * modular roberta * albert * data2vectext, making it modular tomorrow * modular data2vec text * tmp disable * xmod + cache position fixes * whoops * electra + markuplm, small fixes * remove wrong copy * xlm_roberta + some embedding fixes * roberta prelayernorm * RemBert: remove copy, maybe doing it later * ernie * fix roberta offloading * camembert * copy fixes * bert generation + fixes on eager * xlm roberta xl * bridgetower (text) + seamlessv2 copy fixes * rocbert + small fixes * whoops * small round of fixups * NOTE: kernels didnt load with an earlier version, some fixup (needs another look bc cross deps) * the end of the tunnel? * fixup nllbmoe + style * we dont need this anymore * megatron bert is barely used, low prio skip for now * Modernize bert (template for others) NOTE: trying to push this through, might be overdue if not in time possible * check inputs for all others (if checkmarked) * fix bridgetower * style * fix encoder decoder (partially but cause found and fix also, just needs to be done for everything else) * proper fix for bert to force intermediate dict outputs * propagate to others * style * xlm roberta xl investigation, its the layernorm... * mobile bert * revert this, might cause issues with composed models * review * style
Keeping track of the models that are done:
Up to discussion:
Would need another round; questionable if worth it (ordered by prio):