Skip to content

Commit 71838e4

Browse files
committed
lol one single nit
1 parent 4e35168 commit 71838e4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+104
-104
lines changed

examples/modular-transformers/modeling_dummy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def forward(
477477
output_attentions: Optional[bool] = None,
478478
output_hidden_states: Optional[bool] = None,
479479
cache_position: Optional[torch.LongTensor] = None,
480-
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
480+
**kwargs: Unpack[FlashAttentionKwargs],
481481
) -> BaseModelOutputWithPast:
482482
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
483483
output_hidden_states = (
@@ -539,7 +539,7 @@ def forward(
539539
use_cache=use_cache,
540540
cache_position=cache_position,
541541
position_embeddings=position_embeddings,
542-
**flash_attn_kwargs,
542+
**kwargs,
543543
)
544544

545545
hidden_states = layer_outputs[0]

examples/modular-transformers/modeling_multimodal1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def forward(
477477
output_attentions: Optional[bool] = None,
478478
output_hidden_states: Optional[bool] = None,
479479
cache_position: Optional[torch.LongTensor] = None,
480-
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
480+
**kwargs: Unpack[FlashAttentionKwargs],
481481
) -> BaseModelOutputWithPast:
482482
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
483483
output_hidden_states = (
@@ -539,7 +539,7 @@ def forward(
539539
use_cache=use_cache,
540540
cache_position=cache_position,
541541
position_embeddings=position_embeddings,
542-
**flash_attn_kwargs,
542+
**kwargs,
543543
)
544544

545545
hidden_states = layer_outputs[0]

src/transformers/models/aria/modeling_aria.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ def forward(
754754
inputs_embeds: Optional[torch.FloatTensor] = None,
755755
cache_position: Optional[torch.LongTensor] = None,
756756
use_cache: Optional[bool] = None,
757-
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
757+
**kwargs: Unpack[FlashAttentionKwargs],
758758
) -> BaseModelOutputWithPast:
759759
if inputs_embeds is None:
760760
inputs_embeds = self.embed_tokens(input_ids)
@@ -790,7 +790,7 @@ def forward(
790790
past_key_value=past_key_values,
791791
cache_position=cache_position,
792792
position_embeddings=position_embeddings,
793-
**flash_attn_kwargs,
793+
**kwargs,
794794
)
795795

796796
hidden_states = layer_outputs[0]

src/transformers/models/biogpt/modeling_biogpt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def forward(
315315
output_attentions=output_attentions,
316316
position_ids=position_ids,
317317
cache_position=cache_position,
318-
**flash_attn_kwargs,
318+
**kwargs,
319319
)
320320
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
321321
hidden_states = residual + hidden_states
@@ -648,7 +648,7 @@ def forward(
648648

649649
if self.gradient_checkpointing and self.training:
650650
layer_outputs = self._gradient_checkpointing_func(
651-
partial(decoder_layer.__call__, **flash_attn_kwargs),
651+
partial(decoder_layer.__call__, **kwargs),
652652
hidden_states,
653653
causal_mask,
654654
head_mask[idx] if head_mask is not None else None,
@@ -668,7 +668,7 @@ def forward(
668668
use_cache=use_cache,
669669
position_ids=position_ids,
670670
cache_position=cache_position,
671-
**flash_attn_kwargs,
671+
**kwargs,
672672
)
673673

674674
hidden_states = layer_outputs[0]

src/transformers/models/biogpt/modular_biogpt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def forward(
142142
output_attentions=output_attentions,
143143
position_ids=position_ids,
144144
cache_position=cache_position,
145-
**flash_attn_kwargs,
145+
**kwargs,
146146
)
147147
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
148148
hidden_states = residual + hidden_states
@@ -475,7 +475,7 @@ def forward(
475475

476476
if self.gradient_checkpointing and self.training:
477477
layer_outputs = self._gradient_checkpointing_func(
478-
partial(decoder_layer.__call__, **flash_attn_kwargs),
478+
partial(decoder_layer.__call__, **kwargs),
479479
hidden_states,
480480
causal_mask,
481481
head_mask[idx] if head_mask is not None else None,
@@ -495,7 +495,7 @@ def forward(
495495
use_cache=use_cache,
496496
position_ids=position_ids,
497497
cache_position=cache_position,
498-
**flash_attn_kwargs,
498+
**kwargs,
499499
)
500500

501501
hidden_states = layer_outputs[0]

src/transformers/models/bitnet/modeling_bitnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def forward(
368368
inputs_embeds: Optional[torch.FloatTensor] = None,
369369
cache_position: Optional[torch.LongTensor] = None,
370370
use_cache: Optional[bool] = None,
371-
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
371+
**kwargs: Unpack[FlashAttentionKwargs],
372372
) -> BaseModelOutputWithPast:
373373
if inputs_embeds is None:
374374
inputs_embeds = self.embed_tokens(input_ids)
@@ -404,7 +404,7 @@ def forward(
404404
past_key_value=past_key_values,
405405
cache_position=cache_position,
406406
position_embeddings=position_embeddings,
407-
**flash_attn_kwargs,
407+
**kwargs,
408408
)
409409

410410
hidden_states = layer_outputs[0]

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def forward(
410410
inputs_embeds: Optional[torch.FloatTensor] = None,
411411
cache_position: Optional[torch.LongTensor] = None,
412412
use_cache: Optional[bool] = None,
413-
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
413+
**kwargs: Unpack[FlashAttentionKwargs],
414414
) -> BaseModelOutputWithPast:
415415
if inputs_embeds is None:
416416
inputs_embeds = self.embed_tokens(input_ids)
@@ -446,7 +446,7 @@ def forward(
446446
past_key_value=past_key_values,
447447
cache_position=cache_position,
448448
position_embeddings=position_embeddings,
449-
**flash_attn_kwargs,
449+
**kwargs,
450450
)
451451

452452
hidden_states = layer_outputs[0]

src/transformers/models/cohere2/modeling_cohere2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def forward(
394394
output_attentions: Optional[bool] = None,
395395
output_hidden_states: Optional[bool] = None,
396396
cache_position: Optional[torch.LongTensor] = None,
397-
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
397+
**kwargs: Unpack[FlashAttentionKwargs],
398398
) -> BaseModelOutputWithPast:
399399
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
400400
output_hidden_states = (
@@ -462,7 +462,7 @@ def forward(
462462
output_attentions=output_attentions,
463463
use_cache=use_cache,
464464
cache_position=cache_position,
465-
**flash_attn_kwargs,
465+
**kwargs,
466466
)
467467

468468
hidden_states = layer_outputs[0]

src/transformers/models/cohere2/modular_cohere2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def forward(
415415
output_attentions: Optional[bool] = None,
416416
output_hidden_states: Optional[bool] = None,
417417
cache_position: Optional[torch.LongTensor] = None,
418-
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
418+
**kwargs: Unpack[FlashAttentionKwargs],
419419
) -> BaseModelOutputWithPast:
420420
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
421421
output_hidden_states = (
@@ -483,7 +483,7 @@ def forward(
483483
output_attentions=output_attentions,
484484
use_cache=use_cache,
485485
cache_position=cache_position,
486-
**flash_attn_kwargs,
486+
**kwargs,
487487
)
488488

489489
hidden_states = layer_outputs[0]

src/transformers/models/csm/modeling_csm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def forward(
529529
use_cache=use_cache,
530530
cache_position=cache_position,
531531
position_embeddings=position_embeddings,
532-
**flash_attn_kwargs,
532+
**kwargs,
533533
)
534534

535535
hidden_states = layer_outputs[0]
@@ -759,7 +759,7 @@ def forward(
759759
inputs_embeds: Optional[torch.FloatTensor] = None,
760760
cache_position: Optional[torch.LongTensor] = None,
761761
use_cache: Optional[bool] = None,
762-
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
762+
**kwargs: Unpack[FlashAttentionKwargs],
763763
) -> BaseModelOutputWithPast:
764764
r"""
765765
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
@@ -807,7 +807,7 @@ def forward(
807807
past_key_value=past_key_values,
808808
cache_position=cache_position,
809809
position_embeddings=position_embeddings,
810-
**flash_attn_kwargs,
810+
**kwargs,
811811
)
812812

813813
hidden_states = layer_outputs[0]

0 commit comments

Comments
 (0)