diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_vit.py b/keras_hub/src/models/pali_gemma/pali_gemma_vit.py index b217509541..621b4562de 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_vit.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_vit.py @@ -204,9 +204,8 @@ def __init__( self.intermediate_dim = intermediate_dim def compute_attention(self, x, mask=None): - mask = None if mask is not None: - mask = ops.cast(mask, dtype=x.dtype) if mask is not None else None + mask = ops.cast(mask, dtype=x.dtype) return self.attn(x, attention_mask=mask)[0] def build(self, input_shape):