Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move gradient transformations to optax.transforms sub-package - 1/N #923

Merged
merged 1 commit into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 5 additions & 77 deletions optax/_src/constrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,81 +14,9 @@
# ==============================================================================
"""Gradient transformations used to enforce specific constraints."""

from typing import Any, NamedTuple
from optax.transforms import _constraining

import jax
import jax.numpy as jnp

from optax._src import base

# pylint:disable=no-value-for-parameter


NonNegativeParamsState = base.EmptyState


def keep_params_nonnegative() -> base.GradientTransformation:
"""Modifies the updates to keep parameters non-negative, i.e. >= 0.

This transformation ensures that parameters after the update will be
larger than or equal to zero.
In a chain of transformations, this should be the last one.

WARNING: the transformation expects input params to be non-negative.
When params is negative the transformed update will move them to 0.

Returns:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return NonNegativeParamsState()

def update_fn(updates, state, params):
if params is None:
raise ValueError(base.NO_PARAMS_MSG)

updates = jax.tree_util.tree_map(
lambda p, u: jnp.where((p + u) < 0., -p, u), params, updates)
return updates, state

return base.GradientTransformation(init_fn, update_fn)


class ZeroNansState(NamedTuple):
"""Contains a tree.

The entry `found_nan` has the same tree structure as that of the parameters.
Each leaf is a single boolean which contains True iff a NaN was detected in
the corresponding parameter array at the last call to `update`.
"""
found_nan: Any


def zero_nans() -> base.GradientTransformation:
"""A transformation which replaces NaNs with 0.

The state of the transformation has the same tree structure as that of the
parameters. Each leaf is a single boolean which contains True iff a NaN was
detected in the corresponding parameter array at the last call to ``update``.
This state is not used by the transformation internally, but lets users be
aware when NaNs have been zeroed out.

Returns:
A `GradientTransformation`.
"""

def init_fn(params):
return ZeroNansState(jax.tree_util.tree_map(
lambda p: jnp.array(False, dtype=jnp.bool_), params))

def update_fn(updates, opt_state, params=None):
del params
opt_state = ZeroNansState(
jax.tree_util.tree_map(lambda p: jnp.any(jnp.isnan(p)), updates))
updates = jax.tree_util.tree_map(
lambda p: jnp.where(jnp.isnan(p), jnp.zeros_like(p), p), updates)
return updates, opt_state

return base.GradientTransformation(init=init_fn, update=update_fn)
keep_params_nonnegative = _constraining.keep_params_nonnegative
NonNegativeParamsState = _constraining.NonNegativeParamsState
zero_nans = _constraining.zero_nans
ZeroNansState = _constraining.ZeroNansState
Loading
Loading