diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index 849c01f9..f1504445 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -33,9 +33,9 @@ def global_norm(updates: base.Updates) -> base.Updates: def power_iteration( matrix: chex.Array, - num_iters: Optional[int]=100, - error_tolerance: Optional[float]=1e-6, - precision: Optional[jnp.float32]=lax.Precision.HIGHEST): + num_iters: int=100, + error_tolerance: float=1e-6, + precision: lax.Precision=lax.Precision.HIGHEST): r"""Power iteration algorithm. The power iteration algorithm takes a symmetric PSD matrix `A`, and produces @@ -85,10 +85,10 @@ def _iter_body(state): def matrix_inverse_pth_root( matrix: chex.Array, p: int, - num_iters: Optional[int]=100, - ridge_epsilon: Optional[float]=1e-6, - error_tolerance: Optional[float]=1e-6, - precision: Optional[jnp.float32]=lax.Precision.HIGHEST): + num_iters: int=100, + ridge_epsilon: float=1e-6, + error_tolerance: float=1e-6, + precision: lax.Precision=lax.Precision.HIGHEST): """Computes `matrix^(-1/p)`, where `p` is a positive integer. This function uses the Coupled newton iterations algorithm for