Skip to content

Commit

Permalink
Move gradient transformation wrappers to optax.transforms sub-package…
Browse files Browse the repository at this point in the history
… - 1/N

PiperOrigin-RevId: 622167971
  • Loading branch information
mtthss authored and OptaxDev committed Apr 10, 2024
1 parent addb322 commit a4c7655
Show file tree
Hide file tree
Showing 8 changed files with 1,017 additions and 765 deletions.
82 changes: 5 additions & 77 deletions optax/_src/constrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,81 +14,9 @@
# ==============================================================================
"""Gradient transformations used to enforce specific constraints."""

from typing import Any, NamedTuple
from optax.transforms import _constraining

import jax
import jax.numpy as jnp

from optax._src import base

# pylint:disable=no-value-for-parameter


NonNegativeParamsState = base.EmptyState


def keep_params_nonnegative() -> base.GradientTransformation:
"""Modifies the updates to keep parameters non-negative, i.e. >= 0.
This transformation ensures that parameters after the update will be
larger than or equal to zero.
In a chain of transformations, this should be the last one.
WARNING: the transformation expects input params to be non-negative.
When params is negative the transformed update will move them to 0.
Returns:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return NonNegativeParamsState()

def update_fn(updates, state, params):
if params is None:
raise ValueError(base.NO_PARAMS_MSG)

updates = jax.tree_util.tree_map(
lambda p, u: jnp.where((p + u) < 0., -p, u), params, updates)
return updates, state

return base.GradientTransformation(init_fn, update_fn)


class ZeroNansState(NamedTuple):
"""Contains a tree.
The entry `found_nan` has the same tree structure as that of the parameters.
Each leaf is a single boolean which contains True iff a NaN was detected in
the corresponding parameter array at the last call to `update`.
"""
found_nan: Any


def zero_nans() -> base.GradientTransformation:
"""A transformation which replaces NaNs with 0.
The state of the transformation has the same tree structure as that of the
parameters. Each leaf is a single boolean which contains True iff a NaN was
detected in the corresponding parameter array at the last call to ``update``.
This state is not used by the transformation internally, but lets users be
aware when NaNs have been zeroed out.
Returns:
A `GradientTransformation`.
"""

def init_fn(params):
return ZeroNansState(jax.tree_util.tree_map(
lambda p: jnp.array(False, dtype=jnp.bool_), params))

def update_fn(updates, opt_state, params=None):
del params
opt_state = ZeroNansState(
jax.tree_util.tree_map(lambda p: jnp.any(jnp.isnan(p)), updates))
updates = jax.tree_util.tree_map(
lambda p: jnp.where(jnp.isnan(p), jnp.zeros_like(p), p), updates)
return updates, opt_state

return base.GradientTransformation(init=init_fn, update=update_fn)
keep_params_nonnegative = _constraining.keep_params_nonnegative
NonNegativeParamsState = _constraining.NonNegativeParamsState
zero_nans = _constraining.zero_nans
ZeroNansState = _constraining.ZeroNansState
Loading

0 comments on commit a4c7655

Please sign in to comment.