From d5e8dbf66ebdd1fd501a3be97e1fc61cd820cd82 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Sun, 18 Sep 2022 16:49:11 +0530 Subject: [PATCH] fixed optional type --- optax/_src/linear_algebra.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index 849c01f9..f1504445 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -33,9 +33,9 @@ def global_norm(updates: base.Updates) -> base.Updates: def power_iteration( matrix: chex.Array, - num_iters: Optional[int]=100, - error_tolerance: Optional[float]=1e-6, - precision: Optional[jnp.float32]=lax.Precision.HIGHEST): + num_iters: int=100, + error_tolerance: float=1e-6, + precision: lax.Precision=lax.Precision.HIGHEST): r"""Power iteration algorithm. The power iteration algorithm takes a symmetric PSD matrix `A`, and produces @@ -85,10 +85,10 @@ def _iter_body(state): def matrix_inverse_pth_root( matrix: chex.Array, p: int, - num_iters: Optional[int]=100, - ridge_epsilon: Optional[float]=1e-6, - error_tolerance: Optional[float]=1e-6, - precision: Optional[jnp.float32]=lax.Precision.HIGHEST): + num_iters: int=100, + ridge_epsilon: float=1e-6, + error_tolerance: float=1e-6, + precision: lax.Precision=lax.Precision.HIGHEST): """Computes `matrix^(-1/p)`, where `p` is a positive integer. This function uses the Coupled newton iterations algorithm for