Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow creating module instances outside hk.transform #16

Open
gehring opened this issue Feb 29, 2020 · 5 comments
Open

Allow creating module instances outside hk.transform #16

gehring opened this issue Feb 29, 2020 · 5 comments
Labels
enhancement New feature or request

Comments

@gehring
Copy link

gehring commented Feb 29, 2020

This is as much a question as it is a feature request. What is the reasoning for not allowing a module instance from being created (but not used) outside hk.transform? I took a look at hk.Module and ModuleMetaClass but I feared my soul would get harvested by the dark forbidden magic involved before I could identify all the API features it permits.

For example, I would have expected this to be possible:

linear = hk.Linear(10)  # currently not allowed

def forward(x):
  return linear(x)

model = hk.transform(forward)

Concretely, I'm curious to know what would have to be sacrificed (if anything) to support this kind of usage? Is it meant to prevent a module instance from being used in two different functions wrapped by two different hk.transform calls?

I wouldn't be surprised if I were missing some nasty side effect if you were to allow module creation outside of hk.transform, but, if not, I think it would be more intuitive to allow this kind of usage.

@trevorcai
Copy link
Contributor

trevorcai commented Mar 1, 2020

The primary reason we don't currently allow this is that hk.Module objects have unique names (within their hk.transform), accessible via self.name or self.module_name. These names route parameters & state into the right place for hk.get_parameter calls, and are given to the module at construction time (in super().__init__(name=name)).

Uniquifying names requires us to track some state about the names that have already been created. We've made an attempt towards allowing the construction of modules that don't use hk.get_parameter and the other provided monads in their given constructors, but we haven't managed to do this without introducing persistent global state.

There are other solutions that we could try here! One idea is to late-bind names inside hk.transform, but we haven't prioritized this line of work.

Does that make sense? WDYT?

@gehring
Copy link
Author

gehring commented Mar 3, 2020

That all makes sense, thanks for the explanation!

One idea is to late-bind names inside hk.transform, but we haven't prioritized this line of work.

I think that would be great if that could be implemented without adding much complexity but I completely agree that it doesn't feel like a priority. I think the current API is just as powerful without this feature once you get use to it (which in my personal experience took me about 3 "oupsies" and cost me no more than 5 min in refactoring).

@gehring
Copy link
Author

gehring commented Mar 3, 2020

I'm not sure if you want to keep this issue open for feature request tracking purposes but, if not, feel free to close it.

@trevorcai
Copy link
Contributor

That's good to hear - that's been my experience as well :)
I'll leave this issue open to track this FR.

copybara-service bot pushed a commit that referenced this issue Mar 15, 2020
This makes transform trivial (the actual impl is only longer because of error
checking):

    def transform_with_state(f):
      def init_fn(rng, x):
        with new_context(rng=rng) as ctx:
          f(x)
        return ctx.collect_params(), ctx.collect_state()

      def apply_fn(params, state, rng, x):
        with new_context(params=params, state=state, rng=rng) as ctx:
          return f(x), ctx.collect_state()

      return init_fn, apply_fn

But also means we could in theory offer a fully imperative API:

    with new_context(rng=rng) as ctx:
      mod = hk.nets.MLP([300, 100, 10])
      mod(example)
      params = ctx.collect_params()

    .. at some point later ..

    with new_context(params=params):
      out = mod(x)

My motivation for exploring this API is that users very commonly want to be able
to use their Haiku modules without having to wrap them in `hk.transform`, and
some advanced users would like to be able to produce functions that look like
`hk.transform` but with a different contract (e.g. producing multiple apply
methods).

My gut feeling is that this API while being strictly more flexible is also quite
dangerous in JAX (e.g. there are many assumptions about functional purity in JAX
which the results of `hk.transform` satisfies that this does not). I would like
to make this available for folks to play with, but not expose it or promote it
widely for now.

Ping #16.

PiperOrigin-RevId: 301013361
Change-Id: Ia5766a470ef08109c90be6d07f1226b742880c2b
@tomhennigan tomhennigan added the enhancement New feature or request label Mar 18, 2020
@awav
Copy link

awav commented Mar 21, 2020

Hello @trevorcai, @tomhennigan. I like a lot out of the box solutions, but I struggle with extending haiku at the moment. I need constrained parameters like variance (only positive) for Gaussian distributions. The parameter can be represented as a composition constraint: unconstrained_parameter -> bijector.forward(parameter), in my code it is a property of the module. A dictionary with a set of parameters contains only unconstrained version, but for tracking and model printing we need constrained values and there is no way to get it because the model instance is hidden in the function.

class Parameter():
  def __init__(self, init_value: float, name: Text):
    super().__init__(name="parameter")
    self._name = name
    self._init = hk.initializer.Constant(jnp.log(init_value))

  def __call__(self):
    return jnp.exp(hk.get_parameter(f"unconstrained_{self._name}", shape=[], init=self._init))

class Model(hk.Module):
  def __init__(self, init_variance: float, name: Text):
    super().__init__(name)
    self._variance = Parameter(init_variance, "variance")

  @property
  def variance(self):
    return self._variance()

  def __call__(self, x: jnp.array) -> jnp.array:
    return jnp.sum(self.variance * x)

As you can see, a variance value in a parameter dictionary will not have much meaning without information about a transformation that a model uses (could be exp, softplus or another positive bijector).

1. One solution could be to return a model with transformed functions.

def forward_fn(x):
  m = Model(0.1)
  hk.link(m)  
  return m(x)

forward = hk.transoform(forward_fn)
model = forward.linked_objects  # get access to read only object

2. Another possible (?) solution could be making hk.transform a context manager

class Holder(hk.ModuleHolder):
  @hk.transform
  def forward(self, x):
    self.model = Model(0.1)
    return self.model(x)

forward = Holder().forward()

PS: for me, it is a very important issue and a deciding factor on how I'm going to use the library.

li-zihang pushed a commit to li-zihang/dm-haiku that referenced this issue Mar 8, 2022
PiperOrigin-RevId: 350006409
Change-Id: If5c80b0e443ad41767185459940b524b31e1c607
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants