diff --git a/keras/src/layers/activations/softmax.py b/keras/src/layers/activations/softmax.py index 8660877977ec..dc233ac92742 100644 --- a/keras/src/layers/activations/softmax.py +++ b/keras/src/layers/activations/softmax.py @@ -52,10 +52,15 @@ def __init__(self, axis=-1, **kwargs): def call(self, inputs, mask=None): if mask is not None: - adder = ( - 1.0 - backend.cast(mask, inputs.dtype) - ) * _large_negative_number(inputs.dtype) - inputs += adder + # We keep the positions where the mask is True or > 0.5, and set the + # other (masked) positions to -1e.9. + if backend.standardize_dtype(mask.dtype) != "bool": + mask = backend.numpy.greater( + mask, backend.cast(0.5, dtype=mask.dtype) + ) + inputs = backend.numpy.where( + mask, inputs, _large_negative_number(inputs.dtype) + ) if isinstance(self.axis, (tuple, list)): if len(self.axis) > 1: outputs = backend.numpy.exp(