Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added typing to linear_algebra.py #413

Merged
merged 4 commits into from
Oct 5, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions optax/_src/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import jax
from jax import lax
import jax.numpy as jnp
import chex
import numpy as np

from optax._src import base
Expand All @@ -30,10 +31,10 @@ def global_norm(updates: base.Updates) -> base.Updates:


def power_iteration(
matrix,
num_iters=100,
error_tolerance=1e-6,
precision=lax.Precision.HIGHEST):
matrix: chex.Array,
num_iters: int=100,
error_tolerance: float=1e-6,
precision: lax.Precision=lax.Precision.HIGHEST):
r"""Power iteration algorithm.

The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
Expand Down Expand Up @@ -81,12 +82,12 @@ def _iter_body(state):


def matrix_inverse_pth_root(
matrix,
p,
num_iters=100,
ridge_epsilon=1e-6,
error_tolerance=1e-6,
precision=lax.Precision.HIGHEST):
matrix: chex.Array,
p: int,
num_iters: int=100,
ridge_epsilon: float=1e-6,
error_tolerance: float=1e-6,
precision: lax.Precision=lax.Precision.HIGHEST):
"""Computes `matrix^(-1/p)`, where `p` is a positive integer.

This function uses the Coupled newton iterations algorithm for
Expand Down