Skip to content

Commit

Permalink
added typing to linear_algebra.py
Browse files Browse the repository at this point in the history
  • Loading branch information
shivance committed Sep 11, 2022
1 parent 44df918 commit ba7837f
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions optax/_src/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
# ==============================================================================
"""Linear algebra utilities used in optimisation."""

from optparse import Option
import jax
from jax import lax
import jax.numpy as jnp
import chex
import numpy as np
from typing import Optional

from optax._src import base
from optax._src import numerics
Expand All @@ -30,10 +33,10 @@ def global_norm(updates: base.Updates) -> base.Updates:


def power_iteration(
matrix,
num_iters=100,
error_tolerance=1e-6,
precision=lax.Precision.HIGHEST):
matrix: chex.Array,
num_iters: Optional[int]=100,
error_tolerance: Optional[float]=1e-6,
precision: Optional[jnp.float32]=lax.Precision.HIGHEST):
r"""Power iteration algorithm.
The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
Expand Down Expand Up @@ -81,12 +84,12 @@ def _iter_body(state):


def matrix_inverse_pth_root(
matrix,
p,
num_iters=100,
ridge_epsilon=1e-6,
error_tolerance=1e-6,
precision=lax.Precision.HIGHEST):
matrix: chex.Array,
p: int,
num_iters: Optional[int]=100,
ridge_epsilon: Optional[float]=1e-6,
error_tolerance: Optional[float]=1e-6,
precision: Optional[jnp.float32]=lax.Precision.HIGHEST):
"""Computes `matrix^(-1/p)`, where `p` is a positive integer.
This function uses the Coupled newton iterations algorithm for
Expand Down

0 comments on commit ba7837f

Please sign in to comment.