Skip to content

Commit

Permalink
Make hk.switch support multiple operands just like jax.lax.switch.
Browse files Browse the repository at this point in the history
jax.lax.switch remained backwards compatible with passing `operand` as a
keyword, but because hk.switch is not widely used I propose we make this a
breaking change and remove the `operand` kwarg.

PiperOrigin-RevId: 500190018
  • Loading branch information
LenaMartens authored and Copybara-Service committed Jan 6, 2023
1 parent 9c9624a commit 057878b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
6 changes: 3 additions & 3 deletions haiku/_src/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def cond(*args, **kwargs):


@with_output_structure_hint
def switch(index, branches, operand):
def switch(index, branches, *operands):
"""Equivalent to :func:`jax.lax.switch` but with Haiku state passed in/out.
Note that creating parameters inside a switch branch is not supported, as such
Expand All @@ -560,7 +560,7 @@ def switch(index, branches, operand):
Args:
index: Integer scalar type, indicating which branch function to apply.
branches: Sequence of functions (A -> B) to be applied based on index.
operand: Operands (A) input to whichever branch is applied.
operands: Operands (A) input to whichever branch is applied.
Returns:
Value (B) of branch(*operands) for the branch that was selected based on
Expand All @@ -575,7 +575,7 @@ def switch(index, branches, operand):
stateful_branch_mem = _memoize_by_id(stateful_branch)
state = internal_state()
out, state = jax.lax.switch(
index, tuple(map(stateful_branch_mem, branches)), (state, (operand,)))
index, tuple(map(stateful_branch_mem, branches)), (state, operands))
update_internal_state(state)
return out

Expand Down
18 changes: 18 additions & 0 deletions haiku/_src/stateful_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,24 @@ def f(i, x):
self.assertEqual(state, {"square_module": {"y": y}})
self.assertEqual(out, y)

def test_switch_multiple_operands(self):
def f(i, x, y, z):
mod = SquareModule()
branches = [lambda x, y, z: mod(x),
lambda y, x, z: mod(x),
lambda y, z, x: mod(x),
]
return stateful.switch(i, branches, x, y, z)

f = transform.transform_with_state(f)
xyz = (1, 3, 5)
for i in range(3):
params, state = f.init(None, i, *xyz)
out, state = f.apply(params, state, None, i, *xyz)
expected_out = xyz[i]**2
self.assertEqual(state, {"square_module": {"y": expected_out}})
self.assertEqual(out, expected_out)

@test_utils.transform_and_run(run_apply=False)
def test_cond_branch_structure_error(self):
true_fn = lambda x: base.get_parameter("w", x.shape, x.dtype, init=jnp.ones)
Expand Down

0 comments on commit 057878b

Please sign in to comment.