From eebb3865907608cd3ce6e57ef030ebd414ad0ba6 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 24 Feb 2025 18:33:49 +0000 Subject: [PATCH] fix mask dtype --- keras_hub/src/models/gemma/gemma_attention.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/keras_hub/src/models/gemma/gemma_attention.py b/keras_hub/src/models/gemma/gemma_attention.py index 680685dd64..8549e0a9d6 100644 --- a/keras_hub/src/models/gemma/gemma_attention.py +++ b/keras_hub/src/models/gemma/gemma_attention.py @@ -128,14 +128,8 @@ def _compute_attention( "Please set `dropout` to 0.0." ) if attention_mask is not None: - while len(attention_mask.shape) < 4: - attention_mask = ops.expand_dims( - attention_mask, axis=1 - ) # Add dimension for num_heads - if attention_mask.shape[1] != self.num_query_heads: - attention_mask = ops.tile( - attention_mask, [1, self.num_query_heads, 1, 1] - ) + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_mask = ops.cast(attention_mask, dtype="bool") attention_output = ops.dot_product_attention( query=q,