diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 9cc2858e5271a..55dabe7cbe26b 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -124,7 +124,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -132,9 +132,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -940,6 +939,10 @@ def forward( attentions=all_self_attns, ) + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 def _update_causal_mask(self, attention_mask, input_tensor): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: @@ -955,16 +958,8 @@ def _update_causal_mask(self, attention_mask, input_tensor): causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1) self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) - if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows - causal_mask = ( - self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min - ) - else: - mask = torch.full( - (self.config.max_position_embeddings, self.config.max_position_embeddings), - fill_value=torch.finfo(dtype).min, - ) - causal_mask = torch.triu(mask, diagonal=1) + # We use the current dtype to avoid any overflows + causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min causal_mask = causal_mask.to(dtype=dtype, device=device) if attention_mask is not None and attention_mask.dim() == 2: @@ -1146,29 +1141,32 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): + if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None: # generation with static cache - past_length = past_key_value.get_seq_length() + cache_position = kwargs.get("cache_position", None) + if cache_position is None: + past_length = 0 + else: + past_length = cache_position[-1] + 1 input_ids = input_ids[:, past_length:] position_ids = position_ids[:, past_length:] # TODO @gante we should only keep a `cache_position` in generate, and do +=1. # same goes for position ids. Could also help with continued generation. - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - cache_position = torch.arange( - past_length, past_length + position_ids.shape[-1], device=position_ids.device - ) + cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} model_inputs.update( { - "position_ids": position_ids, + "position_ids": position_ids.contiguous(), "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"),