Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions keras/src/layers/activations/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,15 @@ def __init__(self, axis=-1, **kwargs):

def call(self, inputs, mask=None):
if mask is not None:
adder = (
1.0 - backend.cast(mask, inputs.dtype)
) * _large_negative_number(inputs.dtype)
inputs += adder
# We keep the positions where the mask is True or > 0.5, and set the
# other (masked) positions to -1e.9.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a small typo in this comment. -1e.9 appears to be a typo for -1e9. To improve clarity and accuracy, especially since _large_negative_number can return different values based on dtype, I suggest making the comment more general.

Suggested change
# other (masked) positions to -1e.9.
# other (masked) positions to a large negative number.

if backend.standardize_dtype(mask.dtype) != "bool":
mask = backend.numpy.greater(
mask, backend.cast(0.5, dtype=mask.dtype)
)
inputs = backend.numpy.where(
mask, inputs, _large_negative_number(inputs.dtype)
)
if isinstance(self.axis, (tuple, list)):
if len(self.axis) > 1:
outputs = backend.numpy.exp(
Expand Down