From 80e6a0580c2b3da925132febd73a7d9f2b8f4863 Mon Sep 17 00:00:00 2001 From: TechxGenus Date: Sun, 31 Mar 2024 02:37:25 +0800 Subject: [PATCH] [`BC`] Fix BC for AWQ quant (#29965) fix awq quant --- src/transformers/models/cohere/modeling_cohere.py | 2 +- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index d8dfac2035678b..e949bc14482e74 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -963,7 +963,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache + if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 9bebe29e3e941f..2d93c43425f99a 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -971,7 +971,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache + if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 28315114c82ecf..8d0baf63c7b3fe 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1064,7 +1064,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache + if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache target_length = self.config.max_position_embeddings else: # dynamic cache target_length = (