Skip to content

Commit

Permalink
Update second_order package internal structure.
Browse files Browse the repository at this point in the history
Use the new subpackage structure where individual files are private (hence their names being prefixed with `_`) and you can only import the entire subpackages (e.g. `from optax import second_order`, instead of `from optax.second_order import hessian`)

PiperOrigin-RevId: 581246429
  • Loading branch information
mtthss authored and OptaxDev committed Nov 10, 2023
1 parent 86ea3bf commit 7a537f0
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions optax/second_order/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# ==============================================================================
"""The second order optimisation sub-package."""

from optax.second_order.fisher import fisher_diag
from optax.second_order.hessian import hessian_diag
from optax.second_order.hessian import hvp
from optax.second_order._fisher import fisher_diag
from optax.second_order._hessian import hessian_diag
from optax.second_order._hessian import hvp
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
from jax import flatten_util
import jax.numpy as jnp

from optax.second_order import base
from optax.second_order import _base


def _ravel(p: Any) -> jax.Array:
return flatten_util.ravel_pytree(p)[0]


def fisher_diag(
negative_log_likelihood: base.LossFn,
negative_log_likelihood: _base.LossFn,
params: Any,
inputs: jax.Array,
targets: jax.Array,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
from jax import flatten_util
import jax.numpy as jnp

from optax.second_order import base
from optax.second_order import _base


def _ravel(p: Any) -> jax.Array:
return flatten_util.ravel_pytree(p)[0]


def hvp(
loss: base.LossFn,
loss: _base.LossFn,
v: jax.Array,
params: Any,
inputs: jax.Array,
Expand All @@ -58,7 +58,7 @@ def hvp(


def hessian_diag(
loss: base.LossFn,
loss: _base.LossFn,
params: Any,
inputs: jax.Array,
targets: jax.Array,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import jax.numpy as jnp
import numpy as np

from optax import second_order
from optax.second_order import _hessian


NUM_CLASSES = 2
Expand Down Expand Up @@ -77,7 +77,7 @@ def jax_hessian_diag(loss_fun, params, inputs, targets):
@chex.all_variants
def test_hessian_diag(self):
hessian_diag_fn = self.variant(
functools.partial(second_order.hessian_diag, self.loss_fn))
functools.partial(_hessian.hessian_diag, self.loss_fn))
actual = hessian_diag_fn(self.parameters, self.data, self.labels)
np.testing.assert_array_almost_equal(self.hessian, actual, 5)

Expand Down

0 comments on commit 7a537f0

Please sign in to comment.