From b9bd5127dbe85138a5400e9a017238e2de97e782 Mon Sep 17 00:00:00 2001 From: suhana Date: Tue, 26 Aug 2025 15:27:25 +0530 Subject: [PATCH 1/2] Fixing the dtype error causing failing tests --- .../src/models/deberta_v3/disentangled_self_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/deberta_v3/disentangled_self_attention.py b/keras_hub/src/models/deberta_v3/disentangled_self_attention.py index 769f8b041e..66bf75efb0 100644 --- a/keras_hub/src/models/deberta_v3/disentangled_self_attention.py +++ b/keras_hub/src/models/deberta_v3/disentangled_self_attention.py @@ -217,9 +217,9 @@ def _make_log_bucket_position(self, rel_pos): ) def _get_log_pos(abs_pos, mid): - numerator = ops.log(abs_pos / mid) + numerator = ops.log(ops.cast(abs_pos, "float32") / ops.cast(mid, "float32")) numerator = numerator * ops.cast(mid - 1, dtype=numerator.dtype) - denominator = ops.log((self.max_position_embeddings - 1) / mid) + denominator = ops.log(ops.cast(self.max_position_embeddings - 1, "float32") / ops.cast(mid, "float32")) val = ops.ceil(numerator / denominator) val = ops.cast(val, dtype=mid.dtype) val = val + mid From 13143e6e7f63943914f55b3ceef1e62262c191b4 Mon Sep 17 00:00:00 2001 From: suhana Date: Wed, 27 Aug 2025 11:55:24 +0530 Subject: [PATCH 2/2] Fixinf failure --- .../src/models/deberta_v3/disentangled_self_attention.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/deberta_v3/disentangled_self_attention.py b/keras_hub/src/models/deberta_v3/disentangled_self_attention.py index 66bf75efb0..a2ba30a528 100644 --- a/keras_hub/src/models/deberta_v3/disentangled_self_attention.py +++ b/keras_hub/src/models/deberta_v3/disentangled_self_attention.py @@ -217,9 +217,14 @@ def _make_log_bucket_position(self, rel_pos): ) def _get_log_pos(abs_pos, mid): - numerator = ops.log(ops.cast(abs_pos, "float32") / ops.cast(mid, "float32")) + numerator = ops.log( + ops.cast(abs_pos, "float32") / ops.cast(mid, "float32") + ) numerator = numerator * ops.cast(mid - 1, dtype=numerator.dtype) - denominator = ops.log(ops.cast(self.max_position_embeddings - 1, "float32") / ops.cast(mid, "float32")) + denominator = ops.log( + ops.cast(self.max_position_embeddings - 1, "float32") + / ops.cast(mid, "float32") + ) val = ops.ceil(numerator / denominator) val = ops.cast(val, dtype=mid.dtype) val = val + mid