Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faithfully reconstruct tree from context #38

Open
mattwescott opened this issue May 11, 2020 · 0 comments
Open

Faithfully reconstruct tree from context #38

mattwescott opened this issue May 11, 2020 · 0 comments

Comments

@mattwescott
Copy link

This variation on @tomhennigan's example tries to build a tree of module types.

It assumes the parameter creation order is preserved when flattening the parameter dictionary, which may be incorrect. Alternatively, if the path could be added to context, or if it is possible to recover the path from context, that would support a more satisfying solution. With module names and parameters possibly containing "/", it is not clear to me how to construct the path. What am I missing?

def init_and_build_module_tree(f):
    """
    Decorated functions build a tree of module types alongside the parameters

    Usage:
      def f(x):
        net = haiku.nets.MLP([300, 100, 10])
        return net(x)

      params, modules = init_and_build_module_tree(f)(rng_key, np.zeros(4))
      params = tree.map_structure(transform_params, params, modules)
    """

    def _init_and_build_module_tree(rng_key, *args, **kwargs):
        module_types = []

        def record_module_type(next_creator, shape, dtype, init, context):
            module_types.append(type(context.module))
            return next_creator(shape, dtype, init)

        def with_creator(*aargs, **kkwargs):
            with haiku.experimental.custom_creator(record_module_type):
                return f(*aargs, **kkwargs)

        params, _ = haiku.transform_with_state(with_creator).init(
            rng_key,
            *args,
            **kwargs
        )

        module_tree = tree.unflatten_as(
            params,
            module_types
        )

        return params, module_tree

    return _init_and_build_module_tree
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant