From a4c765543d1a8610060fc2a34c92ad35471f0fa1 Mon Sep 17 00:00:00 2001 From: Matteo Hessel Date: Fri, 5 Apr 2024 07:01:39 -0700 Subject: [PATCH] Move gradient transformation wrappers to optax.transforms sub-package - 1/N PiperOrigin-RevId: 622167971 --- optax/_src/constrain.py | 82 +--- optax/_src/wrappers.py | 713 +--------------------------- optax/transforms/__init__.py | 36 ++ optax/transforms/_accumulation.py | 393 +++++++++++++++ optax/transforms/_conditionality.py | 252 ++++++++++ optax/transforms/_constraining.py | 93 ++++ optax/transforms/_layouts.py | 77 +++ optax/transforms/_masking.py | 136 ++++++ 8 files changed, 1017 insertions(+), 765 deletions(-) create mode 100644 optax/transforms/__init__.py create mode 100644 optax/transforms/_accumulation.py create mode 100644 optax/transforms/_conditionality.py create mode 100644 optax/transforms/_constraining.py create mode 100644 optax/transforms/_layouts.py create mode 100644 optax/transforms/_masking.py diff --git a/optax/_src/constrain.py b/optax/_src/constrain.py index 55a7c497..74f4bc74 100644 --- a/optax/_src/constrain.py +++ b/optax/_src/constrain.py @@ -14,81 +14,9 @@ # ============================================================================== """Gradient transformations used to enforce specific constraints.""" -from typing import Any, NamedTuple +from optax.transforms import _constraining -import jax -import jax.numpy as jnp - -from optax._src import base - -# pylint:disable=no-value-for-parameter - - -NonNegativeParamsState = base.EmptyState - - -def keep_params_nonnegative() -> base.GradientTransformation: - """Modifies the updates to keep parameters non-negative, i.e. >= 0. - - This transformation ensures that parameters after the update will be - larger than or equal to zero. - In a chain of transformations, this should be the last one. - - WARNING: the transformation expects input params to be non-negative. - When params is negative the transformed update will move them to 0. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return NonNegativeParamsState() - - def update_fn(updates, state, params): - if params is None: - raise ValueError(base.NO_PARAMS_MSG) - - updates = jax.tree_util.tree_map( - lambda p, u: jnp.where((p + u) < 0., -p, u), params, updates) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -class ZeroNansState(NamedTuple): - """Contains a tree. - - The entry `found_nan` has the same tree structure as that of the parameters. - Each leaf is a single boolean which contains True iff a NaN was detected in - the corresponding parameter array at the last call to `update`. - """ - found_nan: Any - - -def zero_nans() -> base.GradientTransformation: - """A transformation which replaces NaNs with 0. - - The state of the transformation has the same tree structure as that of the - parameters. Each leaf is a single boolean which contains True iff a NaN was - detected in the corresponding parameter array at the last call to ``update``. - This state is not used by the transformation internally, but lets users be - aware when NaNs have been zeroed out. - - Returns: - A `GradientTransformation`. - """ - - def init_fn(params): - return ZeroNansState(jax.tree_util.tree_map( - lambda p: jnp.array(False, dtype=jnp.bool_), params)) - - def update_fn(updates, opt_state, params=None): - del params - opt_state = ZeroNansState( - jax.tree_util.tree_map(lambda p: jnp.any(jnp.isnan(p)), updates)) - updates = jax.tree_util.tree_map( - lambda p: jnp.where(jnp.isnan(p), jnp.zeros_like(p), p), updates) - return updates, opt_state - - return base.GradientTransformation(init=init_fn, update=update_fn) +keep_params_nonnegative = _constraining.keep_params_nonnegative +NonNegativeParamsState = _constraining.NonNegativeParamsState +zero_nans = _constraining.zero_nans +ZeroNansState = _constraining.ZeroNansState diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index 3ead5b56..262c80a6 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -15,708 +15,45 @@ """Transformation wrappers.""" import functools -from typing import Any, Callable, NamedTuple, Optional, Protocol, Union +from typing import Callable import chex -import jax -from jax import lax -from jax import tree_util as jtu import jax.numpy as jnp -import numpy as np -from optax import tree_utils as otu from optax._src import base -from optax._src import numerics -from optax.tree_utils import _state_utils - - -Array = jnp.ndarray - - -def flatten( - inner: base.GradientTransformation -) -> base.GradientTransformationExtraArgs: - """Flattens parameters and gradients for init and update of inner transform. - - This can reduce the overhead of performing many calculations on lots of small - variables, at the cost of slightly increased memory usage. - - Args: - inner: Inner transformation to flatten inputs for. - - Returns: - New ``GradientTransformationExtraArgs`` - """ - - inner = base.with_extra_args_support(inner) - - def _flatten(params): - """Flattens and concatenates all tensors in params to a single vector.""" - params, _ = jtu.tree_flatten(params) - return jnp.concatenate([jnp.reshape(param, [-1]) for param in params]) - - def _unflatten(updates, flat): - """Extracts tensors from flat, using the structure and shapes of params.""" - updates_flat, treedef = jtu.tree_flatten(updates) - offsets = [] - for update in updates_flat: - size = np.size(update) - if offsets: - offsets.append(size + offsets[-1]) - else: - offsets.append(size) - del offsets[-1] - flat_split = jnp.split(flat, offsets) - reshaped = [ - jnp.reshape(flat_update, update.shape) - for flat_update, update in zip(flat_split, updates_flat) - ] - return jtu.tree_unflatten(treedef, reshaped) - - def init_fn(params): - flat = _flatten(params) - return inner.init(flat) - - def update_fn(updates, state, params=None, **extra_args): - if params is not None: - params = _flatten(params) - updates_flat, state = inner.update( - _flatten(updates), state, params, **extra_args - ) - updates = _unflatten(updates, updates_flat) - return updates, state - - return base.GradientTransformationExtraArgs(init_fn, update_fn) - - -class ApplyIfFiniteState(NamedTuple): - """State of the `GradientTransformation` returned by `apply_if_finite`. - - Fields: - notfinite_count: Number of consecutive gradient updates containing an Inf or - a NaN. This number is reset to 0 whenever a gradient update without an Inf - or a NaN is done. - last_finite: Whether or not the last gradient update contained an Inf or a - NaN. - total_notfinite: Total number of gradient updates containing an Inf or - a NaN since this optimizer was initialised. This number is never reset. - inner_state: The state of the inner `GradientTransformation`. - """ - # TODO(optax-dev): notfinite_count, last_finite and inner_state used to be - # annotated as `jnp.array` but that is not a valid annotation (it's a function - # and secretely resolved to `Any`. We should add back typing. - notfinite_count: Any - last_finite: Any - total_notfinite: Any - inner_state: Any - - -def apply_if_finite( - inner: base.GradientTransformation, - max_consecutive_errors: int -) -> base.GradientTransformation: - """A function that wraps an optimizer to make it robust to a few NaNs or Infs. - - The purpose of this function is to prevent any optimization to happen if the - gradients contain NaNs or Infs. That is, when a NaN or Inf is detected in the - gradients, the wrapped optimizer ignores that gradient update. If the NaNs or - Infs persist after a given number of updates, the wrapped optimizer gives up - and accepts the update. - - Args: - inner: Inner transformation to be wrapped. - max_consecutive_errors: Maximum number of consecutive gradient updates - containing NaNs or Infs that the wrapped optimizer will ignore. After - that many ignored updates, the optimizer will give up and accept. - - Returns: - New ``GradientTransformationExtraArgs``. - """ - - inner = base.with_extra_args_support(inner) - - def init(params): - return ApplyIfFiniteState( - notfinite_count=jnp.zeros([], jnp.int32), - last_finite=jnp.array(True, jnp.bool_), - total_notfinite=jnp.zeros([], jnp.int32), - inner_state=inner.init(params)) - - def update(updates, state, params=None, **extra_args): - inner_state = state.inner_state - flat_updates = jtu.tree_flatten(updates)[0] - isfinite = jnp.all( - jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates])) - notfinite_count = jnp.where( - isfinite, jnp.zeros([], jnp.int32), - numerics.safe_int32_increment(state.notfinite_count)) - - def do_update(_): - return inner.update(updates, inner_state, params, **extra_args) - def reject_update(_): - return otu.tree_zeros_like(updates), inner_state - - updates, new_inner_state = lax.cond( - jnp.logical_or(isfinite, notfinite_count > max_consecutive_errors), - do_update, reject_update, operand=None) - - return updates, ApplyIfFiniteState( - notfinite_count=notfinite_count, - last_finite=isfinite, - total_notfinite=jnp.where( - isfinite, state.total_notfinite, - numerics.safe_int32_increment(state.total_notfinite)), - inner_state=new_inner_state) - - return base.GradientTransformationExtraArgs(init=init, update=update) - - -class MultiStepsState(NamedTuple): - """State of the `GradientTransformation` returned by `MultiSteps`. - - Fields: - mini_step: current mini-step counter. At an update, this either increases by - 1 or is reset to 0. - gradient_step: gradient step counter. This only increases after enough - mini-steps have been accumulated. - inner_opt_state: the state of the wrapped otpimiser. - acc_grads: accumulated gradients over multiple mini-steps. - skip_state: an arbitrarily nested tree of arrays. This is only - relevant when passing a `should_skip_update_fn` to `MultiSteps`. This - structure will then contain values for debugging and or monitoring. The - actual structure will vary depending on the choice of - `ShouldSkipUpdateFunction`. - """ - mini_step: Array - gradient_step: Array - inner_opt_state: Any - acc_grads: Any - skip_state: chex.ArrayTree = () - - -class ShouldSkipUpdateFunction(Protocol): - - def __call__(self, updates: base.Updates, gradient_step: Array, - params: Optional[base.Params]) -> tuple[Array, chex.ArrayTree]: - """Returns true to indicate that updates should be skipped in a multi-step. - - Args: - updates: The updates that the gradient transformation has proposed - to apply - gradient_step: The current gradient step (see - `MultiStepsState.gradient_step`). This can be used for example to reject - large gradients with an annealed maximum allowed gradient norm. - params: If known, the current parameter tree of the function being - transformed. - Returns: - A tuple: - * First element is an array with a single bool indicating whether or not - the updates should be applied. - * Second element is an arbitrarily nested structure of arrays that will be - stored in `MultiStepsState.skip_state`. The structure will vary from - function to function. Debugging info, or values to monitor, can be put - in this structure. - """ - - -def skip_not_finite( - updates: base.Updates, gradient_step: Array, - params: Optional[base.Params]) -> tuple[Array, chex.ArrayTree]: - """Returns True iff any of the `updates` contains an inf or a NaN. - - Args: - updates: see `ShouldSkipUpdateFunction`. - gradient_step: see `ShouldSkipUpdateFunction`. - params: see `ShouldSkipUpdateFunction`. - - Returns: - A tuple: - * First element is a scalar array of type bool. - * Second element is a dictionary with keys: - - `should_skip`: True iff `updates` contains an inf or a NaN. - - `num_not_finite`: total number of inf and NaN found in `updates`. - """ - del gradient_step, params - all_is_finite = [jnp.sum(jnp.logical_not(jnp.isfinite(p))) - for p in jtu.tree_leaves(updates)] - num_not_finite = jnp.sum(jnp.array(all_is_finite)) - should_skip = num_not_finite > 0 - return should_skip, dict(should_skip=should_skip, - num_not_finite=num_not_finite) - - -def skip_large_updates(updates: base.Updates, - gradient_step: Array, - params: Optional[base.Params], - max_squared_norm: float) -> tuple[Array, chex.ArrayTree]: - """Returns True if the global norm square of `updates` is small enough. - - Args: - updates: see `ShouldSkipUpdateFunction`. - gradient_step: see `ShouldSkipUpdateFunction`. - params: see `ShouldSkipUpdateFunction`. - max_squared_norm: only updates with a norm square strictly less than this - value will be accepted. - - Returns: - A tuple: - * First element is a scalar array of type bool. - * Second element is a dictionary with keys: - - `should_skip`: True iff square norm of `updates` is larger or equal than - `max_squared_norm`. - - `norm_squared`: overall norm square of the `updates`. - """ - del gradient_step, params - norm_sq = jnp.sum( - jnp.array([jnp.sum(p**2) for p in jtu.tree_leaves(updates)])) - # This will also return True if `norm_sq` is NaN. - should_skip = jnp.logical_not(norm_sq < max_squared_norm) - return should_skip, dict(should_skip=should_skip, norm_squared=norm_sq) - - -class MultiSteps: - """An optimizer wrapper to accumulate gradients over multiple steps. - - This wrapper collects together the updates passed to its ``update`` function - over consecutive steps until a given number of scheduled steps is reached. - In each of these intermediate steps, the returned value from the optimizer is - a tree of zeros of the same shape of the updates passed as input. - - Once the scheduled number of intermediate 'mini-steps' has been reached, the - gradients accumulated to the current time will be passed to the wrapped - optimizer's update function, (with the inner optimizer's state being updated - appropriately) and then returned to the caller. The wrapper's accumulated - gradients are then set back to zero and the process starts again. - - The number of mini-steps per gradient update is controlled by a function, and - it can vary over training. This offers a means of varying batch size over - training. - """ - - def __init__( - self, - opt: base.GradientTransformation, - every_k_schedule: Union[int, Callable[[Array], Array]], - use_grad_mean: bool = True, - should_skip_update_fn: Optional[ShouldSkipUpdateFunction] = None): - # pylint: disable=line-too-long - """Initialiser. - - Args: - opt: the wrapped optimizer. - every_k_schedule: an int or a function. - - * As a function, it returns how many mini-steps should be accumulated - in a single gradient step. Its only argument is the current - gradient step count. By varying the returned value, users can vary the - overall training batch size. - * If an ``int``, this is the constant number of mini-steps per gradient - update. - use_grad_mean: if ``True`` (the default), gradients accumulated over - multiple mini-steps are averaged. Otherwise, they are summed. - should_skip_update_fn: if provided, this function is used to decide when - to accept or reject the updates from a mini-step. When a mini-step is - rejected, the inner state of `MultiSteps` is not updated. In other - words, it is as if this mini-step never happened. For example: - - * to ignore updates containing inf or NaN, do - ``should_skip_update_fn=skip_not_finite``; - * to ignore updates with a norm square larger then 42, do: - ``should_skip_update_fn=functools.partial(skip_large_updates, max_norm_sq=42.)`` - - Note that the optimizer's state :class:`optax.MultiStepsState` contains - a keyword argument ``skip_state`` in which debugging and monitoring - information returned by ``should_skip_update_fn`` is written. - """ - # pylint: enable=line-too-long - self._opt = base.with_extra_args_support(opt) - - if isinstance(every_k_schedule, int): - self._every_k_schedule = lambda step: every_k_schedule - else: - self._every_k_schedule = every_k_schedule - self._use_grad_mean = use_grad_mean - - if self._use_grad_mean: - # Use Welford algorithm for numerically stable aggregation of mean. - self._acc_update = ( - lambda grad, acc, *, n_acc: acc + (grad - acc) / (n_acc + 1)) - else: - self._acc_update = lambda grad, acc, *, n_acc: grad + acc - - if should_skip_update_fn is None: - - def should_skip_update_fn(*unused_args, **unused_kwargs): - return jnp.array(False, dtype=jnp.bool_), () - - self._should_skip_update_fn = should_skip_update_fn - - @property - def inner_opt(self): - return self._opt - - def init(self, params: Any) -> MultiStepsState: - """Builds and returns initial `MultiStepsState`.""" - updates = otu.tree_zeros_like(params) - gradient_step = jnp.zeros([], dtype=jnp.int32) - _, skip_state = self._should_skip_update_fn(updates, gradient_step, params) - init_state = MultiStepsState( - mini_step=jnp.zeros([], dtype=jnp.int32), - gradient_step=gradient_step, - inner_opt_state=self._opt.init(params), - acc_grads=updates, - skip_state=skip_state) - return init_state - - def update(self, - updates: base.Updates, - state: MultiStepsState, - params: Optional[base.Params] = None, - **extra_args: Any, - ) -> tuple[base.Updates, MultiStepsState]: - """Accumulates gradients and proposes non-zero updates every `k_steps`.""" - k_steps = self._every_k_schedule(state.gradient_step) - should_skip_update, skip_state = self._should_skip_update_fn( - updates, state.gradient_step, params) - if (should_skip_update.dtype, should_skip_update.shape) != (jnp.bool_, ()): - raise ValueError( - 'The `should_skip_update_fn` function should return a boolean scalar ' - f'array, but it returned an array of dtype {should_skip_update.dtype}' - f' and shape {should_skip_update.shape}' - ) - - # Note: we do not enclose variables to allow JAX to re-use memory buffers. - def _do_update(updates, state, params): - acc_grads = jtu.tree_map( - lambda upd, acc: self._acc_update(upd, acc, n_acc=state.mini_step), - updates, - state.acc_grads, - ) - - final_updates, new_inner_state = self._opt.update( - acc_grads, state.inner_opt_state, params=params, **extra_args - ) - - emit = state.mini_step == (k_steps - 1) - new_state = MultiStepsState( - mini_step=numerics.safe_int32_increment(state.mini_step) % k_steps, - gradient_step=emit - * numerics.safe_int32_increment(state.gradient_step) - + (1 - emit) * state.gradient_step, - inner_opt_state=jtu.tree_map( - lambda st, nst: jnp.where(emit, nst, st), - state.inner_opt_state, - new_inner_state, - ), - acc_grads=jtu.tree_map( - lambda ga: (1 - emit) * ga, acc_grads - ), - skip_state=skip_state, - ) - - final_updates = jtu.tree_map( - lambda ga: emit * ga, final_updates - ) - return final_updates, new_state - - def _skip_update(updates, state, params): - del updates, params - multi_state_when_skip = MultiStepsState( - mini_step=state.mini_step, - gradient_step=state.gradient_step, - inner_opt_state=state.inner_opt_state, - acc_grads=state.acc_grads, - skip_state=skip_state, - ) - zero_updates = otu.tree_zeros_like(state.acc_grads) - return zero_updates, multi_state_when_skip - - new_updates, new_state = jax.lax.cond( - should_skip_update, _skip_update, _do_update, *(updates, state, params) - ) - return new_updates, new_state - - def has_updated(self, state: Union[MultiStepsState, chex.ArrayTree]) -> Array: - # Use `getattr` to bypass pytype checks. - return jnp.logical_and( - getattr(state, 'mini_step') == 0, getattr(state, 'gradient_step') > 0 - ) - - def gradient_transformation(self) -> base.GradientTransformation: - return base.GradientTransformation(init=self.init, update=self.update) - - -class MaskedState(NamedTuple): - """Maintains inner transform state for masked transformations.""" - inner_state: Any - - -class MaskedNode(NamedTuple): - """A node used to mask out unspecified parts of a tree. - - This node is ignored when mapping functions across the tree e.g. using - `jtu.tree_map` since it is a container without children. It can - therefore be used to mask out parts of a tree. - """ - - -def masked( - inner: base.GradientTransformation, - mask: Union[base.PyTree, Callable[[base.Params], base.PyTree]], - *, - mask_compatible_extra_args: bool = False, -) -> base.GradientTransformationExtraArgs: - """Mask updates so only some are transformed, the rest are passed through. - - For example, it is common to skip weight decay for BatchNorm scale and all - bias parameters. In many networks, these are the only parameters with only - one dimension. So, you may create a mask function to mask these out as - follows:: - - mask_fn = lambda p: jtu.tree_map(lambda x: x.ndim != 1, p) - weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask_fn) - - You may alternatively create the mask pytree upfront:: - - mask = jtu.tree_map(lambda x: x.ndim != 1, params) - weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask) - - For the ``inner`` transform, state will only be stored for the parameters that - have a mask value of ``True``. - - Note that, when using ``tree_map_params``, it may be required to pass the - argument `is_leaf=lambda v: isinstance(v, optax.MaskedNode)`, if the tree - map needs to take additional arguments with the same shape as the original - input tree. - - Args: - inner: Inner transformation to mask. - mask: a PyTree with same structure as (or a prefix of) the params PyTree, or - a Callable that returns such a pytree given the params/updates. The leaves - should be booleans, ``True`` for leaves/subtrees you want to apply the - transformation to, and ``False`` for those you want to skip. The mask must - be static for the gradient transformation to be jit-compilable. - mask_compatible_extra_args: whether to also apply the same masking to - extra_arg fields with the same tree structure as params/updates. - - Returns: - New ``GradientTransformationExtraArgs`` wrapping ``inner``. - """ - inner = base.with_extra_args_support(inner) - - def mask_pytree(pytree, mask_tree): - return jtu.tree_map( - lambda m, p: p if m else MaskedNode(), mask_tree, pytree - ) - - # It is possible that `extra_args` of update_fn has pytrees with the same - # structure as params/updates, e.g. parameter tags. This function applies - # the mask to those pytrees. - def maybe_mask_values(pytree_dict, base_pytree, mask_tree): - base_structure = jtu.tree_structure(base_pytree) - - def _maybe_mask(pytree): - if mask_compatible_extra_args and ( - jtu.tree_structure(pytree) == base_structure): - return mask_pytree(pytree, mask_tree) - else: - return pytree - - return {k: _maybe_mask(v) for k, v in pytree_dict.items()} - - def init_fn(params): - # This is a workaround to make tree_map_params work with masking. - # The API of `masked` takes a mask on construction, instead of at init. - # This means that this gradient transformation can only work for parameter - # trees that match the shape of the mask. Technically this breaks the API - # of optax, and this causes tree_map_params to break. This is because - # tree_map_params calls init with a placeholder in order to detect copies - # of the parameter tree. As a (slightly ugly) workaround, we detect when - # the init is being called by tree_map_params, and pass the placeholder - # down without masking. This is safe, since tree_map_params does not impose - # any particular constraints on the shape of the parameter tree, as long - # as tree_map_params is being called on a tree with the correct structure. - # See wrappers_test for proof that this works! - if isinstance(params, _state_utils._ParamsPlaceholder): # pylint:disable=protected-access - return MaskedState(inner_state=inner.init(params)) - - mask_tree = mask(params) if callable(mask) else mask - masked_params = mask_pytree(params, mask_tree) - return MaskedState(inner_state=inner.init(masked_params)) - - def update_fn(updates, state, params=None, **extra_args): - mask_tree = mask(updates) if callable(mask) else mask - masked_extra_args = maybe_mask_values(extra_args, updates, mask_tree) - masked_updates = mask_pytree(updates, mask_tree) - masked_params = None if params is None else mask_pytree(params, mask_tree) - - new_masked_updates, new_inner_state = inner.update( - masked_updates, state.inner_state, masked_params, **masked_extra_args) - - new_updates = jtu.tree_map( - lambda m, new_u, old_u: new_u if m else old_u, - mask_tree, new_masked_updates, updates) - return new_updates, MaskedState(inner_state=new_inner_state) - - return base.GradientTransformationExtraArgs(init_fn, update_fn) - - -class ConditionFn(Protocol): - """Condition function for conditional transformations.""" - - def __call__( - self, - step: Array, - **extra_args: Any, - ) -> Array: - """Update function with optional extra arguments. - - Args: - step: a counter (array of shape [] and dtype ``int32``) - **extra_args: Additional keyword arguments passed to this condition fn. - - Returns: - a boolean array of shape [] and dtype ``bool`` indicating whether the - inner transformation should be called. - """ - - -class ConditionallyTransformState(NamedTuple): - """Maintains inner transform state and adds a step counter.""" - inner_state: Any - step: Array - - -def conditionally_transform( - inner: base.GradientTransformation, - should_transform_fn: ConditionFn, - forward_extra_args: bool = False, -) -> base.GradientTransformationExtraArgs: - """Calls the inner update function only at certain steps. - - Creates a transformation wrapper that conditionally applies the inner gradient - transformation, and if the condition is not met, just passes the updates and - inner state through unchanged. The behaviour is controlled by a user specified - function ``should_transform_fn`` that is called by ``conditionally_transform`` - passing as input a counter of the number of times that the ``update`` function - has been previously called, the user specified function must returns a boolean - controlling whether the inner transformation should be called. - - WARNING: if instead you want to set the ``updates`` to zero when the condition - is not met, you can use the ``conditionally_mask`` wrapper. - - Args: - inner: the inner transformation. - should_transform_fn: function takes in a ``step`` counter (array of shape [] - and dtype ``int32``), and returns a boolean array of shape []. If - ``forward_extra_args`` is set to True, any extra arguments are also - forwarded to the ``should_transform_fn`. - forward_extra_args: forward extra args to ``should_transform_fn``. - - Returns: - A new ``GradientTransformationExtraArgs``. - - .. versionadded:: 0.2.3 - """ - inner = base.with_extra_args_support(inner) - - def init_fn(params): - return ConditionallyTransformState( - inner_state=inner.init(params), step=jnp.zeros([], dtype=jnp.int32)) - - def update_fn(updates, state, params=None, **extra_args): - - def do_update(_): - return inner.update(updates, state.inner_state, params, **extra_args) - - def reject_update(_): - return updates, state.inner_state - - condition_kwargs = extra_args if forward_extra_args else {} - updates, new_inner_state = lax.cond( - should_transform_fn(state.step, **condition_kwargs), - do_update, reject_update, operand=None) - return updates, ConditionallyTransformState( - new_inner_state, numerics.safe_int32_increment(state.step)) - - return base.GradientTransformationExtraArgs(init_fn, update_fn) +from optax.transforms import _accumulation +from optax.transforms import _conditionality +from optax.transforms import _layouts +from optax.transforms import _masking + + +apply_if_finite = _conditionality.apply_if_finite +ApplyIfFiniteState = _conditionality.ApplyIfFiniteState +ConditionFn = _conditionality.ConditionFn +conditionally_mask = _conditionality.conditionally_mask +conditionally_transform = _conditionality.conditionally_transform +ConditionallyMaskState = _conditionality.ConditionallyMaskState +ConditionallyTransformState = _conditionality.ConditionallyTransformState +flatten = _layouts.flatten +masked = _masking.masked +MaskedNode = _masking.MaskedNode +MaskedState = _masking.MaskedState +MultiSteps = _accumulation.MultiSteps +MultiStepsState = _accumulation.MultiStepsState +ShouldSkipUpdateFunction = _accumulation.ShouldSkipUpdateFunction +skip_not_finite = _accumulation.skip_not_finite +skip_large_updates = _accumulation.skip_large_updates @functools.partial( chex.warn_deprecated_function, - replacement='maybe_transform') + replacement='optax.transforms.maybe_transform') def maybe_update( inner: base.GradientTransformation, - should_update_fn: Callable[[Array], Array] + should_update_fn: Callable[[jnp.ndarray], jnp.ndarray] ) -> base.GradientTransformationExtraArgs: return conditionally_transform( inner=inner, should_transform_fn=should_update_fn) -# TODO(mtthss): delete with deprecated ``maybe_update``. MaybeUpdateState = ConditionallyTransformState - - -class ConditionallyMaskState(NamedTuple): - step: chex.Array - inner_state: base.OptState - - -def conditionally_mask( - inner: base.GradientTransformation, - should_transform_fn: ConditionFn, - forward_extra_args: bool = False, -) -> base.GradientTransformationExtraArgs: - """Calls the inner update function only at certain steps. - - Creates a transformation wrapper that conditionally applies the inner gradient - transformation, and if the condition is not met, the updates are set to 0, - while the inner state is passed through unchanged. The behaviour is controlled - by a user specified function ``should_transform_fn`` that is called - by ``conditionally_transform`` passing as input a counter of the number of - times that the ``update`` function has been previously called, the user - specified function must returns a boolean controlling whether the inner - transformation should be called. - - WARNING: if instead you want to leave ``updates`` unchanged when the condition - is not met, you can use the ``conditionally_transform`` wrapper. - - Args: - inner: the inner transformation. - should_transform_fn: function takes in a step counter (array of shape [] - and dtype ``int32``), and returns a boolean array of shape []. If - ``forward_extra_args`` is set to True, any extra arguments are also - forwarded to the ``should_transform_fn`. - forward_extra_args: forward extra args to ``should_transform_fn``. - - Returns: - A new ``GradientTransformationExtraArgs``. - - .. versionadded:: 0.2.3 - """ - inner = base.with_extra_args_support(inner) - - def init_fn(params): - return ConditionallyMaskState( - step=jnp.zeros([], jnp.int32), inner_state=inner.init(params) - ) - - def update_fn(updates, state, params=None, **extra_args): - - def do_update(_): - return inner.update(updates, state.inner_state, params, **extra_args) - - def reject_update(_): - return jax.tree_util.tree_map(jnp.zeros_like, updates), state.inner_state - - condition_kwargs = extra_args if forward_extra_args else {} - updates, new_inner_state = lax.cond( - should_transform_fn(state.step, **condition_kwargs), - do_update, reject_update, operand=None) - - return updates, ConditionallyMaskState( - step=numerics.safe_int32_increment(state.step), - inner_state=new_inner_state, - ) - - return base.GradientTransformationExtraArgs(init_fn, update_fn) diff --git a/optax/transforms/__init__.py b/optax/transforms/__init__.py new file mode 100644 index 00000000..be67d1e8 --- /dev/null +++ b/optax/transforms/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The transforms sub-package.""" + +from optax.transforms._accumulation import ema +from optax.transforms._accumulation import EmaState +from optax.transforms._accumulation import MultiSteps +from optax.transforms._accumulation import MultiStepsState +from optax.transforms._accumulation import ShouldSkipUpdateFunction +from optax.transforms._accumulation import skip_large_updates +from optax.transforms._accumulation import skip_not_finite +from optax.transforms._accumulation import trace +from optax.transforms._accumulation import TraceState +from optax.transforms._conditionality import apply_if_finite +from optax.transforms._conditionality import ApplyIfFiniteState +from optax.transforms._conditionality import conditionally_mask +from optax.transforms._conditionality import conditionally_transform +from optax.transforms._conditionality import ConditionallyMaskState +from optax.transforms._conditionality import ConditionallyTransformState +from optax.transforms._conditionality import ConditionFn +from optax.transforms._layouts import flatten +from optax.transforms._masking import masked +from optax.transforms._masking import MaskedNode +from optax.transforms._masking import MaskedState diff --git a/optax/transforms/_accumulation.py b/optax/transforms/_accumulation.py new file mode 100644 index 00000000..2b420d66 --- /dev/null +++ b/optax/transforms/_accumulation.py @@ -0,0 +1,393 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradient transformations for accumulating gradients across updates.""" + +from typing import Any, Callable, NamedTuple, Optional, Protocol, Union + +import chex +from jax import lax +from jax import tree_util as jtu +import jax.numpy as jnp + +from optax import tree_utils as otu + +from optax._src import base +from optax._src import numerics +from optax._src import utils + + +class TraceState(NamedTuple): + """Holds an aggregation of past updates.""" + trace: base.Params + + +def trace( + decay: float, + nesterov: bool = False, + accumulator_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """Compute a trace of past updates. + + Note: `trace` and `ema` have very similar but distinct updates; + `trace = decay * trace + t`, while `ema = decay * ema + (1-decay) * t`. + Both are frequently found in the optimization literature. + + Args: + decay: Decay rate for the trace of past updates. + nesterov: Whether to use Nesterov momentum. + accumulator_dtype: Optional `dtype` to be used for the accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + A `GradientTransformation` object. + """ + + accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) + + def init_fn(params): + return TraceState( + trace=otu.tree_zeros_like(params, dtype=accumulator_dtype)) + + def update_fn(updates, state, params=None): + del params + f = lambda g, t: g + decay * t + new_trace = jtu.tree_map(f, updates, state.trace) + updates = jtu.tree_map(f, updates, new_trace) if nesterov else new_trace + new_trace = otu.tree_cast(new_trace, accumulator_dtype) + return updates, TraceState(trace=new_trace) + + return base.GradientTransformation(init_fn, update_fn) + + +class EmaState(NamedTuple): + """Holds an exponential moving average of past updates.""" + count: chex.Array # shape=(), dtype=jnp.int32. + ema: base.Params + + +def ema( + decay: float, + debias: bool = True, + accumulator_dtype: Optional[Any] = None +) -> base.GradientTransformation: + """Compute an exponential moving average of past updates. + + Note: `trace` and `ema` have very similar but distinct updates; + `ema = decay * ema + (1-decay) * t`, while `trace = decay * trace + t`. + Both are frequently found in the optimization literature. + + Args: + decay: Decay rate for the exponential moving average. + debias: Whether to debias the transformed gradient. + accumulator_dtype: Optional `dtype` to used for the accumulator; if `None` + then the `dtype` is inferred from `params` and `updates`. + + Returns: + A `GradientTransformation` object. + """ + + accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) + + def init_fn(params): + return EmaState( + count=jnp.zeros([], jnp.int32), + ema=otu.tree_zeros_like(params, dtype=accumulator_dtype)) + + def update_fn(updates, state, params=None): + del params + updates = new_ema = otu.tree_update_moment( + updates, state.ema, decay, order=1) + count_inc = utils.safe_int32_increment(state.count) + if debias: + updates = otu.tree_bias_correction(new_ema, decay, count_inc) + state_ema = otu.tree_cast(new_ema, accumulator_dtype) + return updates, EmaState(count=count_inc, ema=state_ema) + + return base.GradientTransformation(init_fn, update_fn) + + +class ShouldSkipUpdateFunction(Protocol): + + def __call__( + self, + updates: base.Updates, + gradient_step: chex.Array, + params: Optional[base.Params] + ) -> tuple[chex.Array, chex.ArrayTree]: + """Returns true to indicate that updates should be skipped in a multi-step. + + Args: + updates: The updates that the gradient transformation has proposed. + gradient_step: The current gradient step (see + `MultiStepsState.gradient_step`). This can be used for example to reject + large gradients with an annealed maximum allowed gradient norm. + params: If known, the current params of the function being transformed. + + Returns: + A tuple: + * First element is an array with a single bool indicating whether or not + the updates should be applied. + * Second element is an arbitrary py-tree that will be stored in + `MultiStepsState.skip_state`. Debugging info can be put here. + """ + + +def skip_not_finite( + updates: base.Updates, + gradient_step: chex.Array, + params: Optional[base.Params] +) -> tuple[chex.Array, chex.ArrayTree]: + """Returns True iff any of the `updates` contains an inf or a NaN. + + Args: + updates: see `ShouldSkipUpdateFunction`. + gradient_step: see `ShouldSkipUpdateFunction`. + params: see `ShouldSkipUpdateFunction`. + + Returns: + A tuple: + * First element is a scalar array of type bool. + * Second element is a dictionary with keys: + - `should_skip`: True iff `updates` contains an inf or a NaN. + - `num_not_finite`: total number of inf and NaN found in `updates`. + """ + del gradient_step, params + all_is_finite = [jnp.sum(jnp.logical_not(jnp.isfinite(p))) + for p in jtu.tree_leaves(updates)] + num_not_finite = jnp.sum(jnp.array(all_is_finite)) + should_skip = num_not_finite > 0 + return should_skip, dict(should_skip=should_skip, + num_not_finite=num_not_finite) + + +def skip_large_updates( + updates: base.Updates, + gradient_step: chex.Array, + params: Optional[base.Params], + max_squared_norm: float +) -> tuple[chex.Array, chex.ArrayTree]: + """Returns True if the global norm square of `updates` is small enough. + + Args: + updates: see `ShouldSkipUpdateFunction`. + gradient_step: see `ShouldSkipUpdateFunction`. + params: see `ShouldSkipUpdateFunction`. + max_squared_norm: max square norm that can be accepted in updates. + + Returns: + A tuple: + * First element is a scalar array of type bool. + * Second element is a dictionary with keys: + - `should_skip`: iff ||updates||^2 is greater than `max_squared_norm`. + - `norm_squared`: overall norm square of the `updates`. + """ + del gradient_step, params + norm_sq = jnp.sum( + jnp.array([jnp.sum(p**2) for p in jtu.tree_leaves(updates)])) + # This will also return True if `norm_sq` is NaN. + should_skip = jnp.logical_not(norm_sq < max_squared_norm) + return should_skip, dict(should_skip=should_skip, norm_squared=norm_sq) + + +class MultiStepsState(NamedTuple): + """State of the `GradientTransformation` returned by `MultiSteps`. + + Fields: + mini_step: current mini-step counter. At an update, this either increases by + 1 or is reset to 0. + gradient_step: gradient step counter. This only increases after enough + mini-steps have been accumulated. + inner_opt_state: the state of the wrapped otpimiser. + acc_grads: accumulated gradients over multiple mini-steps. + skip_state: an arbitrarily py tree. This is only relevant when passing + a `should_skip_update_fn` to `MultiSteps`. + """ + mini_step: chex.Array + gradient_step: chex.Array + inner_opt_state: Any + acc_grads: Any + skip_state: chex.ArrayTree = () + + +class MultiSteps: + """An optimizer wrapper to accumulate gradients over multiple steps. + + This wrapper collects together the updates passed to its ``update`` function + over consecutive steps until a given number of scheduled steps is reached. + In each of these intermediate steps, the returned value from the optimizer is + a tree of zeros of the same shape of the updates passed as input. + + Once the scheduled number of intermediate 'mini-steps' has been reached, the + gradients accumulated to the current time will be passed to the wrapped + optimizer's update function, (with the inner optimizer's state being updated + appropriately) and then returned to the caller. The wrapper's accumulated + gradients are then set back to zero and the process starts again. + + The number of mini-steps per gradient update is controlled by a function, and + can vary over training, this also allows varying batch size over training. + """ + + def __init__( + self, + opt: base.GradientTransformation, + every_k_schedule: Union[int, Callable[[chex.Array], chex.Array]], + use_grad_mean: bool = True, + should_skip_update_fn: Optional[ShouldSkipUpdateFunction] = None): + # pylint: disable=line-too-long + """Initialiser. + + Args: + opt: the wrapped optimizer. + every_k_schedule: an int or a function. + + * As a function, it returns how many mini-steps should be accumulated + in a single gradient step. Its only argument is the current + gradient step count. By varying the returned value, users can vary the + overall training batch size. + * If an ``int``, this is the constant number of mini-steps per gradient + update. + use_grad_mean: if ``True`` (the default), gradients accumulated over + multiple mini-steps are averaged. Otherwise, they are summed. + should_skip_update_fn: if provided, this function is used to decide when + to accept or reject the updates from a mini-step. When a mini-step is + rejected, the inner state of `MultiSteps` is not updated. In other + words, it is as if this mini-step never happened. For example: + + * to ignore updates containing inf or NaN, do + ``should_skip_update_fn=skip_not_finite``; + * to ignore updates with a norm square larger then 42, do: + ``should_skip_update_fn=functools.partial(skip_large_updates, max_norm_sq=42.)`` + + Note that the optimizer's state :class:`optax.MultiStepsState` contains + a keyword argument ``skip_state`` in which debugging and monitoring + information returned by ``should_skip_update_fn`` is written. + """ + # pylint: enable=line-too-long + self._opt = base.with_extra_args_support(opt) + + if isinstance(every_k_schedule, int): + self._every_k_schedule = lambda step: every_k_schedule + else: + self._every_k_schedule = every_k_schedule + self._use_grad_mean = use_grad_mean + + if self._use_grad_mean: + # Use Welford algorithm for numerically stable aggregation of mean. + self._acc_update = ( + lambda grad, acc, *, n_acc: acc + (grad - acc) / (n_acc + 1)) + else: + self._acc_update = lambda grad, acc, *, n_acc: grad + acc + + if should_skip_update_fn is None: + + def should_skip_update_fn(*unused_args, **unused_kwargs): + return jnp.array(False, dtype=jnp.bool_), () + + self._should_skip_update_fn = should_skip_update_fn + + @property + def inner_opt(self): + return self._opt + + def init(self, params: Any) -> MultiStepsState: + """Builds and returns initial `MultiStepsState`.""" + updates = otu.tree_zeros_like(params) + gradient_step = jnp.zeros([], dtype=jnp.int32) + _, skip_state = self._should_skip_update_fn(updates, gradient_step, params) + init_state = MultiStepsState( + mini_step=jnp.zeros([], dtype=jnp.int32), + gradient_step=gradient_step, + inner_opt_state=self._opt.init(params), + acc_grads=updates, + skip_state=skip_state) + return init_state + + def update(self, + updates: base.Updates, + state: MultiStepsState, + params: Optional[base.Params] = None, + **extra_args: Any, + ) -> tuple[base.Updates, MultiStepsState]: + """Accumulates gradients and proposes non-zero updates every `k_steps`.""" + k_steps = self._every_k_schedule(state.gradient_step) + should_skip_update, skip_state = self._should_skip_update_fn( + updates, state.gradient_step, params) + if (should_skip_update.dtype, should_skip_update.shape) != (jnp.bool_, ()): + raise ValueError( + 'The `should_skip_update_fn` function should return a boolean scalar ' + f'array, but it returned an array of dtype {should_skip_update.dtype}' + f' and shape {should_skip_update.shape}' + ) + + # Note: we do not enclose variables to allow JAX to re-use memory buffers. + def _do_update(updates, state, params): + acc_grads = jtu.tree_map( + lambda upd, acc: self._acc_update(upd, acc, n_acc=state.mini_step), + updates, + state.acc_grads, + ) + + final_updates, new_inner_state = self._opt.update( + acc_grads, state.inner_opt_state, params=params, **extra_args + ) + + emit = state.mini_step == (k_steps - 1) + new_state = MultiStepsState( + mini_step=numerics.safe_int32_increment(state.mini_step) % k_steps, + gradient_step=emit + * numerics.safe_int32_increment(state.gradient_step) + + (1 - emit) * state.gradient_step, + inner_opt_state=jtu.tree_map( + lambda st, nst: jnp.where(emit, nst, st), + state.inner_opt_state, + new_inner_state, + ), + acc_grads=jtu.tree_map( + lambda ga: (1 - emit) * ga, acc_grads + ), + skip_state=skip_state, + ) + + final_updates = jtu.tree_map( + lambda ga: emit * ga, final_updates + ) + return final_updates, new_state + + def _skip_update(updates, state, params): + del updates, params + multi_state_when_skip = MultiStepsState( + mini_step=state.mini_step, + gradient_step=state.gradient_step, + inner_opt_state=state.inner_opt_state, + acc_grads=state.acc_grads, + skip_state=skip_state, + ) + zero_updates = otu.tree_zeros_like(state.acc_grads) + return zero_updates, multi_state_when_skip + + new_updates, new_state = lax.cond( + should_skip_update, _skip_update, _do_update, *(updates, state, params) + ) + return new_updates, new_state + + def has_updated( + self, state: Union[MultiStepsState, chex.ArrayTree]) -> chex.Array: + # Use `getattr` to bypass pytype checks. + return jnp.logical_and( + getattr(state, 'mini_step') == 0, getattr(state, 'gradient_step') > 0 + ) + + def gradient_transformation(self) -> base.GradientTransformation: + return base.GradientTransformation(init=self.init, update=self.update) diff --git a/optax/transforms/_conditionality.py b/optax/transforms/_conditionality.py new file mode 100644 index 00000000..70b0d75f --- /dev/null +++ b/optax/transforms/_conditionality.py @@ -0,0 +1,252 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrappers that allow transformations to be applied conditionally.""" + +from typing import Any, NamedTuple, Protocol + +import chex +from jax import lax +from jax import tree_util as jtu +import jax.numpy as jnp + +from optax import tree_utils as otu +from optax._src import base +from optax._src import numerics + + +class ConditionFn(Protocol): + """Condition function for conditional transformations.""" + + def __call__( + self, + step: chex.Array, + **extra_args: Any, + ) -> chex.Array: + """Update function with optional extra arguments. + + Args: + step: a counter (array of shape [] and dtype ``int32``) + **extra_args: Additional keyword arguments passed to this condition fn. + + Returns: + a boolean array of shape [] and dtype ``bool`` indicating whether the + inner transformation should be called. + """ + + +class ConditionallyTransformState(NamedTuple): + """Maintains inner transform state and adds a step counter.""" + inner_state: Any + step: chex.Array + + +def conditionally_transform( + inner: base.GradientTransformation, + should_transform_fn: ConditionFn, + forward_extra_args: bool = False, +) -> base.GradientTransformationExtraArgs: + """Calls the inner update function only at certain steps. + + Creates a transformation wrapper that conditionally applies the inner gradient + transformation, and if the condition is not met, just passes the updates and + inner state through unchanged. The behaviour is controlled by a user specified + function ``should_transform_fn`` that is called by ``conditionally_transform`` + passing as input a counter of the number of times that the ``update`` function + has been previously called, the user specified function must returns a boolean + controlling whether the inner transformation should be called. + + WARNING: if instead you want to set the ``updates`` to zero when the condition + is not met, you can use the ``conditionally_mask`` wrapper. + + Args: + inner: the inner transformation. + should_transform_fn: function takes in a ``step`` counter (array of shape [] + and dtype ``int32``), and returns a boolean array of shape []. If + ``forward_extra_args`` is set to True, any extra arguments are also + forwarded to the ``should_transform_fn`. + forward_extra_args: forward extra args to ``should_transform_fn``. + + Returns: + A new ``GradientTransformationExtraArgs``. + + .. versionadded:: 0.2.3 + """ + inner = base.with_extra_args_support(inner) + + def init_fn(params): + return ConditionallyTransformState( + inner_state=inner.init(params), step=jnp.zeros([], dtype=jnp.int32)) + + def update_fn(updates, state, params=None, **extra_args): + + def do_update(_): + return inner.update(updates, state.inner_state, params, **extra_args) + + def reject_update(_): + return updates, state.inner_state + + condition_kwargs = extra_args if forward_extra_args else {} + updates, new_inner_state = lax.cond( + should_transform_fn(state.step, **condition_kwargs), + do_update, reject_update, operand=None) + return updates, ConditionallyTransformState( + new_inner_state, numerics.safe_int32_increment(state.step)) + + return base.GradientTransformationExtraArgs(init_fn, update_fn) + + +class ConditionallyMaskState(NamedTuple): + step: chex.Array + inner_state: base.OptState + + +def conditionally_mask( + inner: base.GradientTransformation, + should_transform_fn: ConditionFn, + forward_extra_args: bool = False, +) -> base.GradientTransformationExtraArgs: + """Calls the inner update function only at certain steps. + + Creates a transformation wrapper that conditionally applies the inner gradient + transformation, and if the condition is not met, the updates are set to 0, + while the inner state is passed through unchanged. The behaviour is controlled + by a user specified function ``should_transform_fn`` that is called + by ``conditionally_transform`` passing as input a counter of the number of + times that the ``update`` function has been previously called, the user + specified function must returns a boolean controlling whether the inner + transformation should be called. + + WARNING: if instead you want to leave ``updates`` unchanged when the condition + is not met, you can use the ``conditionally_transform`` wrapper. + + Args: + inner: the inner transformation. + should_transform_fn: function takes in a step counter (array of shape [] + and dtype ``int32``), and returns a boolean array of shape []. If + ``forward_extra_args`` is set to True, any extra arguments are also + forwarded to the ``should_transform_fn`. + forward_extra_args: forward extra args to ``should_transform_fn``. + + Returns: + A new ``GradientTransformationExtraArgs``. + + .. versionadded:: 0.2.3 + """ + inner = base.with_extra_args_support(inner) + + def init_fn(params): + return ConditionallyMaskState( + step=jnp.zeros([], jnp.int32), inner_state=inner.init(params) + ) + + def update_fn(updates, state, params=None, **extra_args): + + def do_update(_): + return inner.update(updates, state.inner_state, params, **extra_args) + + def reject_update(_): + return otu.tree_zeros_like(updates), state.inner_state + + condition_kwargs = extra_args if forward_extra_args else {} + updates, new_inner_state = lax.cond( + should_transform_fn(state.step, **condition_kwargs), + do_update, reject_update, operand=None) + + return updates, ConditionallyMaskState( + step=numerics.safe_int32_increment(state.step), + inner_state=new_inner_state, + ) + + return base.GradientTransformationExtraArgs(init_fn, update_fn) + + +class ApplyIfFiniteState(NamedTuple): + """State of the `GradientTransformation` returned by `apply_if_finite`. + + Fields: + notfinite_count: Number of consecutive gradient updates containing an Inf or + a NaN. This number is reset to 0 whenever a gradient update without an Inf + or a NaN is done. + last_finite: Whether or not the last gradient update contained an Inf or a + NaN. + total_notfinite: Total number of gradient updates containing an Inf or + a NaN since this optimizer was initialised. This number is never reset. + inner_state: The state of the inner `GradientTransformation`. + """ + notfinite_count: Any + last_finite: Any + total_notfinite: Any + inner_state: Any + + +def apply_if_finite( + inner: base.GradientTransformation, + max_consecutive_errors: int +) -> base.GradientTransformation: + """A function that wraps an optimizer to make it robust to a few NaNs or Infs. + + The purpose of this function is to prevent any optimization to happen if the + gradients contain NaNs or Infs. That is, when a NaN or Inf is detected in the + gradients, the wrapped optimizer ignores that gradient update. If the NaNs or + Infs persist after a given number of updates, the wrapped optimizer gives up + and accepts the update. + + Args: + inner: Inner transformation to be wrapped. + max_consecutive_errors: Maximum number of consecutive gradient updates + containing NaNs or Infs that the wrapped optimizer will ignore. After + that many ignored updates, the optimizer will give up and accept. + + Returns: + New ``GradientTransformationExtraArgs``. + """ + + inner = base.with_extra_args_support(inner) + + def init(params): + return ApplyIfFiniteState( + notfinite_count=jnp.zeros([], jnp.int32), + last_finite=jnp.array(True, jnp.bool_), + total_notfinite=jnp.zeros([], jnp.int32), + inner_state=inner.init(params)) + + def update(updates, state, params=None, **extra_args): + inner_state = state.inner_state + flat_updates = jtu.tree_flatten(updates)[0] + isfinite = jnp.all( + jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates])) + notfinite_count = jnp.where( + isfinite, jnp.zeros([], jnp.int32), + numerics.safe_int32_increment(state.notfinite_count)) + + def do_update(_): + return inner.update(updates, inner_state, params, **extra_args) + + def reject_update(_): + return otu.tree_zeros_like(updates), inner_state + + updates, new_inner_state = lax.cond( + jnp.logical_or(isfinite, notfinite_count > max_consecutive_errors), + do_update, reject_update, operand=None) + + return updates, ApplyIfFiniteState( + notfinite_count=notfinite_count, + last_finite=isfinite, + total_notfinite=jnp.where( + isfinite, state.total_notfinite, + numerics.safe_int32_increment(state.total_notfinite)), + inner_state=new_inner_state) + + return base.GradientTransformationExtraArgs(init=init, update=update) diff --git a/optax/transforms/_constraining.py b/optax/transforms/_constraining.py new file mode 100644 index 00000000..1579e722 --- /dev/null +++ b/optax/transforms/_constraining.py @@ -0,0 +1,93 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradient transformations used to enforce specific constraints.""" + +from typing import Any, NamedTuple + +from jax import tree_util as jtu +import jax.numpy as jnp + +from optax._src import base + + +NonNegativeParamsState = base.EmptyState + + +def keep_params_nonnegative() -> base.GradientTransformation: + """Modifies the updates to keep parameters non-negative, i.e. >= 0. + + This transformation ensures that parameters after the update will be + larger than or equal to zero. + In a chain of transformations, this should be the last one. + + WARNING: the transformation expects input params to be non-negative. + When params is negative the transformed update will move them to 0. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return NonNegativeParamsState() + + def update_fn(updates, state, params): + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + + updates = jtu.tree_map( + lambda p, u: jnp.where((p + u) < 0., -p, u), params, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +class ZeroNansState(NamedTuple): + """Contains a tree. + + The entry `found_nan` has the same tree structure as that of the parameters. + Each leaf is a single boolean which contains True iff a NaN was detected in + the corresponding parameter array at the last call to `update`. + """ + found_nan: Any + + +def zero_nans() -> base.GradientTransformation: + """A transformation which replaces NaNs with 0. + + The state of the transformation has the same tree structure as that of the + parameters. Each leaf is a single boolean which contains True iff a NaN was + detected in the corresponding parameter array at the last call to ``update``. + This state is not used by the transformation internally, but lets users be + aware when NaNs have been zeroed out. + + Returns: + A `GradientTransformation`. + """ + + def init_fn(params): + return ZeroNansState( + found_nan=jtu.tree_map( + lambda p: jnp.array(False, dtype=jnp.bool_), params)) + + def update_fn(updates, opt_state, params=None): + del params, opt_state + opt_state = ZeroNansState( + found_nan=jtu.tree_map(lambda p: jnp.any(jnp.isnan(p)), updates)) + updates = jtu.tree_map( + lambda p: jnp.where(jnp.isnan(p), jnp.zeros_like(p), p), updates) + return updates, opt_state + + return base.GradientTransformation(init=init_fn, update=update_fn) diff --git a/optax/transforms/_layouts.py b/optax/transforms/_layouts.py new file mode 100644 index 00000000..9c8d5171 --- /dev/null +++ b/optax/transforms/_layouts.py @@ -0,0 +1,77 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrappers changing the layouts of the tensors that transforms operate on.""" + +from jax import tree_util as jtu +import jax.numpy as jnp +import numpy as np + +from optax._src import base + + +def flatten( + inner: base.GradientTransformation +) -> base.GradientTransformationExtraArgs: + """Flattens parameters and gradients for init and update of inner transform. + + This can reduce the overhead of performing many calculations on lots of small + variables, at the cost of slightly increased memory usage. + + Args: + inner: Inner transformation to flatten inputs for. + + Returns: + New ``GradientTransformationExtraArgs`` + """ + + inner = base.with_extra_args_support(inner) + + def _flatten(params): + """Flattens and concatenates all tensors in params to a single vector.""" + params, _ = jtu.tree_flatten(params) + return jnp.concatenate([jnp.reshape(param, [-1]) for param in params]) + + def _unflatten(updates, flat): + """Extracts tensors from flat, using the structure and shapes of params.""" + updates_flat, treedef = jtu.tree_flatten(updates) + offsets = [] + for update in updates_flat: + size = np.size(update) + if offsets: + offsets.append(size + offsets[-1]) + else: + offsets.append(size) + del offsets[-1] + flat_split = jnp.split(flat, offsets) + reshaped = [ + jnp.reshape(flat_update, update.shape) + for flat_update, update in zip(flat_split, updates_flat) + ] + return jtu.tree_unflatten(treedef, reshaped) + + def init_fn(params): + flat = _flatten(params) + return inner.init(flat) + + def update_fn(updates, state, params=None, **extra_args): + if params is not None: + params = _flatten(params) + updates_flat, state = inner.update( + _flatten(updates), state, params, **extra_args + ) + updates = _unflatten(updates, updates_flat) + return updates, state + + return base.GradientTransformationExtraArgs(init_fn, update_fn) diff --git a/optax/transforms/_masking.py b/optax/transforms/_masking.py new file mode 100644 index 00000000..1e0648a5 --- /dev/null +++ b/optax/transforms/_masking.py @@ -0,0 +1,136 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrappers that mask out part of the parameters when applying a transform.""" + +from typing import Any, Callable, NamedTuple, Union + +from jax import tree_util as jtu + +from optax._src import base +from optax.tree_utils import _state_utils + + +class MaskedState(NamedTuple): + """Maintains inner transform state for masked transformations.""" + inner_state: Any + + +class MaskedNode(NamedTuple): + """A node used to mask out unspecified parts of a tree. + + This node is ignored when mapping functions across the tree e.g. using + `jtu.tree_map` since it is a container without children. It can + therefore be used to mask out parts of a tree. + """ + + +def masked( + inner: base.GradientTransformation, + mask: Union[base.PyTree, Callable[[base.Params], base.PyTree]], + *, + mask_compatible_extra_args: bool = False, +) -> base.GradientTransformationExtraArgs: + """Mask updates so only some are transformed, the rest are passed through. + + For example, it is common to skip weight decay for BatchNorm scale and all + bias parameters. Since in many networks, these are the only 1D parameters, + you may for instance create a mask function to mask them out as follows:: + + mask_fn = lambda p: jtu.tree_map(lambda x: x.ndim != 1, p) + weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask_fn) + + You may alternatively create the mask pytree upfront:: + + mask = jtu.tree_map(lambda x: x.ndim != 1, params) + weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask) + + For the ``inner`` transform, state will only be stored for the parameters that + have a mask value of ``True``. + + Note that, when using ``tree_map_params``, it may be required to pass the + argument `is_leaf=lambda v: isinstance(v, optax.MaskedNode)`, if the tree + map needs to take additional arguments with the same shape as the original + input tree. + + Args: + inner: Inner transformation to mask. + mask: a PyTree with same structure as (or a prefix of) the params PyTree, or + a Callable that returns such a pytree given the params/updates. The leaves + should be booleans, ``True`` for leaves/subtrees you want to apply the + transformation to, and ``False`` for those you want to skip. The mask must + be static for the gradient transformation to be jit-compilable. + mask_compatible_extra_args: whether to also apply the same masking to + extra_arg fields with the same tree structure as params/updates. + + Returns: + New ``GradientTransformationExtraArgs`` wrapping ``inner``. + """ + inner = base.with_extra_args_support(inner) + + def mask_pytree(pytree, mask_tree): + return jtu.tree_map( + lambda m, p: p if m else MaskedNode(), mask_tree, pytree + ) + + # It is possible that `extra_args` of update_fn has pytrees with the same + # structure as params/updates, e.g. parameter tags. This function applies + # the mask to those pytrees. + def maybe_mask_values(pytree_dict, base_pytree, mask_tree): + base_structure = jtu.tree_structure(base_pytree) + + def _maybe_mask(pytree): + if mask_compatible_extra_args and ( + jtu.tree_structure(pytree) == base_structure): + return mask_pytree(pytree, mask_tree) + else: + return pytree + + return {k: _maybe_mask(v) for k, v in pytree_dict.items()} + + def init_fn(params): + # This is a workaround to make tree_map_params work with masking. + # The API of `masked` takes a mask on construction, instead of at init. + # This means that this gradient transformation can only work for parameter + # trees that match the shape of the mask. Technically this breaks the API + # of optax, and this causes tree_map_params to break. This is because + # tree_map_params calls init with a placeholder in order to detect copies + # of the parameter tree. As a (slightly ugly) workaround, we detect when + # the init is being called by tree_map_params, and pass the placeholder + # down without masking. This is safe, since tree_map_params does not impose + # any particular constraints on the shape of the parameter tree, as long + # as tree_map_params is being called on a tree with the correct structure. + # See wrappers_test for proof that this works! + if isinstance(params, _state_utils._ParamsPlaceholder): # pylint:disable=protected-access + return MaskedState(inner_state=inner.init(params)) + + mask_tree = mask(params) if callable(mask) else mask + masked_params = mask_pytree(params, mask_tree) + return MaskedState(inner_state=inner.init(masked_params)) + + def update_fn(updates, state, params=None, **extra_args): + mask_tree = mask(updates) if callable(mask) else mask + masked_extra_args = maybe_mask_values(extra_args, updates, mask_tree) + masked_updates = mask_pytree(updates, mask_tree) + masked_params = None if params is None else mask_pytree(params, mask_tree) + + new_masked_updates, new_inner_state = inner.update( + masked_updates, state.inner_state, masked_params, **masked_extra_args) + + new_updates = jtu.tree_map( + lambda m, new_u, old_u: new_u if m else old_u, + mask_tree, new_masked_updates, updates) + return new_updates, MaskedState(inner_state=new_inner_state) + + return base.GradientTransformationExtraArgs(init_fn, update_fn)