Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In haiku, is there something equivalent to flax.nn.Model? #69

Closed
TuanNguyen27 opened this issue Aug 27, 2020 · 4 comments
Closed

In haiku, is there something equivalent to flax.nn.Model? #69

TuanNguyen27 opened this issue Aug 27, 2020 · 4 comments

Comments

@TuanNguyen27
Copy link

TuanNguyen27 commented Aug 27, 2020

Thanks for creating such a nice library!

My example use case in flax:

nn = ... # flax.nn.Module
 _, nn_params = nn.init(rng_key, data)
model = flax.nn.Model(nn, nn_params) 

With this, i can call model(x) and i can also access current params via model.params.

For haiku

nn = ... #hk.Module 
nn_params = nn.init(rng_key, data)
# i can do this
partial_fun = lambda data: nn.apply(nn_params, rng_key, data)

I can do partial_fun(x) but because partial_fun is a lambda, i can't access nn_params from partial_fun. Wondering if there's any workaround to achieve this in haiku.

For more context, I am trying to integrate haiku with numpyro so that we can convert traditional NN into Bayesian NN. You can see this issue pyro-ppl/numpyro#705 for more context, or this example notebook

@TuanNguyen27 TuanNguyen27 changed the title In haiku, is there something equivalent to flax.nn.Model? In haiku, is there something equivalent to flax.nn.Model? Aug 27, 2020
@cgarciae
Copy link

Haiku doesn't offers a Model interface. You can try Elegy which offers a Keras-like Model interface, however since version 0.2.0 it is no longer compatible with Haiku, the Module system is very similar though.

@TuanNguyen27
Copy link
Author

thanks for the tip, @cgarciae ! Just for my own curiosity, is there any feature available in haiku but not in Elegy ?

@cgarciae
Copy link

@TuanNguyen27 From the base implementation of the Module class + hooks (get_parameter and friends) probably not. If you need any layer not yet ported its usually very straight forward since the code is identical 99% of the time; you can post an issue since its our goal to have all of them.

@tomhennigan
Copy link
Collaborator

We don't ship a model interface in Haiku (in general we try to be un-opinionated wrt how you use your neural network), but you can write a small wrapper class if you want to hold your apply function and parameters as an object:

@dataclass
class Model:
  params: hk.Params
  apply: Callable

  def __call__(self, x):
    return self.apply(self.params, x)

model = Model(params, f.apply)
print(model.params)
out = model(x)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants