diff --git a/keras_hub/src/models/clip/clip_encoder_block.py b/keras_hub/src/models/clip/clip_encoder_block.py index b438115fc7..ba0193afd9 100644 --- a/keras_hub/src/models/clip/clip_encoder_block.py +++ b/keras_hub/src/models/clip/clip_encoder_block.py @@ -10,28 +10,11 @@ def quick_gelu(x): # TODO: Deprecate this in favor of `keras.layers.MultiHeadAttention` once the # dtype compatibility issue is resolved. class CLIPMultiHeadAttention(layers.MultiHeadAttention): - def _compute_attention( - self, query, key, value, attention_mask=None, training=None - ): - query = ops.multiply( - query, ops.cast(self._inverse_sqrt_key_dim, query.dtype) - ) - attention_scores = ops.einsum(self._dot_product_equation, key, query) - attention_scores = self._masked_softmax( + def _masked_softmax(self, attention_scores, attention_mask=None): + attention_scores = super()._masked_softmax( attention_scores, attention_mask ) - # Fix the dtype compatibility. - attention_scores = ops.cast(attention_scores, value.dtype) - if self.dropout: - final_attn_scores = self._dropout_layer( - attention_scores, training=training - ) - else: - final_attn_scores = attention_scores - attention_output = ops.einsum( - self._combine_equation, final_attn_scores, value - ) - return attention_output, attention_scores + return ops.cast(attention_scores, self._value_dense.compute_dtype) class CLIPEncoderBlock(layers.Layer):