Skip to content

Commit

Permalink
Merge pull request #413 from shivance:master
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 479022039
  • Loading branch information
OptaxDev committed Oct 5, 2022
2 parents e026a15 + 98d19d3 commit d744b3d
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions optax/_src/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Linear algebra utilities used in optimisation."""

import chex
import jax
from jax import lax
import jax.numpy as jnp
Expand All @@ -29,11 +30,10 @@ def global_norm(updates: base.Updates) -> base.Updates:
jnp.sum(numerics.abs_sq(x)) for x in jax.tree_util.tree_leaves(updates)))


def power_iteration(
matrix,
num_iters=100,
error_tolerance=1e-6,
precision=lax.Precision.HIGHEST):
def power_iteration(matrix: chex.Array,
num_iters: int = 100,
error_tolerance: float = 1e-6,
precision: lax.Precision = lax.Precision.HIGHEST):
r"""Power iteration algorithm.
The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
Expand All @@ -48,9 +48,9 @@ def power_iteration(
num_iters: Number of iterations.
error_tolerance: Iterative exit condition.
precision: precision XLA related flag, the available options are:
a) lax.Precision.DEFAULT (better step time, but not precise)
b) lax.Precision.HIGH (increased precision, slower)
c) lax.Precision.HIGHEST (best possible precision, slowest)
a) lax.Precision.DEFAULT (better step time, but not precise);
b) lax.Precision.HIGH (increased precision, slower);
c) lax.Precision.HIGHEST (best possible precision, slowest).
Returns:
eigen vector, eigen value
Expand Down Expand Up @@ -80,13 +80,12 @@ def _iter_body(state):
return v_out, s_out


def matrix_inverse_pth_root(
matrix,
p,
num_iters=100,
ridge_epsilon=1e-6,
error_tolerance=1e-6,
precision=lax.Precision.HIGHEST):
def matrix_inverse_pth_root(matrix: chex.Array,
p: int,
num_iters: int = 100,
ridge_epsilon: float = 1e-6,
error_tolerance: float = 1e-6,
precision: lax.Precision = lax.Precision.HIGHEST):
"""Computes `matrix^(-1/p)`, where `p` is a positive integer.
This function uses the Coupled newton iterations algorithm for
Expand All @@ -105,9 +104,9 @@ def matrix_inverse_pth_root(
ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
error_tolerance: Error indicator, useful for early termination.
precision: precision XLA related flag, the available options are:
a) lax.Precision.DEFAULT (better step time, but not precise)
b) lax.Precision.HIGH (increased precision, slower)
c) lax.Precision.HIGHEST (best possible precision, slowest)
a) lax.Precision.DEFAULT (better step time, but not precise);
b) lax.Precision.HIGH (increased precision, slower);
c) lax.Precision.HIGHEST (best possible precision, slowest).
Returns:
matrix^(-1/p)
Expand Down

0 comments on commit d744b3d

Please sign in to comment.