Skip to content

Commit

Permalink
Add support for extra args to multi transform
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 541573248
  • Loading branch information
rosshemsley authored and OptaxDev committed Jun 19, 2023
1 parent f527be8 commit fa4f976
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
14 changes: 10 additions & 4 deletions optax/_src/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class MultiTransformState(NamedTuple):
def multi_transform(
transforms: Mapping[Hashable, base.GradientTransformation],
param_labels: Union[base.PyTree, Callable[[base.PyTree], base.PyTree]]
) -> base.GradientTransformation:
) -> base.GradientTransformationExtraArgs:
"""Partitions params and applies a different transformation to each subset.
Below is an example where we apply Adam to the weights and SGD to the biases
Expand Down Expand Up @@ -130,6 +130,12 @@ def map_fn(nested_dict):
Returns:
An ``optax.GradientTransformation``.
"""

transforms = {
k: base.with_extra_args_support(v)
for k, v in transforms.items()
}

def make_mask(labels, group):
return jax.tree_util.tree_map(lambda label: label == group, labels)

Expand All @@ -148,13 +154,13 @@ def init_fn(params):
}
return MultiTransformState(inner_states)

def update_fn(updates, state, params=None):
def update_fn(updates, state, params=None, **extra_args):
labels = param_labels(updates) if callable(param_labels) else param_labels
new_inner_state = {}
for group, tx in transforms.items():
masked_tx = wrappers.masked(tx, make_mask(labels, group))
updates, new_inner_state[group] = masked_tx.update(
updates, state.inner_states[group], params)
updates, state.inner_states[group], params, **extra_args)
return updates, MultiTransformState(new_inner_state)

return base.GradientTransformation(init_fn, update_fn)
return base.GradientTransformationExtraArgs(init_fn, update_fn)
38 changes: 38 additions & 0 deletions optax/_src/combine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,44 @@ def test_multi_transform(self, use_fn):
updates, state = update_fn(updates, state)
chex.assert_trees_all_close(updates, correct_updates)

def test_extra_args(self):

class ArgNotEqual1Error(ValueError):
"""Raised when argument not set as expected."""

def init(params):
return {'mu': params}

def update_with_arg(updates, state, params=None, *, arg, **extra_args):
del params, extra_args
if arg != 1:
raise ArgNotEqual1Error()
return updates, state

def update_without_arg(updates, state, params=None):
del params
return updates, state

opt_no_arg = base.GradientTransformation(init, update_without_arg)
opt_extra_arg = base.GradientTransformationExtraArgs(init, update_with_arg)

opt = combine.multi_transform(
{
'a': opt_no_arg,
'b': opt_extra_arg,
},
('a', 'b'),
)

fake_params = ({'u': jnp.array([1])}, {'v': jnp.array([1])})
state = opt.init(fake_params)

with self.assertRaises(TypeError):
opt.update(fake_params, state)
with self.assertRaises(ArgNotEqual1Error):
opt.update(fake_params, state, arg=2, ignored_kwarg='hi')
opt.update(fake_params, state, arg=1, ignored_kwarg='hi')

@parameterized.parameters(list, tuple, dict)
def test_empty(self, container):
init_fn, update_fn = combine.multi_transform(
Expand Down

0 comments on commit fa4f976

Please sign in to comment.