Skip to content

Commit

Permalink
"""Calculates the squared error for a set of predictions.
Browse files Browse the repository at this point in the history
  Mean Squared Error can be computed as squared_error(a, b).mean().

  Note: l2_loss = 0.5 * squared_error, where the 0.5 term is standard in
  "Pattern Recognition and Machine Learning" by Bishop, but not
  "The Elements of Statistical Learning" by Tibshirani.

  References:
    [Chris Bishop, 2006](https://bit.ly/3eeP0ga)

  Args:
    predictions: a vector of arbitrary shape `[...]`.
    targets: a vector with shape broadcastable to that of `predictions`;
      if not provided then it is assumed to be a vector of zeros.

  Returns:
    elementwise squared differences, with same shape as `predictions`.

PiperOrigin-RevId: 515627280
  • Loading branch information
OptaxDev committed Mar 10, 2023
1 parent 5f0f5da commit 451b006
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 7 deletions.
3 changes: 2 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ Common Losses
smooth_labels
softmax_cross_entropy
softmax_cross_entropy_with_integer_labels
squared_error


Losses
Expand All @@ -531,7 +532,7 @@ Losses
.. autofunction:: smooth_labels
.. autofunction:: softmax_cross_entropy
.. autofunction:: softmax_cross_entropy_with_integer_labels

.. autofunction:: squared_error

Linear Algebra Operators
========================
Expand Down
1 change: 1 addition & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
from optax._src.loss import smooth_labels
from optax._src.loss import softmax_cross_entropy
from optax._src.loss import softmax_cross_entropy_with_integer_labels
from optax._src.loss import squared_error
from optax._src.numerics import safe_int32_increment
from optax._src.numerics import safe_norm
from optax._src.numerics import safe_root_mean_squares
Expand Down
38 changes: 32 additions & 6 deletions optax/_src/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@
from optax._src import utils


def l2_loss(
def squared_error(
predictions: chex.Array,
targets: Optional[chex.Array] = None,
) -> chex.Array:
"""Calculates the L2 loss for a set of predictions.
"""Calculates the squared error for a set of predictions.
Note: the 0.5 term is standard in "Pattern Recognition and Machine Learning"
by Bishop, but not "The Elements of Statistical Learning" by Tibshirani.
Mean Squared Error can be computed as squared_error(a, b).mean().
Note: l2_loss = 0.5 * squared_error, where the 0.5 term is standard in
"Pattern Recognition and Machine Learning" by Bishop, but not
"The Elements of Statistical Learning" by Tibshirani.
References:
[Chris Bishop, 2006](https://bit.ly/3eeP0ga)
Expand All @@ -53,8 +56,31 @@ def l2_loss(
if targets is not None:
# Avoid broadcasting logic for "-" operator.
chex.assert_equal_shape((predictions, targets))
errors = (predictions - targets) if (targets is not None) else predictions
return 0.5 * (errors)**2
errors = predictions - targets if targets is not None else predictions
return errors ** 2


def l2_loss(
predictions: chex.Array,
targets: Optional[chex.Array] = None,
) -> chex.Array:
"""Calculates the L2 loss for a set of predictions.
Note: the 0.5 term is standard in "Pattern Recognition and Machine Learning"
by Bishop, but not "The Elements of Statistical Learning" by Tibshirani.
References:
[Chris Bishop, 2006](https://bit.ly/3eeP0ga)
Args:
predictions: a vector of arbitrary shape `[...]`.
targets: a vector with shape broadcastable to that of `predictions`;
if not provided then it is assumed to be a vector of zeros.
Returns:
elementwise squared differences, with same shape as `predictions`.
"""
return 0.5 * squared_error(predictions, targets)


def huber_loss(
Expand Down
28 changes: 28 additions & 0 deletions optax/_src/loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,34 @@
from optax._src import loss


class SquaredErrorTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.ys = jnp.array([-2., -1., 0.5, 1.])
self.ts = jnp.array([-1.5, 0., -1, 1.])
# compute expected outputs in numpy.
self.exp = (self.ts - self.ys) ** 2

@chex.all_variants
def test_scalar(self):
np.testing.assert_allclose(
self.variant(loss.squared_error)(
self.ys[0], self.ts[0]), self.exp[0])

@chex.all_variants
def test_batched(self):
np.testing.assert_allclose(
self.variant(loss.squared_error)(
self.ys, self.ts), self.exp)

@chex.all_variants
def test_shape_mismatch(self):
with self.assertRaises(AssertionError):
_ = self.variant(loss.squared_error)(
self.ys, jnp.expand_dims(self.ts, axis=-1))


class L2LossTest(parameterized.TestCase):

def setUp(self):
Expand Down

0 comments on commit 451b006

Please sign in to comment.