Skip to content

Commit

Permalink
Utility to extract value from pytree (and so from state)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615502669
  • Loading branch information
vroulet authored and OptaxDev committed Mar 13, 2024
1 parent 0f9ea47 commit f45b2eb
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 130 deletions.
20 changes: 15 additions & 5 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Scale gradient
.. autofunction:: scale_gradient

Value and grad from state
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: value_and_grad_from_state


Expand Down Expand Up @@ -98,7 +98,8 @@ Tree
tree_add
tree_add_scalar_mul
tree_div
tree_vdot
tree_get
tree_get_all_with_path
tree_l2_norm
tree_map_params
tree_mul
Expand All @@ -107,6 +108,7 @@ Tree
tree_scalar_mul
tree_sub
tree_sum
tree_vdot
tree_zeros_like

Tree add
Expand All @@ -121,9 +123,13 @@ Tree divide
~~~~~~~~~~~
.. autofunction:: tree_div

Tree inner product
~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_vdot
Fetch single value that match a given key
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_get

Fetch all values that match a given key
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_get_all_with_path

Tree l2 norm
~~~~~~~~~~~~
Expand Down Expand Up @@ -157,6 +163,10 @@ Tree sum
~~~~~~~~
.. autofunction:: tree_sum

Tree inner product
~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_vdot

Tree zeros like
~~~~~~~~~~~~~~~
.. autofunction:: tree_zeros_like
83 changes: 7 additions & 76 deletions optax/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from optax._src import base
from optax._src import linear_algebra
from optax._src import numerics
from optax.tree_utils import _state_utils


def tile_second_to_last_dim(a: chex.Array) -> chex.Array:
Expand Down Expand Up @@ -226,69 +227,6 @@ def _extract_fns_kwargs(
return fns_kwargs, remaining_kwargs


def _extract_from_state(
state: chex.ArrayTree,
key: str,
) -> list[tuple[Any, str, list[int]]]:
r"""Extract values from state.
Search in a state (potentially a pytree with :class:`optax.OptState` leaves
returned by :func:`optax.chain`) for a specific ``key``. That key may appear
more than once in the state (see example below). So this function returns a
list of all values corresponding to the key with the name of the associated
state and the path to the state in the pytree of states.
Examples:
>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> base_opt = optax.chain(
... optax.adam(learning_rate=1.),
... optax.adam(learning_rate=1.)
... )
>>> solver = optax.chain(optax.adam(learning_rate=1.), base_opt)
>>> state = solver.init(params)
>>> values_found = _extract_from_state(state, 'count')
>>> print(len(values_found))
3
>>> count, state_name, path_to_state = values_found[0]
>>> print(count, state_name, path_to_state)
0 ScaleByAdamState [0, 0]
>>> state_with_entry = state
>>> for i in path_to_state:
... state_with_entry = state_with_entry[i]
>>> print(state_with_entry.__class__.__name__ == state_name)
True
>>> print(getattr(state_with_entry, 'count') == count)
True
Args:
state: state to search in. It can be an ``optax.OptState``
or a pytree of ``optax.OptState`` returned by, e.g.,
``optax.chain(...).init(params)``.
key: keyword to search state for.
Returns:
values
list of tuples where each tuple is of the form (``value``, ``state_name``,
``path_to_state``). Here ``value`` is one entry of the state that
corresponds to the ``key``, ``state_name`` is the name of the state where
this value has been found, and ``path_to_state`` is a sequence of indexes
that lead to the state where the value has been found (see example).
"""
values_found = []
tree_flatten, _ = jax.tree_util.tree_flatten_with_path(state)
for path, val in tree_flatten:
if getattr(path[-1], 'name') == key:
path_to_state = [path[i].idx for i in range(len(path)-1)]
substate = state
for i in path_to_state:
substate = substate[i]
state_name = substate.__class__.__name__
values_found.append((val, state_name, path_to_state))
return values_found


def value_and_grad_from_state(
value_fn: Callable[..., Union[jax.Array, float]],
) -> Callable[..., tuple[Union[float, jax.Array], base.Updates]]:
Expand Down Expand Up @@ -346,20 +284,13 @@ def _value_and_grad(
state: base.OptState,
**fn_kwargs: dict[str, Any],
):
values_found = _extract_from_state(state, 'value')
grads_found = _extract_from_state(state, 'grad')
if len(values_found) > 1 or len(grads_found) > 1:
raise ValueError('Found multiple values or gradients.')
elif not values_found or not grads_found:
raise ValueError('Found no value or no gradient.')
else:
value = values_found[0][0]
grad = grads_found[0][0]
if grad is None:
value = _state_utils.tree_get(state, 'value')
grad = _state_utils.tree_get(state, 'grad')
if (value is None) or (grad is None):
raise ValueError(
'Gradient is None. Make sure that the gradient is stored in the '
'state, e.g., using store_grad=True in the definition of, e.g. '
'optax.scale_by_backtracking_linesearch.'
'Value or gradient not found in the state. '
'Make sure that these values are stored in the state by the '
'optimizer.'
)
value, grad = jax.lax.cond(
(~jnp.isinf(value)) & (~jnp.isnan(value)),
Expand Down
43 changes: 3 additions & 40 deletions optax/_src/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,43 +258,6 @@ def test_canonicalize_dtype(self, dtype, expected_dtype):
canonical = utils.canonicalize_dtype(dtype)
self.assertIs(canonical, expected_dtype)

def test_extract_from_state(self):
params = jnp.array([1.0, 2.0, 3.0])

def check_values_found(state, values_found):
for value, state_name, path_to_state in values_found:
state_with_value = state
for i in path_to_state:
state_with_value = state_with_value[i]
self.assertEqual(getattr(state_with_value, key), value)
self.assertEqual(state_with_value.__class__.__name__, state_name)

key = 'count'
# Single value of 'count', simple OptState
opt = transform.scale_by_adam()
state = opt.init(params)
values_found = utils._extract_from_state(state, key)
self.assertLen(values_found, 1)
check_values_found(state, values_found)

# No value of 'count'
opt = alias.sgd(learning_rate=1.0)
state = opt.init(params)
values_found = utils._extract_from_state(state, key)
self.assertEmpty(values_found)

# Several values of 'count', state defined by chain
opt = combine.chain(
alias.adam(learning_rate=1.0),
combine.chain(
alias.adam(learning_rate=1.0), alias.adam(learning_rate=1.0)
),
)
state = opt.init(params)
values_found = utils._extract_from_state(state, key)
self.assertLen(values_found, 3)
check_values_found(state, values_found)

@chex.variants(
with_jit=True,
without_jit=True,
Expand Down Expand Up @@ -323,7 +286,7 @@ def fn(x):
linesearch.scale_by_backtracking_linesearch(max_backtracking_steps=15),
)
state = opt.init(params)
self.assertRaises(ValueError, value_and_grad, params, state=state)
self.assertRaises(KeyError, value_and_grad, params, state=state)

# It should work efficiently when the linesearch stores the gradient
opt = combine.chain(
Expand All @@ -341,15 +304,15 @@ def fn(x):
params = jax.block_until_ready(params)

def false_fn(_):
return 1.
return 1.0

false_value_and_grad_ = utils.value_and_grad_from_state(false_fn)
false_value_and_grad = self.variant(false_value_and_grad_)

# At the second step we should not evaluate the function
# so in this case it should not return the output of false_fn
value, _ = false_value_and_grad(params, state=state)
self.assertNotEqual(value, 1.)
self.assertNotEqual(value, 1.0)

def test_extract_fns_kwargs(self):
def fn1(a, b):
Expand Down
3 changes: 3 additions & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
# ==============================================================================
"""The tree_utils sub-package."""

from optax.tree_utils._state_utils import tree_get
from optax.tree_utils._state_utils import tree_get_all_with_path
from optax.tree_utils._state_utils import tree_map_params

from optax.tree_utils._tree_math import tree_add
from optax.tree_utils._tree_math import tree_add_scalar_mul
from optax.tree_utils._tree_math import tree_div
Expand Down
122 changes: 121 additions & 1 deletion optax/tree_utils/_state_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,21 @@
"""Tools for mapping over optimizer states."""

import typing
from typing import Any, Callable, Optional, Protocol, Union, cast
from typing import Any, Callable, Hashable, Optional, Protocol, Union, cast

import jax
from optax._src import base

_JaxKeyType = Union[
int,
str,
Hashable,
jax.tree_util.SequenceKey,
jax.tree_util.DictKey,
jax.tree_util.FlattenedIndexKey,
jax.tree_util.GetAttrKey,
]


@typing.runtime_checkable
class Initable(Protocol):
Expand Down Expand Up @@ -107,6 +117,116 @@ def map_params(maybe_placeholder_value, value):
)


def _convert_jax_key_fn(key: _JaxKeyType) -> Union[int, str]:
"""Convert a key returned by `jax.tree_util` to a usual type."""
if isinstance(key, (str, int)):
return key # int | str.
if isinstance(key, jax.tree_util.SequenceKey):
return key.idx # int.
if isinstance(key, jax.tree_util.DictKey):
if isinstance(key.key, (str, int)):
return key.key
raise KeyError("Hashable keys not supported")
if isinstance(key, jax.tree_util.FlattenedIndexKey):
return key.key # int.
if isinstance(key, jax.tree_util.GetAttrKey):
return key.name # str.
raise KeyError(f"Jax tree key '{key}' of type '{type(key)}' not valid.")


def tree_get_all_with_path(
tree: base.PyTree,
key: Any,
) -> list[tuple[jax._src.tree_util.KeyPath, Any]]:
r"""Extract values from leaves of a pytree matching a given key.
Search in the leaves of a pytree for a specific ``key`` (which can be a key
from a dictionary or a name from a NamedTuple for example).
That key or name may appear more than once in the pytree. So this function
returns a list of all values corresponding to ``key`` with the path to
that value.
Examples:
>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> base_opt = optax.chain(
... optax.adam(learning_rate=1.),
... optax.adam(learning_rate=1.)
... )
>>> solver = optax.chain(optax.adam(learning_rate=1.), base_opt)
>>> state = solver.init(params)
>>> values_found = optax.tree_utils.tree_get_all_with_path(state, 'count')
>>> print(len(values_found))
3
>>> path_to_count, count = values_found[0]
>>> print(path_to_count, count)
(SequenceKey(idx=0), SequenceKey(idx=0), GetAttrKey(name='count')) 0
.. seealso:: :func:`optax.tree_utils.tree_get`
Args:
tree: tree to search in.
key: keyword or name to search in tree for.
Returns:
values_with_path
list of tuples where each tuple is of the form
(``path_to_value``, ``value``). Here ``value`` is one entry of the state
that corresponds to the ``key``, and ``path_to_value`` is a path returned
by :func:`jax.tree_util.tree_flatten_with_path`.
"""
values_with_path_found = []
tree_flatten_with_path, _ = jax.tree_util.tree_flatten_with_path(tree)
for path, val in tree_flatten_with_path:
key_leaf = _convert_jax_key_fn(path[-1])
if key_leaf == key:
values_with_path_found.append((path, val))
return values_with_path_found


def tree_get(tree: base.PyTree, key: Any, default: Optional[Any] = None) -> Any:
"""Extract a value from leaves of a pytree matching a given key.
Search in the leaves of a pytree for a specific ``key`` (which can be a key
from a dictionary or a name from a NamedTuple).
If no leaves in the tree have the required ``key`` returns ``default``.
.. seealso:: :func:`optax.tree_utils.tree_get_all_with_path`
Examples:
>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> solver = optax.inject_hyperparams(optax.adam)(learning_rate=1.)
>>> state = solver.init(params)
>>> lr = optax.tree_utils.tree_get(state, 'learning_rate')
>>> print(lr)
1.0
Args:
tree: tree to search in.
key: keyword or name to search in tree for.
default: default value to return if no leaves in the tree matched the given
``key``.
Returns:
value
value in the tree matching the given ``key``. If none are
found return default value. If multiple are found raises an error.
Raises:
KeyError: If multiple values of ``key`` are found in ``tree``.
"""
values_with_path_found = tree_get_all_with_path(tree, key)
if len(values_with_path_found) > 1:
raise KeyError(f"Found multiple values for '{key}' in {tree}.")
elif not values_with_path_found:
return default
else:
return values_with_path_found[0][1]


@jax.tree_util.register_pytree_node_class
class _ParamsPlaceholder:

Expand Down
Loading

0 comments on commit f45b2eb

Please sign in to comment.