Skip to content

Commit

Permalink
Update pytypes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 517430400
  • Loading branch information
hbq1 authored and OptaxDev committed Mar 17, 2023
1 parent 84c6449 commit 04768d2
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 10 deletions.
6 changes: 3 additions & 3 deletions optax/_src/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Note that complex numbers are also supported, see
https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
"""
from typing import Tuple
from typing import List, Tuple

import chex
import jax
Expand Down Expand Up @@ -126,8 +126,8 @@ def clip_fn(t):


def per_example_global_norm_clip(
grads: chex.Array, l2_norm_clip: float
) -> Tuple[chex.Array, jax.Array]:
grads: List[chex.Array], l2_norm_clip: float
) -> Tuple[List[chex.Array], jax.Array]:
"""Applies gradient clipping per-example using their global norm.
References:
Expand Down
2 changes: 1 addition & 1 deletion optax/_src/control_variates.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@


CvState = Any
ComputeCv = Callable[[base.Params, chex.Array, CvState], float]
ComputeCv = Callable[[base.Params, chex.Array, CvState], chex.Array]
CvExpectedValue = Callable[[base.Params, CvState], CvState]
UpdateCvState = Callable[[base.Params, chex.Array, CvState], CvState]
ControlVariate = Tuple[ComputeCv, CvExpectedValue, UpdateCvState]
Expand Down
4 changes: 2 additions & 2 deletions optax/_src/factorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
# pylint:disable=no-value-for-parameter


def _decay_rate_pow(i: int, exponent: float = 0.8) -> float:
def _decay_rate_pow(i: int, exponent: float = 0.8) -> chex.Array:
"""Second-order moment decay schedule."""
t = jnp.array(i, jnp.float32) + 1.0
t = jnp.array(i + 1, jnp.float32)
return 1.0 - t**(-exponent)


Expand Down
2 changes: 1 addition & 1 deletion optax/_src/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def join_schedules(schedules: Sequence[base.Schedule],
Returns:
schedule: A function that maps step counts to values.
"""
def schedule(step: jax.Array) -> jax.Array:
def schedule(step: chex.Numeric) -> chex.Numeric:
output = schedules[0](step)
for boundary, schedule in zip(boundaries, schedules[1:]):
output = jnp.where(step < boundary, output, schedule(step - boundary))
Expand Down
7 changes: 4 additions & 3 deletions optax/_src/state_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_adam(self):
)

expected = (
transform.ScaleByAdamState(
transform.ScaleByAdamState( # pytype:disable=wrong-arg-types
count=FakeShardSpec(sharding_axis=None),
mu={
'my/fake/module': {
Expand Down Expand Up @@ -225,12 +225,13 @@ def test_map_non_params_to_none(self):
)

expected = (
transform.ScaleByAdamState(
transform.ScaleByAdamState( # pytype:disable=wrong-arg-types
count=None,
mu={'a': 1},
nu={'a': 1},
),
transform.ScaleByScheduleState(count=None),
transform.ScaleByScheduleState( # pytype:disable=wrong-arg-types
count=None),
)
self.assertEqual(state, expected)

Expand Down

0 comments on commit 04768d2

Please sign in to comment.