From fa4f976e3fc7d3c2f95def3e676315163e32bbd2 Mon Sep 17 00:00:00 2001 From: Ross Hemsley Date: Mon, 19 Jun 2023 02:41:57 -0700 Subject: [PATCH] Add support for extra args to multi transform PiperOrigin-RevId: 541573248 --- optax/_src/combine.py | 14 ++++++++++---- optax/_src/combine_test.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/optax/_src/combine.py b/optax/_src/combine.py index 00f81ea7..ef7bf2d9 100644 --- a/optax/_src/combine.py +++ b/optax/_src/combine.py @@ -75,7 +75,7 @@ class MultiTransformState(NamedTuple): def multi_transform( transforms: Mapping[Hashable, base.GradientTransformation], param_labels: Union[base.PyTree, Callable[[base.PyTree], base.PyTree]] -) -> base.GradientTransformation: +) -> base.GradientTransformationExtraArgs: """Partitions params and applies a different transformation to each subset. Below is an example where we apply Adam to the weights and SGD to the biases @@ -130,6 +130,12 @@ def map_fn(nested_dict): Returns: An ``optax.GradientTransformation``. """ + + transforms = { + k: base.with_extra_args_support(v) + for k, v in transforms.items() + } + def make_mask(labels, group): return jax.tree_util.tree_map(lambda label: label == group, labels) @@ -148,13 +154,13 @@ def init_fn(params): } return MultiTransformState(inner_states) - def update_fn(updates, state, params=None): + def update_fn(updates, state, params=None, **extra_args): labels = param_labels(updates) if callable(param_labels) else param_labels new_inner_state = {} for group, tx in transforms.items(): masked_tx = wrappers.masked(tx, make_mask(labels, group)) updates, new_inner_state[group] = masked_tx.update( - updates, state.inner_states[group], params) + updates, state.inner_states[group], params, **extra_args) return updates, MultiTransformState(new_inner_state) - return base.GradientTransformation(init_fn, update_fn) + return base.GradientTransformationExtraArgs(init_fn, update_fn) diff --git a/optax/_src/combine_test.py b/optax/_src/combine_test.py index 4fc8c108..b377243e 100644 --- a/optax/_src/combine_test.py +++ b/optax/_src/combine_test.py @@ -179,6 +179,44 @@ def test_multi_transform(self, use_fn): updates, state = update_fn(updates, state) chex.assert_trees_all_close(updates, correct_updates) + def test_extra_args(self): + + class ArgNotEqual1Error(ValueError): + """Raised when argument not set as expected.""" + + def init(params): + return {'mu': params} + + def update_with_arg(updates, state, params=None, *, arg, **extra_args): + del params, extra_args + if arg != 1: + raise ArgNotEqual1Error() + return updates, state + + def update_without_arg(updates, state, params=None): + del params + return updates, state + + opt_no_arg = base.GradientTransformation(init, update_without_arg) + opt_extra_arg = base.GradientTransformationExtraArgs(init, update_with_arg) + + opt = combine.multi_transform( + { + 'a': opt_no_arg, + 'b': opt_extra_arg, + }, + ('a', 'b'), + ) + + fake_params = ({'u': jnp.array([1])}, {'v': jnp.array([1])}) + state = opt.init(fake_params) + + with self.assertRaises(TypeError): + opt.update(fake_params, state) + with self.assertRaises(ArgNotEqual1Error): + opt.update(fake_params, state, arg=2, ignored_kwarg='hi') + opt.update(fake_params, state, arg=1, ignored_kwarg='hi') + @parameterized.parameters(list, tuple, dict) def test_empty(self, container): init_fn, update_fn = combine.multi_transform(