Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,12 +755,26 @@ def binary_crossentropy(target, output, from_logits=False):
target = convert_to_tensor(target)
output = convert_to_tensor(output)

# We only apply the squeeze fix if we are on an MPS device,
# as this change breaks tests on other platforms that
# expect the original tensor shape to be preserved.
if (
torch.backends.mps.is_available()
and target.ndim > 1
and output.ndim == target.ndim
and target.shape[-1] == 1
and output.shape[-1] == 1
):
target = torch.squeeze(target, -1).contiguous()
output = torch.squeeze(output, -1).contiguous()

if target.shape != output.shape:
raise ValueError(
"Arguments `target` and `output` must have the same shape. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)

# By default, PyTorch, does reduction of `sum` over all rows,
# change reduction to `none` to keep dim
if from_logits:
Expand Down