From 7a537f0e0c6a3c8281f0c3740a9cca1684fb5e35 Mon Sep 17 00:00:00 2001 From: Matteo Hessel Date: Fri, 10 Nov 2023 07:04:56 -0800 Subject: [PATCH] Update second_order package internal structure. Use the new subpackage structure where individual files are private (hence their names being prefixed with `_`) and you can only import the entire subpackages (e.g. `from optax import second_order`, instead of `from optax.second_order import hessian`) PiperOrigin-RevId: 581246429 --- optax/second_order/__init__.py | 6 +++--- optax/second_order/{base.py => _base.py} | 0 optax/second_order/{fisher.py => _fisher.py} | 4 ++-- optax/second_order/{hessian.py => _hessian.py} | 6 +++--- optax/second_order/{hessian_test.py => _hessian_test.py} | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) rename optax/second_order/{base.py => _base.py} (100%) rename optax/second_order/{fisher.py => _fisher.py} (95%) rename optax/second_order/{hessian.py => _hessian.py} (96%) rename optax/second_order/{hessian_test.py => _hessian_test.py} (96%) diff --git a/optax/second_order/__init__.py b/optax/second_order/__init__.py index 72c230fe..1df9fae5 100644 --- a/optax/second_order/__init__.py +++ b/optax/second_order/__init__.py @@ -14,6 +14,6 @@ # ============================================================================== """The second order optimisation sub-package.""" -from optax.second_order.fisher import fisher_diag -from optax.second_order.hessian import hessian_diag -from optax.second_order.hessian import hvp +from optax.second_order._fisher import fisher_diag +from optax.second_order._hessian import hessian_diag +from optax.second_order._hessian import hvp diff --git a/optax/second_order/base.py b/optax/second_order/_base.py similarity index 100% rename from optax/second_order/base.py rename to optax/second_order/_base.py diff --git a/optax/second_order/fisher.py b/optax/second_order/_fisher.py similarity index 95% rename from optax/second_order/fisher.py rename to optax/second_order/_fisher.py index d7e6e44d..11d7d2e0 100644 --- a/optax/second_order/fisher.py +++ b/optax/second_order/_fisher.py @@ -25,7 +25,7 @@ from jax import flatten_util import jax.numpy as jnp -from optax.second_order import base +from optax.second_order import _base def _ravel(p: Any) -> jax.Array: @@ -33,7 +33,7 @@ def _ravel(p: Any) -> jax.Array: def fisher_diag( - negative_log_likelihood: base.LossFn, + negative_log_likelihood: _base.LossFn, params: Any, inputs: jax.Array, targets: jax.Array, diff --git a/optax/second_order/hessian.py b/optax/second_order/_hessian.py similarity index 96% rename from optax/second_order/hessian.py rename to optax/second_order/_hessian.py index 33a69884..688fcb70 100644 --- a/optax/second_order/hessian.py +++ b/optax/second_order/_hessian.py @@ -25,7 +25,7 @@ from jax import flatten_util import jax.numpy as jnp -from optax.second_order import base +from optax.second_order import _base def _ravel(p: Any) -> jax.Array: @@ -33,7 +33,7 @@ def _ravel(p: Any) -> jax.Array: def hvp( - loss: base.LossFn, + loss: _base.LossFn, v: jax.Array, params: Any, inputs: jax.Array, @@ -58,7 +58,7 @@ def hvp( def hessian_diag( - loss: base.LossFn, + loss: _base.LossFn, params: Any, inputs: jax.Array, targets: jax.Array, diff --git a/optax/second_order/hessian_test.py b/optax/second_order/_hessian_test.py similarity index 96% rename from optax/second_order/hessian_test.py rename to optax/second_order/_hessian_test.py index 29a7b3b5..2442025f 100644 --- a/optax/second_order/hessian_test.py +++ b/optax/second_order/_hessian_test.py @@ -26,7 +26,7 @@ import jax.numpy as jnp import numpy as np -from optax import second_order +from optax.second_order import _hessian NUM_CLASSES = 2 @@ -77,7 +77,7 @@ def jax_hessian_diag(loss_fun, params, inputs, targets): @chex.all_variants def test_hessian_diag(self): hessian_diag_fn = self.variant( - functools.partial(second_order.hessian_diag, self.loss_fn)) + functools.partial(_hessian.hessian_diag, self.loss_fn)) actual = hessian_diag_fn(self.parameters, self.data, self.labels) np.testing.assert_array_almost_equal(self.hessian, actual, 5)