From 2823496ca434e8325d8fddfa13983990ef7e454a Mon Sep 17 00:00:00 2001 From: Yang Li Date: Thu, 13 Nov 2025 23:57:34 +0000 Subject: [PATCH 1/2] Update keras3 Softmax mask handling to be more numerically robust. --- keras/src/layers/activations/softmax.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/keras/src/layers/activations/softmax.py b/keras/src/layers/activations/softmax.py index 8660877977ec..6fdfd5754366 100644 --- a/keras/src/layers/activations/softmax.py +++ b/keras/src/layers/activations/softmax.py @@ -52,10 +52,14 @@ def __init__(self, axis=-1, **kwargs): def call(self, inputs, mask=None): if mask is not None: - adder = ( - 1.0 - backend.cast(mask, inputs.dtype) - ) * _large_negative_number(inputs.dtype) - inputs += adder + # We keep the positions where the mask is True or > 0.5, and set the + # other (masked) positions to -1e.9. + if backend.standardize_dtype(mask.dtype) != "bool": + mask = backend.numpy.greater( + mask, backend.cast(0.5, dtype=mask.dtype)) + inputs = backend.numpy.where( + mask, inputs, _large_negative_number(inputs.dtype) + ) if isinstance(self.axis, (tuple, list)): if len(self.axis) > 1: outputs = backend.numpy.exp( From 6c87c58902a2a011cbb948ac5c01d0b5ae07b748 Mon Sep 17 00:00:00 2001 From: Yang Li Date: Fri, 14 Nov 2025 17:42:18 +0000 Subject: [PATCH 2/2] Fix formatting --- keras/src/layers/activations/softmax.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras/src/layers/activations/softmax.py b/keras/src/layers/activations/softmax.py index 6fdfd5754366..dc233ac92742 100644 --- a/keras/src/layers/activations/softmax.py +++ b/keras/src/layers/activations/softmax.py @@ -56,7 +56,8 @@ def call(self, inputs, mask=None): # other (masked) positions to -1e.9. if backend.standardize_dtype(mask.dtype) != "bool": mask = backend.numpy.greater( - mask, backend.cast(0.5, dtype=mask.dtype)) + mask, backend.cast(0.5, dtype=mask.dtype) + ) inputs = backend.numpy.where( mask, inputs, _large_negative_number(inputs.dtype) )