diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index 3ebde731..420aff36 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -14,6 +14,7 @@ # ============================================================================== """Linear algebra utilities used in optimisation.""" +import chex import jax from jax import lax import jax.numpy as jnp @@ -29,11 +30,10 @@ def global_norm(updates: base.Updates) -> base.Updates: jnp.sum(numerics.abs_sq(x)) for x in jax.tree_util.tree_leaves(updates))) -def power_iteration( - matrix, - num_iters=100, - error_tolerance=1e-6, - precision=lax.Precision.HIGHEST): +def power_iteration(matrix: chex.Array, + num_iters: int = 100, + error_tolerance: float = 1e-6, + precision: lax.Precision = lax.Precision.HIGHEST): r"""Power iteration algorithm. The power iteration algorithm takes a symmetric PSD matrix `A`, and produces @@ -48,9 +48,9 @@ def power_iteration( num_iters: Number of iterations. error_tolerance: Iterative exit condition. precision: precision XLA related flag, the available options are: - a) lax.Precision.DEFAULT (better step time, but not precise) - b) lax.Precision.HIGH (increased precision, slower) - c) lax.Precision.HIGHEST (best possible precision, slowest) + a) lax.Precision.DEFAULT (better step time, but not precise); + b) lax.Precision.HIGH (increased precision, slower); + c) lax.Precision.HIGHEST (best possible precision, slowest). Returns: eigen vector, eigen value @@ -80,13 +80,12 @@ def _iter_body(state): return v_out, s_out -def matrix_inverse_pth_root( - matrix, - p, - num_iters=100, - ridge_epsilon=1e-6, - error_tolerance=1e-6, - precision=lax.Precision.HIGHEST): +def matrix_inverse_pth_root(matrix: chex.Array, + p: int, + num_iters: int = 100, + ridge_epsilon: float = 1e-6, + error_tolerance: float = 1e-6, + precision: lax.Precision = lax.Precision.HIGHEST): """Computes `matrix^(-1/p)`, where `p` is a positive integer. This function uses the Coupled newton iterations algorithm for @@ -105,9 +104,9 @@ def matrix_inverse_pth_root( ridge_epsilon: Ridge epsilon added to make the matrix positive definite. error_tolerance: Error indicator, useful for early termination. precision: precision XLA related flag, the available options are: - a) lax.Precision.DEFAULT (better step time, but not precise) - b) lax.Precision.HIGH (increased precision, slower) - c) lax.Precision.HIGHEST (best possible precision, slowest) + a) lax.Precision.DEFAULT (better step time, but not precise); + b) lax.Precision.HIGH (increased precision, slower); + c) lax.Precision.HIGHEST (best possible precision, slowest). Returns: matrix^(-1/p)