From ba7837f65f9eba652e0a4452498c683c916efbf4 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Sun, 11 Sep 2022 17:36:18 +0530 Subject: [PATCH 1/4] added typing to linear_algebra.py --- optax/_src/linear_algebra.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index 3ebde731..c0b6f9cf 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -14,10 +14,13 @@ # ============================================================================== """Linear algebra utilities used in optimisation.""" +from optparse import Option import jax from jax import lax import jax.numpy as jnp +import chex import numpy as np +from typing import Optional from optax._src import base from optax._src import numerics @@ -30,10 +33,10 @@ def global_norm(updates: base.Updates) -> base.Updates: def power_iteration( - matrix, - num_iters=100, - error_tolerance=1e-6, - precision=lax.Precision.HIGHEST): + matrix: chex.Array, + num_iters: Optional[int]=100, + error_tolerance: Optional[float]=1e-6, + precision: Optional[jnp.float32]=lax.Precision.HIGHEST): r"""Power iteration algorithm. The power iteration algorithm takes a symmetric PSD matrix `A`, and produces @@ -81,12 +84,12 @@ def _iter_body(state): def matrix_inverse_pth_root( - matrix, - p, - num_iters=100, - ridge_epsilon=1e-6, - error_tolerance=1e-6, - precision=lax.Precision.HIGHEST): + matrix: chex.Array, + p: int, + num_iters: Optional[int]=100, + ridge_epsilon: Optional[float]=1e-6, + error_tolerance: Optional[float]=1e-6, + precision: Optional[jnp.float32]=lax.Precision.HIGHEST): """Computes `matrix^(-1/p)`, where `p` is a positive integer. This function uses the Coupled newton iterations algorithm for From 8f5560aa3c0d50ffbaec31715d2683222c2df6a7 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Sat, 17 Sep 2022 09:44:27 +0530 Subject: [PATCH 2/4] removing unused optparse import this was imported mistakenly, is not used in the code --- optax/_src/linear_algebra.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index c0b6f9cf..849c01f9 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -14,7 +14,6 @@ # ============================================================================== """Linear algebra utilities used in optimisation.""" -from optparse import Option import jax from jax import lax import jax.numpy as jnp From d5e8dbf66ebdd1fd501a3be97e1fc61cd820cd82 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Sun, 18 Sep 2022 16:49:11 +0530 Subject: [PATCH 3/4] fixed optional type --- optax/_src/linear_algebra.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index 849c01f9..f1504445 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -33,9 +33,9 @@ def global_norm(updates: base.Updates) -> base.Updates: def power_iteration( matrix: chex.Array, - num_iters: Optional[int]=100, - error_tolerance: Optional[float]=1e-6, - precision: Optional[jnp.float32]=lax.Precision.HIGHEST): + num_iters: int=100, + error_tolerance: float=1e-6, + precision: lax.Precision=lax.Precision.HIGHEST): r"""Power iteration algorithm. The power iteration algorithm takes a symmetric PSD matrix `A`, and produces @@ -85,10 +85,10 @@ def _iter_body(state): def matrix_inverse_pth_root( matrix: chex.Array, p: int, - num_iters: Optional[int]=100, - ridge_epsilon: Optional[float]=1e-6, - error_tolerance: Optional[float]=1e-6, - precision: Optional[jnp.float32]=lax.Precision.HIGHEST): + num_iters: int=100, + ridge_epsilon: float=1e-6, + error_tolerance: float=1e-6, + precision: lax.Precision=lax.Precision.HIGHEST): """Computes `matrix^(-1/p)`, where `p` is a positive integer. This function uses the Coupled newton iterations algorithm for From 98d19d37381d1fd8794b0c6ac323dcd2b282e671 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Tue, 4 Oct 2022 21:49:31 +0530 Subject: [PATCH 4/4] Update linear_algebra.py --- optax/_src/linear_algebra.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index f1504445..ea5b83b8 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -19,7 +19,6 @@ import jax.numpy as jnp import chex import numpy as np -from typing import Optional from optax._src import base from optax._src import numerics