Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 13 additions & 26 deletions src/transformers/models/blt/modeling_blt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch.nn.functional as F

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
Expand Down Expand Up @@ -321,7 +321,6 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
use_cache: bool = False,
past_key_values=None,
cache_position=None,
**kwargs,
Expand Down Expand Up @@ -393,9 +392,7 @@ def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
Expand All @@ -404,27 +401,13 @@ def forward(
query_states = self.q_proj(query_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

if cross_attention_states is not None:
cross_attention_states = self.k_norm(cross_attention_states)
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if past_key_values is not None:
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
elif cache_position[0] != 0:
key_states, value_states = (
past_key_values.layers[self.layer_idx].keys,
past_key_values.layers[self.layer_idx].values,
)
else:
raise ValueError(
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
)
attention_interface: Callable = eager_attention_forward
cross_attention_states = self.k_norm(cross_attention_states)
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

Expand Down Expand Up @@ -1089,6 +1072,9 @@ def forward(
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))

# Extract input embeddings as early as possible
if inputs_embeds is not None:
encoder_embeds = inputs_embeds
Expand Down Expand Up @@ -1137,7 +1123,7 @@ def forward(
input_embeds=encoder_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
position_ids=position_ids,
)

Expand All @@ -1157,6 +1143,7 @@ def forward(
encoder_attention_mask=cross_attn_mask_enc,
num_patches=patch_lengths.shape[1],
patch_ids=patch_ids,
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
**kwargs,
)
encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
Expand Down Expand Up @@ -1192,7 +1179,7 @@ def forward(
patch_embeds=global_hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
past_key_values=past_key_values.cross_attention_cache if past_key_values is not None else None,
cache_position=cache_position,
encoder_attention_mask=cross_attn_mask_dec,
**kwargs,
Expand Down
59 changes: 13 additions & 46 deletions src/transformers/models/blt/modular_blt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch.nn as nn
import torch.nn.functional as F

from ...cache_utils import Cache, DynamicCache
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...masking_utils import create_causal_mask
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import dynamic_rope_update
Expand Down Expand Up @@ -299,27 +299,6 @@ def __init__(self, config, layer_idx: int):
class BltSelfAttention(MllamaTextSelfAttention):
def __init__(self, config: BltConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.is_causal = True

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
use_cache: bool = False,
past_key_values=None,
cache_position=None,
**kwargs,
):
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
use_cache=use_cache,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)


class BltCrossAttention(MllamaTextCrossAttention):
Expand All @@ -335,37 +314,21 @@ def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_norm(hidden_states)
query_states = self.q_proj(query_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

if cross_attention_states is not None:
cross_attention_states = self.k_norm(cross_attention_states)
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if past_key_values is not None:
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
elif cache_position[0] != 0:
key_states, value_states = (
past_key_values.layers[self.layer_idx].keys,
past_key_values.layers[self.layer_idx].values,
)
else:
raise ValueError(
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
)
attention_interface: Callable = eager_attention_forward
cross_attention_states = self.k_norm(cross_attention_states)
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

Expand Down Expand Up @@ -828,6 +791,9 @@ def forward(
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))

# Extract input embeddings as early as possible
if inputs_embeds is not None:
encoder_embeds = inputs_embeds
Expand Down Expand Up @@ -876,7 +842,7 @@ def forward(
input_embeds=encoder_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
position_ids=position_ids,
)

Expand All @@ -896,6 +862,7 @@ def forward(
encoder_attention_mask=cross_attn_mask_enc,
num_patches=patch_lengths.shape[1],
patch_ids=patch_ids,
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
**kwargs,
)
encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
Expand Down Expand Up @@ -931,7 +898,7 @@ def forward(
patch_embeds=global_hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
past_key_values=past_key_values.cross_attention_cache if past_key_values is not None else None,
cache_position=cache_position,
encoder_attention_mask=cross_attn_mask_dec,
**kwargs,
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,10 +534,8 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
use_cache: bool = False,
past_key_values=None,
cache_position=None,
position_ids=None,
**kwargs,
):
bsz, q_len, _ = hidden_states.size()
Expand Down
9 changes: 6 additions & 3 deletions tests/models/blt/test_modeling_blt.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,15 @@ def test_eager_matches_sdpa_inference(

@require_torch_accelerator
class BltIntegrationTest(unittest.TestCase):
def setup(self):
cleanup(torch_device, gc_collect=True)

def tearDown(self):
# TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves
# some memory allocated in the cache, which means some object is not being released properly. This causes some
# unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU.
# Investigate the root cause.
cleanup(torch_device, gc_collect=False)
cleanup(torch_device, gc_collect=True)

@slow
@require_read_token
Expand Down Expand Up @@ -339,7 +342,7 @@ def test_model_logits(self):
def test_model_bf16(self):
"""Test Blt model with bfloat16 precision."""
NUM_TOKENS_TO_GENERATE = 200
EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s"
EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m"

prompt = "my name is"

Expand Down Expand Up @@ -472,7 +475,7 @@ def test_model_eager(self):
def test_model_bf16_static_cache(self):
"""Test Blt model with bfloat16 precision and static cache."""
NUM_TOKENS_TO_GENERATE = 200
EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s"
EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m"

prompt = "my name is"

Expand Down