Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deleguate state management for update on events #126

Closed
eserie opened this issue May 1, 2021 · 4 comments
Closed

Deleguate state management for update on events #126

eserie opened this issue May 1, 2021 · 4 comments

Comments

@eserie
Copy link

eserie commented May 1, 2021

Hi,

I would like to implement a module that updates another module on a given condition/event and freezes its outputs the rest of the time.

Since the jax control flow has some limitations (see JAX documentation), it seems to me that to implement this idea, we need a mechanism that allows us to delegate the management of the state of a module to an external module/wrapper.

I show below an example of such a mechanism.
In the manage_state=False mode of my implementation, the state is handled much like RNNCore in haiku..

As the example shows, this mechanism works.

However, this approach will not work easily if the internal module is itself composed of other modules with state.

I wonder if there might be a mechanism in haiku that would allow a given module to "intercept" all the states of its internal modules in order to implement this idea in a more scalable way.

Does anyone have an insight about this?

Do you think it would be interesting to integrate such functionality in Haiku?

Here is my shortand implementation (maybe a bit long for this post, but it makes things explicit):

from abc import ABCMeta, abstractmethod
from collections import namedtuple
from typing import Any, Optional, Sequence

import haiku as hk
import jax
import jax.numpy as np


class DelegateStateCore(hk.Module, metaclass=ABCMeta):
    def __init__(self, manage_state=True, name=None):
        super().__init__(name=name)
        self.manage_state = manage_state

    def delegate_state_management(self):
        """Delegate state management to outer module."""
        self.manage_state = False
        return self

    @abstractmethod
    def initial_state(
        self, shape: Optional[Sequence[int]] = None, dtype: Any = np.float32
    ) -> Any:
        ...

    @abstractmethod
    def initial_output(
        self, shape: Optional[Sequence[int]] = None, dtype: Any = np.float32
    ) -> Any:
        ...


class UpdateOnEvent(hk.Module):
    """Apply module when an event occur otherwise return last computed output.
    If the module has state management, then it will be ask to delegate the state
    management to this module.
    """

    def __init__(self, module, name=None):
        super().__init__(name=name)
        if isinstance(module, DelegateStateCore):
            self.module = module.delegate_state_management()
        else:
            raise TypeError(f"only DelegateStateCore modules are managed")

        self._state_name = self.module.name + "_state"
        self._output_name = self.module.name + "_output"

    def __call__(self, on_event, input):
        prev_state = hk.get_state(
            self._state_name,
            shape=input.shape,
            dtype=input.dtype,
            init=self.module.initial_state,
        )
        prev_output = hk.get_state(
            self._output_name,
            shape=input.shape,
            dtype=input.dtype,
            init=self.module.initial_output,
        )
        res = self.module((input, prev_state))
        output, next_state = res

        def true_fun(operand):
            output, next_state, _, _ = operand
            return output, next_state

        def false_fun(operand):
            _, _, prev_output, prev_state = operand
            return prev_output, prev_state

        operand = (output, next_state, prev_output, prev_state)
        output, next_state = hk.cond(
            pred=on_event,
            true_operand=operand,
            true_fun=true_fun,
            false_operand=operand,
            false_fun=false_fun,
        )
        hk.set_state(self._state_name, next_state)
        hk.set_state(self._output_name, output)
        return output


class MyModule(DelegateStateCore):
    State = namedtuple("MyModuleState", "state")
    Output = namedtuple("MyModuleOutput", "output")

    def initial_state(self, shape, dtype):
        return self.State(np.full(shape, fill_value=0, dtype=dtype))

    def initial_output(self, shape, dtype):
        return self.Output(np.full(shape, fill_value=np.nan, dtype=dtype))

    def __call__(self, x):
        if self.manage_state:
            prev_state = hk.get_state(
                "state", shape=x.shape, dtype=x.dtype, init=self.initial_state
            )
        else:
            x, prev_state = x

        state = prev_state.state
        state = state + x
        output = state * 2

        if self.manage_state:
            hk.set_state("state", self.State(state))
            return self.Output(output)
        else:
            return self.Output(output), self.State(state)


def test_no_wrapping():
    print("test_no_wrapping")

    fun = hk.transform_with_state(lambda x: MyModule()(x))

    seq = hk.PRNGSequence(42)
    x = jax.random.normal(next(seq), shape=(10,))

    params, state = fun.init(next(seq), x)

    for i in range(10):
        output, state = fun.apply(params, state, next(seq), x)
        print(i, output.output[0])
    print("")
    return output


test_no_wrapping()


def test_deleguate_all_event_true():
    print("test_deleguate_all_event_true")
    fun = hk.transform_with_state(lambda event, x: UpdateOnEvent(MyModule())(event, x))

    seq = hk.PRNGSequence(42)
    x = jax.random.normal(next(seq), shape=(10,))

    event = True
    params, state = fun.init(next(seq), event, x)

    for i in range(10):
        output, state = fun.apply(params, state, next(seq), event, x)
        print(i, output.output[0])
    print("")
    return output


test_deleguate_all_event_true()


def test_deleguate_all_event_true():
    print("test_wrapping_some_event_true")
    fun = hk.transform_with_state(lambda event, x: UpdateOnEvent(MyModule())(event, x))

    seq = hk.PRNGSequence(42)
    x = jax.random.normal(next(seq), shape=(10,))

    event = False
    params, state = fun.init(next(seq), event, x)

    for i in range(10):
        if i in [3, 7]:
            event = True
        else:
            event = False
        output, state = fun.apply(params, state, next(seq), event, x)
        print(i, output.output[0])
    print("")
    return output


test_deleguate_all_event_true()

This will print:

test_no_wrapping
0 -3.3304188
1 -6.6608377
2 -9.991257
3 -13.321675
4 -16.652094
5 -19.982513
6 -23.312933
7 -26.643353
8 -29.973772
9 -33.30419

test_deleguate_all_event_true
0 -3.3304188
1 -6.6608377
2 -9.991257
3 -13.321675
4 -16.652094
5 -19.982513
6 -23.312933
7 -26.643353
8 -29.973772
9 -33.30419

test_wrapping_some_event_true
0 nan
1 nan
2 nan
3 -3.3304188
4 -3.3304188
5 -3.3304188
6 -3.3304188
7 -6.6608377
8 -6.6608377
9 -6.6608377
@eserie
Copy link
Author

eserie commented May 2, 2021

I think I found a way to proceed.
We would need a "setter" version of the method state_dict() :

Here I show a proposal for such a setter function:

def set_params_or_state_dict(
    module_name: str,
    submodules: Set[str],
    which: str,
    state_dict: Mapping[str, jnp.array],
):
    """Returns module parameters or state for the given module or submodules."""
    assert which in ("params", "state")
    frame = base.current_frame()
    for their_module_name, bundle in getattr(frame, which).items():
        if (
            their_module_name == module_name
            or their_module_name.startswith(module_name + "/")
            or their_module_name in submodules
        ):
            for name, value in bundle.items():
                fq_name = their_module_name + "/" + name
                if which == "state":
                    value_dict = value._asdict()
                    value_dict["current"] = state_dict[fq_name]
                    bundle[name] = type(value)(**value_dict)
                else:
                    bundle[name] = state_dict[fq_name]


def set_state_from_dict(self, next_state_dict):
    """Set state keyed by name for this module and submodules."""
    if not base.frame_stack:
        raise ValueError(
            "`module.set_state_from_dict()` must be used as part of an `hk.transform`."
        )
    set_params_or_state_dict(
        self.module_name,
        self._submodules,
        "state",
        next_state_dict,
    )


hk.Module.set_state_from_dict = set_state_from_dict

Then, a much more generic implementation of the module UpdateOnEvent would be:

from jax.tree_util import tree_map


class UpdateOnEvent(hk.Module):
    """Apply module when an event occur otherwise return last computed output.
    If the module has state management, then it will be ask to delegate the state
    management to this module.
    """

    def __init__(self, module, initial_output_value=np.nan, name=None):
        super().__init__(name=name)
        self.module = module
        self.initial_output_value = initial_output_value

    def __call__(self, on_event, input):
        # state_dict = self.state_dict()

        prev_state_dict = self.module.state_dict()

        output = self.module(input)

        prev_output = hk.get_state(
            "prev_output",
            shape=[],
            init=lambda *_: tree_map(
                lambda x: np.full(x.shape, self.initial_output_value, dtype=x.dtype),
                output,
            ),
        )

        next_state_dict = self.module.state_dict() if prev_state_dict else {}

        def true_fun(operand):
            output, next_state, _, _ = operand
            return output, next_state

        def false_fun(operand):
            _, _, prev_output, prev_state = operand
            return prev_output, prev_state

        operand = (output, next_state_dict, prev_output, prev_state_dict)
        output, next_state_dict = hk.cond(
            pred=on_event,
            true_operand=operand,
            true_fun=true_fun,
            false_operand=operand,
            false_fun=false_fun,
        )

        if next_state_dict:
            self.module.set_state_from_dict(next_state_dict)

        hk.set_state("prev_output", output)
        return output

which would work with any module.

Here is a usage example:

class MyModule(hk.Module):
    def __call__(self, x):
        prev_state = hk.get_state(
            "state",
            shape=x.shape,
            dtype=x.dtype,
            init=lambda shape, dtype: np.full(shape, fill_value=0, dtype=dtype),
        )

        state = prev_state
        state = state + x
        output = state * 2

        hk.set_state("state", state)

        return output


def test_deleguate_all_event_true():
    print("test_wrapping_some_event_true")
    fun = hk.transform_with_state(
        lambda event, x: UpdateOnEvent(MyModule(), initial_output_value=-999)(event, x)
    )

    seq = hk.PRNGSequence(42)
    x = jax.random.normal(next(seq), shape=(10,))

    event = False
    params, state = fun.init(next(seq), event, x)

    for i in range(10):
        if i in [3, 7]:
            event = True
        else:
            event = False
        output, state = fun.apply(params, state, next(seq), event, x)
        print(i, output[0])
    print("")
    return output


test_deleguate_all_event_true()

If you think it could be worth, two things could be done:

  • integrate the function set_params_or_state_dict() in haiku._src.module and add the method .set_state_from_dict to hk.Module.
  • Integrate the new module UpdateOnEvent if you think it is sufficiently generic and in the scope of the library.

@eserie
Copy link
Author

eserie commented Jun 16, 2021

To give a little more context, this development is now used in the open-source library WAX-ML available at https://github.com/eserie/wax-ml.
It would be interesting to know if this could be integrated into Haiku or if you would prefer it to be implemented outside of Haiku.

@tomhennigan
Copy link
Collaborator

Hi @eserie, I'm glad you were able to implement something to support your use case using Haiku!

In general we try to be very conservative when adding features to Haiku and where it is possible to implement things on top of Haiku it is preferable to do so. If lots of users tell us this is something they need, and your solution ends up getting forked into multiple projects then I think we should reconsider whether to upstream similar functionality.

FYI - your implementation is making use of several private APIs in Haiku, so you may find that future changes to our internals break your implementation (and will require you to update) however you can pin your library to specific versions of Haiku to mitigate this.

@eserie
Copy link
Author

eserie commented Jun 17, 2021

Hi @tomhennigan, thanks for your feedback.
I'll be watching for future changes and following your recommendations to pin WAX-ML to a fixed version, which makes sense since Haiku is so far perfect for my use case!
Feel free to contact me if you plan to implement something along these lines so I can anticipate the changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants