diff --git a/docs/api.rst b/docs/api.rst index 49653141..dd72591c 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -512,6 +512,7 @@ Common Losses smooth_labels softmax_cross_entropy softmax_cross_entropy_with_integer_labels + squared_error Losses @@ -531,7 +532,7 @@ Losses .. autofunction:: smooth_labels .. autofunction:: softmax_cross_entropy .. autofunction:: softmax_cross_entropy_with_integer_labels - +.. autofunction:: squared_error Linear Algebra Operators ======================== diff --git a/optax/__init__.py b/optax/__init__.py index 7bfc00af..5f5c2a15 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -93,6 +93,7 @@ from optax._src.loss import smooth_labels from optax._src.loss import softmax_cross_entropy from optax._src.loss import softmax_cross_entropy_with_integer_labels +from optax._src.loss import squared_error from optax._src.numerics import safe_int32_increment from optax._src.numerics import safe_norm from optax._src.numerics import safe_root_mean_squares diff --git a/optax/_src/loss.py b/optax/_src/loss.py index 23ad78d5..103b46e2 100644 --- a/optax/_src/loss.py +++ b/optax/_src/loss.py @@ -29,14 +29,17 @@ from optax._src import utils -def l2_loss( +def squared_error( predictions: chex.Array, targets: Optional[chex.Array] = None, ) -> chex.Array: - """Calculates the L2 loss for a set of predictions. + """Calculates the squared error for a set of predictions. - Note: the 0.5 term is standard in "Pattern Recognition and Machine Learning" - by Bishop, but not "The Elements of Statistical Learning" by Tibshirani. + Mean Squared Error can be computed as squared_error(a, b).mean(). + + Note: l2_loss = 0.5 * squared_error, where the 0.5 term is standard in + "Pattern Recognition and Machine Learning" by Bishop, but not + "The Elements of Statistical Learning" by Tibshirani. References: [Chris Bishop, 2006](https://bit.ly/3eeP0ga) @@ -53,8 +56,31 @@ def l2_loss( if targets is not None: # Avoid broadcasting logic for "-" operator. chex.assert_equal_shape((predictions, targets)) - errors = (predictions - targets) if (targets is not None) else predictions - return 0.5 * (errors)**2 + errors = predictions - targets if targets is not None else predictions + return errors ** 2 + + +def l2_loss( + predictions: chex.Array, + targets: Optional[chex.Array] = None, +) -> chex.Array: + """Calculates the L2 loss for a set of predictions. + + Note: the 0.5 term is standard in "Pattern Recognition and Machine Learning" + by Bishop, but not "The Elements of Statistical Learning" by Tibshirani. + + References: + [Chris Bishop, 2006](https://bit.ly/3eeP0ga) + + Args: + predictions: a vector of arbitrary shape `[...]`. + targets: a vector with shape broadcastable to that of `predictions`; + if not provided then it is assumed to be a vector of zeros. + + Returns: + elementwise squared differences, with same shape as `predictions`. + """ + return 0.5 * squared_error(predictions, targets) def huber_loss( diff --git a/optax/_src/loss_test.py b/optax/_src/loss_test.py index 8717660f..3b523d90 100644 --- a/optax/_src/loss_test.py +++ b/optax/_src/loss_test.py @@ -25,6 +25,34 @@ from optax._src import loss +class SquaredErrorTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.ys = jnp.array([-2., -1., 0.5, 1.]) + self.ts = jnp.array([-1.5, 0., -1, 1.]) + # compute expected outputs in numpy. + self.exp = (self.ts - self.ys) ** 2 + + @chex.all_variants + def test_scalar(self): + np.testing.assert_allclose( + self.variant(loss.squared_error)( + self.ys[0], self.ts[0]), self.exp[0]) + + @chex.all_variants + def test_batched(self): + np.testing.assert_allclose( + self.variant(loss.squared_error)( + self.ys, self.ts), self.exp) + + @chex.all_variants + def test_shape_mismatch(self): + with self.assertRaises(AssertionError): + _ = self.variant(loss.squared_error)( + self.ys, jnp.expand_dims(self.ts, axis=-1)) + + class L2LossTest(parameterized.TestCase): def setUp(self):