diff --git a/templates/api/losses/index.md b/templates/api/losses/index.md index 1045ae4eeb..9b3588a3cf 100644 --- a/templates/api/losses/index.md +++ b/templates/api/losses/index.md @@ -76,7 +76,8 @@ A loss is a callable with arguments `loss_fn(y_true, y_pred, sample_weight=None) by the corresponding value of `sample_weight`. (Note on`dN-1`: all loss functions reduce by 1 dimension, usually `axis=-1`.) -By default, loss functions return one scalar loss value per input sample, e.g. +By default, loss functions return one scalar loss value for each input sample +in the batch dimension, e.g. ``` >>> from keras import ops