diff --git a/optax/second_order/__init__.py b/optax/second_order/__init__.py index 72c230fe..1df9fae5 100644 --- a/optax/second_order/__init__.py +++ b/optax/second_order/__init__.py @@ -14,6 +14,6 @@ # ============================================================================== """The second order optimisation sub-package.""" -from optax.second_order.fisher import fisher_diag -from optax.second_order.hessian import hessian_diag -from optax.second_order.hessian import hvp +from optax.second_order._fisher import fisher_diag +from optax.second_order._hessian import hessian_diag +from optax.second_order._hessian import hvp diff --git a/optax/second_order/base.py b/optax/second_order/_base.py similarity index 100% rename from optax/second_order/base.py rename to optax/second_order/_base.py diff --git a/optax/second_order/fisher.py b/optax/second_order/_fisher.py similarity index 95% rename from optax/second_order/fisher.py rename to optax/second_order/_fisher.py index d7e6e44d..11d7d2e0 100644 --- a/optax/second_order/fisher.py +++ b/optax/second_order/_fisher.py @@ -25,7 +25,7 @@ from jax import flatten_util import jax.numpy as jnp -from optax.second_order import base +from optax.second_order import _base def _ravel(p: Any) -> jax.Array: @@ -33,7 +33,7 @@ def _ravel(p: Any) -> jax.Array: def fisher_diag( - negative_log_likelihood: base.LossFn, + negative_log_likelihood: _base.LossFn, params: Any, inputs: jax.Array, targets: jax.Array, diff --git a/optax/second_order/hessian.py b/optax/second_order/_hessian.py similarity index 96% rename from optax/second_order/hessian.py rename to optax/second_order/_hessian.py index 33a69884..688fcb70 100644 --- a/optax/second_order/hessian.py +++ b/optax/second_order/_hessian.py @@ -25,7 +25,7 @@ from jax import flatten_util import jax.numpy as jnp -from optax.second_order import base +from optax.second_order import _base def _ravel(p: Any) -> jax.Array: @@ -33,7 +33,7 @@ def _ravel(p: Any) -> jax.Array: def hvp( - loss: base.LossFn, + loss: _base.LossFn, v: jax.Array, params: Any, inputs: jax.Array, @@ -58,7 +58,7 @@ def hvp( def hessian_diag( - loss: base.LossFn, + loss: _base.LossFn, params: Any, inputs: jax.Array, targets: jax.Array, diff --git a/optax/second_order/hessian_test.py b/optax/second_order/_hessian_test.py similarity index 96% rename from optax/second_order/hessian_test.py rename to optax/second_order/_hessian_test.py index 29a7b3b5..2442025f 100644 --- a/optax/second_order/hessian_test.py +++ b/optax/second_order/_hessian_test.py @@ -26,7 +26,7 @@ import jax.numpy as jnp import numpy as np -from optax import second_order +from optax.second_order import _hessian NUM_CLASSES = 2 @@ -77,7 +77,7 @@ def jax_hessian_diag(loss_fun, params, inputs, targets): @chex.all_variants def test_hessian_diag(self): hessian_diag_fn = self.variant( - functools.partial(second_order.hessian_diag, self.loss_fn)) + functools.partial(_hessian.hessian_diag, self.loss_fn)) actual = hessian_diag_fn(self.parameters, self.data, self.labels) np.testing.assert_array_almost_equal(self.hessian, actual, 5)