Skip to content

Commit

Permalink
adding tree utils for generating trees with random values and averagi…
Browse files Browse the repository at this point in the history
…ng across trees.

PiperOrigin-RevId: 612828927
  • Loading branch information
q-berthet authored and OptaxDev committed Mar 5, 2024
1 parent 6de95bf commit 927994c
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 8 deletions.
5 changes: 5 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ Tree
tree_map_params
tree_mul
tree_ones_like
tree_random_like
tree_scalar_mul
tree_sub
tree_sum
Expand Down Expand Up @@ -140,6 +141,10 @@ Tree ones like
~~~~~~~~~~~~~~
.. autofunction:: tree_ones_like

Tree with random values
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_random_like

Tree scalar multiply
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_scalar_mul
Expand Down
2 changes: 1 addition & 1 deletion optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
"""The tree_utils sub-package."""

from optax.tree_utils._state_utils import tree_map_params

from optax.tree_utils._tree_math import tree_add
from optax.tree_utils._tree_math import tree_add_scalar_mul
from optax.tree_utils._tree_math import tree_div
from optax.tree_utils._tree_math import tree_l2_norm
from optax.tree_utils._tree_math import tree_mul
from optax.tree_utils._tree_math import tree_ones_like
from optax.tree_utils._tree_math import tree_random_like
from optax.tree_utils._tree_math import tree_scalar_mul
from optax.tree_utils._tree_math import tree_sub
from optax.tree_utils._tree_math import tree_sum
Expand Down
59 changes: 54 additions & 5 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,31 @@

import functools
import operator
from typing import Any, Union
from typing import Any, Callable, Union

import chex
import jax
from jax import tree_util as jtu
import jax.numpy as jnp
from optax._src import base


def tree_add(tree_x: Any, tree_y: Any) -> Any:
r"""Add two pytrees.
Shape = base.Shape


def tree_add(tree_x: Any, tree_y: Any, *other_trees: Any) -> Any:
r"""Add two (or more) pytrees.
Args:
tree_x: first pytree.
tree_y: second pytree.
*other_trees: optional other trees to add
Returns:
the sum of the two pytrees.
the sum of the two (or more) pytrees.
"""
return jtu.tree_map(operator.add, tree_x, tree_y)
trees = [tree_x, tree_y, *other_trees]
return jtu.tree_map(lambda *leaves: sum(leaves), *trees)


def tree_sub(tree_x: Any, tree_y: Any) -> Any:
Expand Down Expand Up @@ -197,3 +203,46 @@ def tree_ones_like(tree: Any) -> Any:
an all-ones tree with the same structure as ``tree``.
"""
return jtu.tree_map(jnp.ones_like, tree)


def _tree_rng_keys_split(
rng_key: chex.PRNGKey, target_tree: chex.ArrayTree
) -> chex.ArrayTree:
"""Split keys to match structure of target tree.
Args:
rng_key: the key to split.
target_tree: the tree whose structure to match.
Returns:
a tree of rng keys.
"""
tree_def = jtu.tree_structure(target_tree)
keys = jax.random.split(rng_key, tree_def.num_leaves)
return jtu.tree_unflatten(tree_def, keys)


def tree_random_like(
rng_key: chex.PRNGKey,
target_tree: chex.ArrayTree,
sampler: Callable[
[chex.PRNGKey, Shape], chex.Array
] = jax.random.normal,
) -> chex.ArrayTree:
"""Create tree with normal random entries of the same shape as target tree.
Args:
rng_key: key for the random number generator.
target_tree: the tree whose structure to match. Leaves must be arrays.
sampler: the noise sampling function
Returns:
a random tree with the same structure as ``target_tree``, whose leaves have
distribution ``sampler``.
"""
keys_tree = _tree_rng_keys_split(rng_key, target_tree)
return jtu.tree_map(
lambda l, k: sampler(k, l.shape),
target_tree,
keys_tree,
)
56 changes: 54 additions & 2 deletions optax/tree_utils/_tree_math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
"""Tests for optax.tree_utils."""

from absl.testing import absltest

import chex
import jax
from jax import tree_util as jtu
import jax.numpy as jnp
import numpy as np

from optax import tree_utils as tu


Expand All @@ -30,6 +30,8 @@ def setUp(self):
super().setUp()
rng = np.random.RandomState(0)

self.rng_jax = jax.random.PRNGKey(0)

self.tree_a = (rng.randn(20, 10) + 1j * rng.randn(20, 10), rng.randn(20))
self.tree_b = (rng.randn(20, 10), rng.randn(20))

Expand All @@ -39,6 +41,9 @@ def setUp(self):
self.array_a = rng.randn(20) + 1j * rng.randn(20)
self.array_b = rng.randn(20)

self.tree_a_dict_jax = jtu.tree_map(jnp.array, self.tree_a_dict)
self.tree_b_dict_jax = jtu.tree_map(jnp.array, self.tree_b_dict)

def test_tree_add(self):
expected = self.array_a + self.array_b
got = tu.tree_add(self.array_a, self.array_b)
Expand Down Expand Up @@ -145,5 +150,52 @@ def test_tree_ones_like(self):
got = tu.tree_ones_like(self.tree_a)
chex.assert_trees_all_close(expected, got)

def test_add_multiple_trees(self):
"""Test adding more than 2 trees with tree_add."""
trees = [self.tree_a_dict_jax, self.tree_a_dict_jax, self.tree_a_dict_jax]
expected = tu.tree_scalar_mul(3., self.tree_a_dict_jax)
got = tu.tree_add(*trees)
chex.assert_trees_all_close(expected, got)

def test_random_like_tree(self, eps=1e-6):
"""Test for `tree_random_like`.
Args:
eps: amount of noise.
Tests that `tree_random_like` generates a tree of the proper structure,
that it can be added to a target tree with a small multiplicative factor
without errors, and that the resulting addition is close to the original.
"""
rand_tree_a = tu.tree_random_like(self.rng_jax, self.tree_a)
rand_tree_b = tu.tree_random_like(self.rng_jax, self.tree_b)
rand_tree_a_dict = tu.tree_random_like(self.rng_jax, self.tree_a_dict_jax)
rand_tree_b_dict = tu.tree_random_like(self.rng_jax, self.tree_b_dict_jax)
rand_array_a = tu.tree_random_like(self.rng_jax, self.array_a)
rand_array_b = tu.tree_random_like(self.rng_jax, self.array_b)
sum_tree_a = tu.tree_add_scalar_mul(self.tree_a, eps, rand_tree_a)
sum_tree_b = tu.tree_add_scalar_mul(self.tree_b, eps, rand_tree_b)
sum_tree_a_dict = tu.tree_add_scalar_mul(self.tree_a_dict,
eps,
rand_tree_a_dict)
sum_tree_b_dict = tu.tree_add_scalar_mul(self.tree_b_dict,
eps,
rand_tree_b_dict)
sum_array_a = tu.tree_add_scalar_mul(self.array_a, eps, rand_array_a)
sum_array_b = tu.tree_add_scalar_mul(self.array_b, eps, rand_array_b)
tree_sums = [sum_tree_a,
sum_tree_b,
sum_tree_a_dict,
sum_tree_b_dict,
sum_array_a,
sum_array_b]
trees = [self.tree_a,
self.tree_b,
self.tree_a_dict,
self.tree_b_dict,
self.array_a,
self.array_b]
chex.assert_trees_all_close(trees, tree_sums, atol=1e-5)

if __name__ == '__main__':
absltest.main()

0 comments on commit 927994c

Please sign in to comment.