Skip to content

Commit

Permalink
fixed optional type
Browse files Browse the repository at this point in the history
  • Loading branch information
shivance committed Sep 18, 2022
1 parent 8f5560a commit d5e8dbf
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions optax/_src/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def global_norm(updates: base.Updates) -> base.Updates:

def power_iteration(
matrix: chex.Array,
num_iters: Optional[int]=100,
error_tolerance: Optional[float]=1e-6,
precision: Optional[jnp.float32]=lax.Precision.HIGHEST):
num_iters: int=100,
error_tolerance: float=1e-6,
precision: lax.Precision=lax.Precision.HIGHEST):
r"""Power iteration algorithm.
The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
Expand Down Expand Up @@ -85,10 +85,10 @@ def _iter_body(state):
def matrix_inverse_pth_root(
matrix: chex.Array,
p: int,
num_iters: Optional[int]=100,
ridge_epsilon: Optional[float]=1e-6,
error_tolerance: Optional[float]=1e-6,
precision: Optional[jnp.float32]=lax.Precision.HIGHEST):
num_iters: int=100,
ridge_epsilon: float=1e-6,
error_tolerance: float=1e-6,
precision: lax.Precision=lax.Precision.HIGHEST):
"""Computes `matrix^(-1/p)`, where `p` is a positive integer.
This function uses the Coupled newton iterations algorithm for
Expand Down

0 comments on commit d5e8dbf

Please sign in to comment.