diff --git a/keras_hub/src/models/gemma/gemma_attention.py b/keras_hub/src/models/gemma/gemma_attention.py index 22b228f958..680685dd64 100644 --- a/keras_hub/src/models/gemma/gemma_attention.py +++ b/keras_hub/src/models/gemma/gemma_attention.py @@ -141,11 +141,8 @@ def _compute_attention( query=q, key=k, value=v, - bias=None, mask=attention_mask, scale=query_normalization, - is_causal=True, - flash_attention=True, ) return attention_output