diff --git a/keras_hub/src/models/deberta_v3/disentangled_self_attention.py b/keras_hub/src/models/deberta_v3/disentangled_self_attention.py index 769f8b041e..a2ba30a528 100644 --- a/keras_hub/src/models/deberta_v3/disentangled_self_attention.py +++ b/keras_hub/src/models/deberta_v3/disentangled_self_attention.py @@ -217,9 +217,14 @@ def _make_log_bucket_position(self, rel_pos): ) def _get_log_pos(abs_pos, mid): - numerator = ops.log(abs_pos / mid) + numerator = ops.log( + ops.cast(abs_pos, "float32") / ops.cast(mid, "float32") + ) numerator = numerator * ops.cast(mid - 1, dtype=numerator.dtype) - denominator = ops.log((self.max_position_embeddings - 1) / mid) + denominator = ops.log( + ops.cast(self.max_position_embeddings - 1, "float32") + / ops.cast(mid, "float32") + ) val = ops.ceil(numerator / denominator) val = ops.cast(val, dtype=mid.dtype) val = val + mid