New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deleguate state management for update on events #126
Comments
I think I found a way to proceed. Here I show a proposal for such a setter function: def set_params_or_state_dict(
module_name: str,
submodules: Set[str],
which: str,
state_dict: Mapping[str, jnp.array],
):
"""Returns module parameters or state for the given module or submodules."""
assert which in ("params", "state")
frame = base.current_frame()
for their_module_name, bundle in getattr(frame, which).items():
if (
their_module_name == module_name
or their_module_name.startswith(module_name + "/")
or their_module_name in submodules
):
for name, value in bundle.items():
fq_name = their_module_name + "/" + name
if which == "state":
value_dict = value._asdict()
value_dict["current"] = state_dict[fq_name]
bundle[name] = type(value)(**value_dict)
else:
bundle[name] = state_dict[fq_name]
def set_state_from_dict(self, next_state_dict):
"""Set state keyed by name for this module and submodules."""
if not base.frame_stack:
raise ValueError(
"`module.set_state_from_dict()` must be used as part of an `hk.transform`."
)
set_params_or_state_dict(
self.module_name,
self._submodules,
"state",
next_state_dict,
)
hk.Module.set_state_from_dict = set_state_from_dict Then, a much more generic implementation of the module from jax.tree_util import tree_map
class UpdateOnEvent(hk.Module):
"""Apply module when an event occur otherwise return last computed output.
If the module has state management, then it will be ask to delegate the state
management to this module.
"""
def __init__(self, module, initial_output_value=np.nan, name=None):
super().__init__(name=name)
self.module = module
self.initial_output_value = initial_output_value
def __call__(self, on_event, input):
# state_dict = self.state_dict()
prev_state_dict = self.module.state_dict()
output = self.module(input)
prev_output = hk.get_state(
"prev_output",
shape=[],
init=lambda *_: tree_map(
lambda x: np.full(x.shape, self.initial_output_value, dtype=x.dtype),
output,
),
)
next_state_dict = self.module.state_dict() if prev_state_dict else {}
def true_fun(operand):
output, next_state, _, _ = operand
return output, next_state
def false_fun(operand):
_, _, prev_output, prev_state = operand
return prev_output, prev_state
operand = (output, next_state_dict, prev_output, prev_state_dict)
output, next_state_dict = hk.cond(
pred=on_event,
true_operand=operand,
true_fun=true_fun,
false_operand=operand,
false_fun=false_fun,
)
if next_state_dict:
self.module.set_state_from_dict(next_state_dict)
hk.set_state("prev_output", output)
return output which would work with any module. Here is a usage example: class MyModule(hk.Module):
def __call__(self, x):
prev_state = hk.get_state(
"state",
shape=x.shape,
dtype=x.dtype,
init=lambda shape, dtype: np.full(shape, fill_value=0, dtype=dtype),
)
state = prev_state
state = state + x
output = state * 2
hk.set_state("state", state)
return output
def test_deleguate_all_event_true():
print("test_wrapping_some_event_true")
fun = hk.transform_with_state(
lambda event, x: UpdateOnEvent(MyModule(), initial_output_value=-999)(event, x)
)
seq = hk.PRNGSequence(42)
x = jax.random.normal(next(seq), shape=(10,))
event = False
params, state = fun.init(next(seq), event, x)
for i in range(10):
if i in [3, 7]:
event = True
else:
event = False
output, state = fun.apply(params, state, next(seq), event, x)
print(i, output[0])
print("")
return output
test_deleguate_all_event_true() If you think it could be worth, two things could be done:
|
To give a little more context, this development is now used in the open-source library WAX-ML available at https://github.com/eserie/wax-ml. |
Hi @eserie, I'm glad you were able to implement something to support your use case using Haiku! In general we try to be very conservative when adding features to Haiku and where it is possible to implement things on top of Haiku it is preferable to do so. If lots of users tell us this is something they need, and your solution ends up getting forked into multiple projects then I think we should reconsider whether to upstream similar functionality. FYI - your implementation is making use of several private APIs in Haiku, so you may find that future changes to our internals break your implementation (and will require you to update) however you can pin your library to specific versions of Haiku to mitigate this. |
Hi @tomhennigan, thanks for your feedback. |
Hi,
I would like to implement a module that updates another module on a given condition/event and freezes its outputs the rest of the time.
Since the jax control flow has some limitations (see JAX documentation), it seems to me that to implement this idea, we need a mechanism that allows us to delegate the management of the state of a module to an external module/wrapper.
I show below an example of such a mechanism.
In the
manage_state=False
mode of my implementation, the state is handled much likeRNNCore
in haiku..As the example shows, this mechanism works.
However, this approach will not work easily if the internal module is itself composed of other modules with state.
I wonder if there might be a mechanism in haiku that would allow a given module to "intercept" all the states of its internal modules in order to implement this idea in a more scalable way.
Does anyone have an insight about this?
Do you think it would be interesting to integrate such functionality in Haiku?
Here is my shortand implementation (maybe a bit long for this post, but it makes things explicit):
This will print:
The text was updated successfully, but these errors were encountered: