diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index fe435876db2a..ae40c5e75ab9 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -28,7 +28,7 @@ import torch.nn.functional as F from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -321,7 +321,6 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, - use_cache: bool = False, past_key_values=None, cache_position=None, **kwargs, @@ -393,9 +392,7 @@ def forward( self, hidden_states: torch.Tensor, cross_attention_states: Optional[torch.Tensor] = None, - past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, - cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -404,27 +401,13 @@ def forward( query_states = self.q_proj(query_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - if cross_attention_states is not None: - cross_attention_states = self.k_norm(cross_attention_states) - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if past_key_values is not None: - key_states, value_states = past_key_values.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - elif cache_position[0] != 0: - key_states, value_states = ( - past_key_values.layers[self.layer_idx].keys, - past_key_values.layers[self.layer_idx].values, - ) - else: - raise ValueError( - "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" - ) - attention_interface: Callable = eager_attention_forward + cross_attention_states = self.k_norm(cross_attention_states) + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -1089,6 +1072,9 @@ def forward( if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + # Extract input embeddings as early as possible if inputs_embeds is not None: encoder_embeds = inputs_embeds @@ -1137,7 +1123,7 @@ def forward( input_embeds=encoder_embeds, attention_mask=attention_mask, cache_position=cache_position, - past_key_values=past_key_values, + past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None, position_ids=position_ids, ) @@ -1157,6 +1143,7 @@ def forward( encoder_attention_mask=cross_attn_mask_enc, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, + past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None, **kwargs, ) encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) @@ -1192,7 +1179,7 @@ def forward( patch_embeds=global_hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_values=past_key_values, + past_key_values=past_key_values.cross_attention_cache if past_key_values is not None else None, cache_position=cache_position, encoder_attention_mask=cross_attn_mask_dec, **kwargs, diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 78d5aa5a15ef..f9c9bf434005 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -22,7 +22,7 @@ import torch.nn as nn import torch.nn.functional as F -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...masking_utils import create_causal_mask from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import dynamic_rope_update @@ -299,27 +299,6 @@ def __init__(self, config, layer_idx: int): class BltSelfAttention(MllamaTextSelfAttention): def __init__(self, config: BltConfig, layer_idx: int): super().__init__(config, layer_idx) - self.is_causal = True - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, - use_cache: bool = False, - past_key_values=None, - cache_position=None, - **kwargs, - ): - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_embeddings=position_embeddings, - use_cache=use_cache, - past_key_values=past_key_values, - cache_position=cache_position, - **kwargs, - ) class BltCrossAttention(MllamaTextCrossAttention): @@ -335,9 +314,7 @@ def forward( self, hidden_states: torch.Tensor, cross_attention_states: Optional[torch.Tensor] = None, - past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, - cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ): bsz, q_len, _ = hidden_states.size() @@ -345,27 +322,13 @@ def forward( query_states = self.q_proj(query_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - if cross_attention_states is not None: - cross_attention_states = self.k_norm(cross_attention_states) - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if past_key_values is not None: - key_states, value_states = past_key_values.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - elif cache_position[0] != 0: - key_states, value_states = ( - past_key_values.layers[self.layer_idx].keys, - past_key_values.layers[self.layer_idx].values, - ) - else: - raise ValueError( - "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" - ) - attention_interface: Callable = eager_attention_forward + cross_attention_states = self.k_norm(cross_attention_states) + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -828,6 +791,9 @@ def forward( if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + # Extract input embeddings as early as possible if inputs_embeds is not None: encoder_embeds = inputs_embeds @@ -876,7 +842,7 @@ def forward( input_embeds=encoder_embeds, attention_mask=attention_mask, cache_position=cache_position, - past_key_values=past_key_values, + past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None, position_ids=position_ids, ) @@ -896,6 +862,7 @@ def forward( encoder_attention_mask=cross_attn_mask_enc, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, + past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None, **kwargs, ) encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) @@ -931,7 +898,7 @@ def forward( patch_embeds=global_hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_values=past_key_values, + past_key_values=past_key_values.cross_attention_cache if past_key_values is not None else None, cache_position=cache_position, encoder_attention_mask=cross_attn_mask_dec, **kwargs, diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 1edcdb21dad3..1f52d30be45c 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -534,10 +534,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, - use_cache: bool = False, past_key_values=None, cache_position=None, - position_ids=None, **kwargs, ): bsz, q_len, _ = hidden_states.size() diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index c7ca7099582d..7688da9b04c2 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -224,12 +224,15 @@ def test_eager_matches_sdpa_inference( @require_torch_accelerator class BltIntegrationTest(unittest.TestCase): + def setup(self): + cleanup(torch_device, gc_collect=True) + def tearDown(self): # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves # some memory allocated in the cache, which means some object is not being released properly. This causes some # unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU. # Investigate the root cause. - cleanup(torch_device, gc_collect=False) + cleanup(torch_device, gc_collect=True) @slow @require_read_token @@ -339,7 +342,7 @@ def test_model_logits(self): def test_model_bf16(self): """Test Blt model with bfloat16 precision.""" NUM_TOKENS_TO_GENERATE = 200 - EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s" + EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m" prompt = "my name is" @@ -472,7 +475,7 @@ def test_model_eager(self): def test_model_bf16_static_cache(self): """Test Blt model with bfloat16 precision and static cache.""" NUM_TOKENS_TO_GENERATE = 200 - EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s" + EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m" prompt = "my name is"