Skip to content

Commit 2b4734b

Browse files
authored
Support passing flash_attn_kwargs when gradient_checkpointing is enabled (#37037)
* support passing flash_attn_kwargs when gradient_checkpointing is enabled * make modeling_deepspeek_v3.py consistent with modular_deepseek_v3.py
1 parent bd41b9c commit 2b4734b

29 files changed

+58
-30
lines changed

examples/modular-transformers/modeling_dummy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# the file from the modular. If any change should be done, please apply the change to the
55
# modular_dummy.py file directly. One of our CI enforces this.
66
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7+
from functools import partial
78
from typing import Callable, Optional, Tuple, Union
89

910
import torch
@@ -544,7 +545,7 @@ def forward(
544545

545546
if self.gradient_checkpointing and self.training:
546547
layer_outputs = self._gradient_checkpointing_func(
547-
decoder_layer.__call__,
548+
partial(decoder_layer.__call__, **flash_attn_kwargs),
548549
hidden_states,
549550
causal_mask,
550551
position_ids,

examples/modular-transformers/modeling_multimodal1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# the file from the modular. If any change should be done, please apply the change to the
55
# modular_multimodal1.py file directly. One of our CI enforces this.
66
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7+
from functools import partial
78
from typing import Callable, Optional, Tuple, Union
89

910
import torch
@@ -544,7 +545,7 @@ def forward(
544545

545546
if self.gradient_checkpointing and self.training:
546547
layer_outputs = self._gradient_checkpointing_func(
547-
decoder_layer.__call__,
548+
partial(decoder_layer.__call__, **flash_attn_kwargs),
548549
hidden_states,
549550
causal_mask,
550551
position_ids,

src/transformers/models/aria/modeling_aria.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
2121
from dataclasses import dataclass
22+
from functools import partial
2223
from typing import Callable, List, Optional, Tuple, Union
2324

2425
from ...activations import ACT2FN
@@ -963,7 +964,7 @@ def forward(
963964

964965
if self.gradient_checkpointing and self.training:
965966
layer_outputs = self._gradient_checkpointing_func(
966-
decoder_layer.__call__,
967+
partial(decoder_layer.__call__, **flash_attn_kwargs),
967968
hidden_states,
968969
causal_mask,
969970
position_ids,

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
# This file is based on the LLama model definition file in transformers
2828

2929

30+
from functools import partial
3031
from typing import Callable, List, Optional, Tuple, Union
3132

3233
import torch
@@ -613,7 +614,7 @@ def forward(
613614

614615
if self.gradient_checkpointing and self.training:
615616
layer_outputs = self._gradient_checkpointing_func(
616-
decoder_layer.__call__,
617+
partial(decoder_layer.__call__, **flash_attn_kwargs),
617618
hidden_states,
618619
causal_mask,
619620
position_ids,

src/transformers/models/cohere2/modeling_cohere2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2020
# See the License for the specific language governing permissions and
2121
# limitations under the License.
22+
from functools import partial
2223
from typing import Callable, List, Optional, Tuple, Union
2324

2425
import torch
@@ -634,7 +635,7 @@ def forward(
634635

635636
if self.gradient_checkpointing and self.training:
636637
layer_outputs = self._gradient_checkpointing_func(
637-
decoder_layer.__call__,
638+
partial(decoder_layer.__call__, **flash_attn_kwargs),
638639
hidden_states,
639640
position_embeddings,
640641
causal_mask,

src/transformers/models/cohere2/modular_cohere2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
from functools import partial
1617
from typing import Callable, Optional, Tuple, Union
1718

1819
import torch
@@ -533,7 +534,7 @@ def forward(
533534

534535
if self.gradient_checkpointing and self.training:
535536
layer_outputs = self._gradient_checkpointing_func(
536-
decoder_layer.__call__,
537+
partial(decoder_layer.__call__, **flash_attn_kwargs),
537538
hidden_states,
538539
position_embeddings,
539540
causal_mask,

src/transformers/models/deepseek_v3/modeling_deepseek_v3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# modular_deepseek_v3.py file directly. One of our CI enforces this.
66
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
77
import math
8+
from functools import partial
89
from typing import Callable, Optional, Tuple, Union
910

1011
import torch
@@ -759,7 +760,7 @@ def forward(
759760

760761
if self.gradient_checkpointing and self.training:
761762
layer_outputs = self._gradient_checkpointing_func(
762-
decoder_layer.__call__,
763+
partial(decoder_layer.__call__, **flash_attn_kwargs),
763764
hidden_states,
764765
causal_mask,
765766
position_ids,

src/transformers/models/diffllama/modeling_diffllama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# See the License for the specific language governing permissions and
2323
# limitations under the License.
2424
import math
25+
from functools import partial
2526
from typing import Optional, Tuple, Union
2627

2728
import torch
@@ -852,7 +853,7 @@ def forward(
852853

853854
if self.gradient_checkpointing and self.training:
854855
layer_outputs = self._gradient_checkpointing_func(
855-
decoder_layer.__call__,
856+
partial(decoder_layer.__call__, **flash_attn_kwargs),
856857
hidden_states,
857858
causal_mask,
858859
position_ids,

src/transformers/models/emu3/modeling_emu3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# limitations under the License.
2222

2323
import math
24-
from functools import cached_property
24+
from functools import cached_property, partial
2525
from typing import Callable, List, Optional, Tuple, Union
2626

2727
import torch
@@ -1439,7 +1439,7 @@ def forward(
14391439

14401440
if self.gradient_checkpointing and self.training:
14411441
layer_outputs = self._gradient_checkpointing_func(
1442-
decoder_layer.__call__,
1442+
partial(decoder_layer.__call__, **flash_attn_kwargs),
14431443
hidden_states,
14441444
causal_mask,
14451445
position_ids,

src/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2020
# See the License for the specific language governing permissions and
2121
# limitations under the License.
22+
from functools import partial
2223
from typing import Callable, Optional, Tuple, Union
2324

2425
import torch
@@ -645,7 +646,7 @@ def forward(
645646

646647
if self.gradient_checkpointing and self.training:
647648
layer_outputs = self._gradient_checkpointing_func(
648-
decoder_layer.__call__,
649+
partial(decoder_layer.__call__, **flash_attn_kwargs),
649650
hidden_states,
650651
position_embeddings,
651652
causal_mask,

0 commit comments

Comments
 (0)