From f45b2eb82ffdb8c22d57923b5039000539bed4bc Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Wed, 13 Mar 2024 12:14:35 -0700 Subject: [PATCH] Utility to extract value from pytree (and so from state) PiperOrigin-RevId: 615502669 --- docs/api/utilities.rst | 20 +++-- optax/_src/utils.py | 83 ++---------------- optax/_src/utils_test.py | 43 +-------- optax/tree_utils/__init__.py | 3 + optax/tree_utils/_state_utils.py | 122 +++++++++++++++++++++++++- optax/tree_utils/_state_utils_test.py | 106 ++++++++++++++++++++-- 6 files changed, 247 insertions(+), 130 deletions(-) diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index 2a4ad3ca..fbc1b497 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -15,7 +15,7 @@ Scale gradient .. autofunction:: scale_gradient Value and grad from state -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: value_and_grad_from_state @@ -98,7 +98,8 @@ Tree tree_add tree_add_scalar_mul tree_div - tree_vdot + tree_get + tree_get_all_with_path tree_l2_norm tree_map_params tree_mul @@ -107,6 +108,7 @@ Tree tree_scalar_mul tree_sub tree_sum + tree_vdot tree_zeros_like Tree add @@ -121,9 +123,13 @@ Tree divide ~~~~~~~~~~~ .. autofunction:: tree_div -Tree inner product -~~~~~~~~~~~~~~~~~~ -.. autofunction:: tree_vdot +Fetch single value that match a given key +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: tree_get + +Fetch all values that match a given key +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: tree_get_all_with_path Tree l2 norm ~~~~~~~~~~~~ @@ -157,6 +163,10 @@ Tree sum ~~~~~~~~ .. autofunction:: tree_sum +Tree inner product +~~~~~~~~~~~~~~~~~~ +.. autofunction:: tree_vdot + Tree zeros like ~~~~~~~~~~~~~~~ .. autofunction:: tree_zeros_like diff --git a/optax/_src/utils.py b/optax/_src/utils.py index bc0a111c..55038060 100644 --- a/optax/_src/utils.py +++ b/optax/_src/utils.py @@ -26,6 +26,7 @@ from optax._src import base from optax._src import linear_algebra from optax._src import numerics +from optax.tree_utils import _state_utils def tile_second_to_last_dim(a: chex.Array) -> chex.Array: @@ -226,69 +227,6 @@ def _extract_fns_kwargs( return fns_kwargs, remaining_kwargs -def _extract_from_state( - state: chex.ArrayTree, - key: str, -) -> list[tuple[Any, str, list[int]]]: - r"""Extract values from state. - - Search in a state (potentially a pytree with :class:`optax.OptState` leaves - returned by :func:`optax.chain`) for a specific ``key``. That key may appear - more than once in the state (see example below). So this function returns a - list of all values corresponding to the key with the name of the associated - state and the path to the state in the pytree of states. - - Examples: - >>> import jax.numpy as jnp - >>> import optax - >>> params = jnp.array([1., 2., 3.]) - >>> base_opt = optax.chain( - ... optax.adam(learning_rate=1.), - ... optax.adam(learning_rate=1.) - ... ) - >>> solver = optax.chain(optax.adam(learning_rate=1.), base_opt) - >>> state = solver.init(params) - >>> values_found = _extract_from_state(state, 'count') - >>> print(len(values_found)) - 3 - >>> count, state_name, path_to_state = values_found[0] - >>> print(count, state_name, path_to_state) - 0 ScaleByAdamState [0, 0] - >>> state_with_entry = state - >>> for i in path_to_state: - ... state_with_entry = state_with_entry[i] - >>> print(state_with_entry.__class__.__name__ == state_name) - True - >>> print(getattr(state_with_entry, 'count') == count) - True - - Args: - state: state to search in. It can be an ``optax.OptState`` - or a pytree of ``optax.OptState`` returned by, e.g., - ``optax.chain(...).init(params)``. - key: keyword to search state for. - - Returns: - values - list of tuples where each tuple is of the form (``value``, ``state_name``, - ``path_to_state``). Here ``value`` is one entry of the state that - corresponds to the ``key``, ``state_name`` is the name of the state where - this value has been found, and ``path_to_state`` is a sequence of indexes - that lead to the state where the value has been found (see example). - """ - values_found = [] - tree_flatten, _ = jax.tree_util.tree_flatten_with_path(state) - for path, val in tree_flatten: - if getattr(path[-1], 'name') == key: - path_to_state = [path[i].idx for i in range(len(path)-1)] - substate = state - for i in path_to_state: - substate = substate[i] - state_name = substate.__class__.__name__ - values_found.append((val, state_name, path_to_state)) - return values_found - - def value_and_grad_from_state( value_fn: Callable[..., Union[jax.Array, float]], ) -> Callable[..., tuple[Union[float, jax.Array], base.Updates]]: @@ -346,20 +284,13 @@ def _value_and_grad( state: base.OptState, **fn_kwargs: dict[str, Any], ): - values_found = _extract_from_state(state, 'value') - grads_found = _extract_from_state(state, 'grad') - if len(values_found) > 1 or len(grads_found) > 1: - raise ValueError('Found multiple values or gradients.') - elif not values_found or not grads_found: - raise ValueError('Found no value or no gradient.') - else: - value = values_found[0][0] - grad = grads_found[0][0] - if grad is None: + value = _state_utils.tree_get(state, 'value') + grad = _state_utils.tree_get(state, 'grad') + if (value is None) or (grad is None): raise ValueError( - 'Gradient is None. Make sure that the gradient is stored in the ' - 'state, e.g., using store_grad=True in the definition of, e.g. ' - 'optax.scale_by_backtracking_linesearch.' + 'Value or gradient not found in the state. ' + 'Make sure that these values are stored in the state by the ' + 'optimizer.' ) value, grad = jax.lax.cond( (~jnp.isinf(value)) & (~jnp.isnan(value)), diff --git a/optax/_src/utils_test.py b/optax/_src/utils_test.py index bfcfcee2..b77a550e 100644 --- a/optax/_src/utils_test.py +++ b/optax/_src/utils_test.py @@ -258,43 +258,6 @@ def test_canonicalize_dtype(self, dtype, expected_dtype): canonical = utils.canonicalize_dtype(dtype) self.assertIs(canonical, expected_dtype) - def test_extract_from_state(self): - params = jnp.array([1.0, 2.0, 3.0]) - - def check_values_found(state, values_found): - for value, state_name, path_to_state in values_found: - state_with_value = state - for i in path_to_state: - state_with_value = state_with_value[i] - self.assertEqual(getattr(state_with_value, key), value) - self.assertEqual(state_with_value.__class__.__name__, state_name) - - key = 'count' - # Single value of 'count', simple OptState - opt = transform.scale_by_adam() - state = opt.init(params) - values_found = utils._extract_from_state(state, key) - self.assertLen(values_found, 1) - check_values_found(state, values_found) - - # No value of 'count' - opt = alias.sgd(learning_rate=1.0) - state = opt.init(params) - values_found = utils._extract_from_state(state, key) - self.assertEmpty(values_found) - - # Several values of 'count', state defined by chain - opt = combine.chain( - alias.adam(learning_rate=1.0), - combine.chain( - alias.adam(learning_rate=1.0), alias.adam(learning_rate=1.0) - ), - ) - state = opt.init(params) - values_found = utils._extract_from_state(state, key) - self.assertLen(values_found, 3) - check_values_found(state, values_found) - @chex.variants( with_jit=True, without_jit=True, @@ -323,7 +286,7 @@ def fn(x): linesearch.scale_by_backtracking_linesearch(max_backtracking_steps=15), ) state = opt.init(params) - self.assertRaises(ValueError, value_and_grad, params, state=state) + self.assertRaises(KeyError, value_and_grad, params, state=state) # It should work efficiently when the linesearch stores the gradient opt = combine.chain( @@ -341,7 +304,7 @@ def fn(x): params = jax.block_until_ready(params) def false_fn(_): - return 1. + return 1.0 false_value_and_grad_ = utils.value_and_grad_from_state(false_fn) false_value_and_grad = self.variant(false_value_and_grad_) @@ -349,7 +312,7 @@ def false_fn(_): # At the second step we should not evaluate the function # so in this case it should not return the output of false_fn value, _ = false_value_and_grad(params, state=state) - self.assertNotEqual(value, 1.) + self.assertNotEqual(value, 1.0) def test_extract_fns_kwargs(self): def fn1(a, b): diff --git a/optax/tree_utils/__init__.py b/optax/tree_utils/__init__.py index 960ed66d..0c12e353 100644 --- a/optax/tree_utils/__init__.py +++ b/optax/tree_utils/__init__.py @@ -14,7 +14,10 @@ # ============================================================================== """The tree_utils sub-package.""" +from optax.tree_utils._state_utils import tree_get +from optax.tree_utils._state_utils import tree_get_all_with_path from optax.tree_utils._state_utils import tree_map_params + from optax.tree_utils._tree_math import tree_add from optax.tree_utils._tree_math import tree_add_scalar_mul from optax.tree_utils._tree_math import tree_div diff --git a/optax/tree_utils/_state_utils.py b/optax/tree_utils/_state_utils.py index 36c5c1f0..dde675f5 100644 --- a/optax/tree_utils/_state_utils.py +++ b/optax/tree_utils/_state_utils.py @@ -15,11 +15,21 @@ """Tools for mapping over optimizer states.""" import typing -from typing import Any, Callable, Optional, Protocol, Union, cast +from typing import Any, Callable, Hashable, Optional, Protocol, Union, cast import jax from optax._src import base +_JaxKeyType = Union[ + int, + str, + Hashable, + jax.tree_util.SequenceKey, + jax.tree_util.DictKey, + jax.tree_util.FlattenedIndexKey, + jax.tree_util.GetAttrKey, +] + @typing.runtime_checkable class Initable(Protocol): @@ -107,6 +117,116 @@ def map_params(maybe_placeholder_value, value): ) +def _convert_jax_key_fn(key: _JaxKeyType) -> Union[int, str]: + """Convert a key returned by `jax.tree_util` to a usual type.""" + if isinstance(key, (str, int)): + return key # int | str. + if isinstance(key, jax.tree_util.SequenceKey): + return key.idx # int. + if isinstance(key, jax.tree_util.DictKey): + if isinstance(key.key, (str, int)): + return key.key + raise KeyError("Hashable keys not supported") + if isinstance(key, jax.tree_util.FlattenedIndexKey): + return key.key # int. + if isinstance(key, jax.tree_util.GetAttrKey): + return key.name # str. + raise KeyError(f"Jax tree key '{key}' of type '{type(key)}' not valid.") + + +def tree_get_all_with_path( + tree: base.PyTree, + key: Any, +) -> list[tuple[jax._src.tree_util.KeyPath, Any]]: + r"""Extract values from leaves of a pytree matching a given key. + + Search in the leaves of a pytree for a specific ``key`` (which can be a key + from a dictionary or a name from a NamedTuple for example). + That key or name may appear more than once in the pytree. So this function + returns a list of all values corresponding to ``key`` with the path to + that value. + + Examples: + >>> import jax.numpy as jnp + >>> import optax + >>> params = jnp.array([1., 2., 3.]) + >>> base_opt = optax.chain( + ... optax.adam(learning_rate=1.), + ... optax.adam(learning_rate=1.) + ... ) + >>> solver = optax.chain(optax.adam(learning_rate=1.), base_opt) + >>> state = solver.init(params) + >>> values_found = optax.tree_utils.tree_get_all_with_path(state, 'count') + >>> print(len(values_found)) + 3 + >>> path_to_count, count = values_found[0] + >>> print(path_to_count, count) + (SequenceKey(idx=0), SequenceKey(idx=0), GetAttrKey(name='count')) 0 + + .. seealso:: :func:`optax.tree_utils.tree_get` + + Args: + tree: tree to search in. + key: keyword or name to search in tree for. + + Returns: + values_with_path + list of tuples where each tuple is of the form + (``path_to_value``, ``value``). Here ``value`` is one entry of the state + that corresponds to the ``key``, and ``path_to_value`` is a path returned + by :func:`jax.tree_util.tree_flatten_with_path`. + """ + values_with_path_found = [] + tree_flatten_with_path, _ = jax.tree_util.tree_flatten_with_path(tree) + for path, val in tree_flatten_with_path: + key_leaf = _convert_jax_key_fn(path[-1]) + if key_leaf == key: + values_with_path_found.append((path, val)) + return values_with_path_found + + +def tree_get(tree: base.PyTree, key: Any, default: Optional[Any] = None) -> Any: + """Extract a value from leaves of a pytree matching a given key. + + Search in the leaves of a pytree for a specific ``key`` (which can be a key + from a dictionary or a name from a NamedTuple). + If no leaves in the tree have the required ``key`` returns ``default``. + + .. seealso:: :func:`optax.tree_utils.tree_get_all_with_path` + + Examples: + >>> import jax.numpy as jnp + >>> import optax + >>> params = jnp.array([1., 2., 3.]) + >>> solver = optax.inject_hyperparams(optax.adam)(learning_rate=1.) + >>> state = solver.init(params) + >>> lr = optax.tree_utils.tree_get(state, 'learning_rate') + >>> print(lr) + 1.0 + + Args: + tree: tree to search in. + key: keyword or name to search in tree for. + default: default value to return if no leaves in the tree matched the given + ``key``. + + Returns: + value + value in the tree matching the given ``key``. If none are + found return default value. If multiple are found raises an error. + + Raises: + KeyError: If multiple values of ``key`` are found in ``tree``. + """ + values_with_path_found = tree_get_all_with_path(tree, key) + if len(values_with_path_found) > 1: + raise KeyError(f"Found multiple values for '{key}' in {tree}.") + elif not values_with_path_found: + return default + else: + return values_with_path_found[0][1] + + @jax.tree_util.register_pytree_node_class class _ParamsPlaceholder: diff --git a/optax/tree_utils/_state_utils_test.py b/optax/tree_utils/_state_utils_test.py index ce38bd0e..9597f99c 100644 --- a/optax/tree_utils/_state_utils_test.py +++ b/optax/tree_utils/_state_utils_test.py @@ -21,7 +21,7 @@ import chex import jax import jax.numpy as jnp - +import jax.tree_util as jtu from optax._src import alias from optax._src import base from optax._src import combine @@ -127,7 +127,6 @@ def test_dict_based_optimizers(self): self.assertEqual(expected, opt_state_sharding_spec) def test_state_chex_dataclass(self): - @chex.dataclass class Foo: count: int @@ -141,7 +140,7 @@ def init(params): } state = init(params) - state = _state_utils.tree_map_params(init, lambda v: v+1, state) + state = _state_utils.tree_map_params(init, lambda v: v + 1, state) state = cast(Foo, state) self.assertEqual(int(state.count), 0) @@ -196,11 +195,11 @@ def test_inject_hparams(self): params = _fake_params() state = opt.init(params) - state = _state_utils.tree_map_params(opt, lambda v: v+1, state) + state = _state_utils.tree_map_params(opt, lambda v: v + 1, state) state = cast(_inject.InjectHyperparamsState, state) self.assertEqual(1e-3, state.hyperparams['learning_rate']) - params_plus_one = jax.tree_map(lambda v: v+1, params) + params_plus_one = jax.tree_map(lambda v: v + 1, params) mu = getattr(state.inner_state[0], 'mu') chex.assert_trees_all_close(mu, params_plus_one) @@ -227,8 +226,7 @@ def test_map_non_params_to_none(self): state = opt.init(params) state = _state_utils.tree_map_params( - opt, - lambda v: 1, state, transform_non_params=lambda _: None + opt, lambda v: 1, state, transform_non_params=lambda _: None ) expected = ( @@ -238,10 +236,102 @@ def test_map_non_params_to_none(self): nu={'a': 1}, ), transform.ScaleByScheduleState( # pytype:disable=wrong-arg-types - count=None), + count=None + ), ) self.assertEqual(state, expected) + def test_tree_get_all_with_path(self): + params = jnp.array([1.0, 2.0, 3.0]) + + with self.subTest('Test with single value in state'): + key = 'count' + opt = transform.scale_by_adam() + state = opt.init(params) + values_found = _state_utils.tree_get_all_with_path(state, key) + expected_result = [((jtu.GetAttrKey(name='count'),), jnp.array(0.0))] + self.assertEqual(values_found, expected_result) + + with self.subTest('Test with no value in state'): + key = 'count' + opt = alias.sgd(learning_rate=1.0) + state = opt.init(params) + values_found = _state_utils.tree_get_all_with_path(state, key) + self.assertEmpty(values_found) + + with self.subTest('Test with multiple values in state'): + key = 'learning_rate' + opt = combine.chain( + _inject.inject_hyperparams(alias.sgd)(learning_rate=1.0), + combine.chain( + alias.adam(learning_rate=1.0), + _inject.inject_hyperparams(alias.adam)(learning_rate=1e-4), + ), + ) + state = opt.init(params) + values_found = _state_utils.tree_get_all_with_path(state, key) + expected_result = [ + ( + ( + jtu.SequenceKey(idx=0), + jtu.GetAttrKey(name='hyperparams'), + jtu.DictKey(key='learning_rate'), + ), + jnp.array(1.0), + ), + ( + ( + jtu.SequenceKey(idx=1), + jtu.SequenceKey(idx=1), + jtu.GetAttrKey(name='hyperparams'), + jtu.DictKey(key='learning_rate'), + ), + jnp.array(1e-4), + ), + ] + self.assertEqual(values_found, expected_result) + + def test_tree_get(self): + params = jnp.array([1.0, 2.0, 3.0]) + + with self.subTest('Test with unique value matching the key'): + solver = _inject.inject_hyperparams(alias.sgd)(learning_rate=42.0) + state = solver.init(params) + lr = _state_utils.tree_get(state, 'learning_rate') + self.assertEqual(lr, 42.0) + + with self.subTest('Test with no value matching the key'): + solver = _inject.inject_hyperparams(alias.sgd)(learning_rate=42.0) + state = solver.init(params) + ema = _state_utils.tree_get(state, 'ema') + self.assertIsNone(ema) + ema = _state_utils.tree_get(state, 'ema', default=7.0) + self.assertEqual(ema, 7.0) + + with self.subTest('Test with multiple values matching the key'): + solver = combine.chain( + _inject.inject_hyperparams(alias.sgd)(learning_rate=42.0), + _inject.inject_hyperparams(alias.sgd)(learning_rate=42.0), + ) + state = solver.init(params) + self.assertRaises(KeyError, _state_utils.tree_get, state, 'learning_rate') + + with self.subTest('Test jitted tree_get'): + opt = _inject.inject_hyperparams(alias.sgd)( + learning_rate=lambda x: 1/(x+1) + ) + state = opt.init(params) + + @jax.jit + def get_learning_rate(state): + return _state_utils.tree_get(state, 'learning_rate') + + for i in range(4): + # we simply update state, we don't care about updates. + _, state = opt.update(params, state) + lr = get_learning_rate(state) + self.assertEqual(lr, 1/(i+1)) + def _fake_params(): return {