-
Notifications
You must be signed in to change notification settings - Fork 226
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create (internal only for now) context manager API for Haiku modules.
This makes transform trivial (the actual impl is only longer because of error checking): def transform_with_state(f): def init_fn(rng, x): with new_context(rng=rng) as ctx: f(x) return ctx.collect_params(), ctx.collect_state() def apply_fn(params, state, rng, x): with new_context(params=params, state=state, rng=rng) as ctx: return f(x), ctx.collect_state() return init_fn, apply_fn But also means we could in theory offer a fully imperative API: with new_context(rng=rng) as ctx: mod = hk.nets.MLP([300, 100, 10]) mod(example) params = ctx.collect_params() .. at some point later .. with new_context(params=params): out = mod(x) My motivation for exploring this API is that users very commonly want to be able to use their Haiku modules without having to wrap them in `hk.transform`, and some advanced users would like to be able to produce functions that look like `hk.transform` but with a different contract (e.g. producing multiple apply methods). My gut feeling is that this API while being strictly more flexible is also quite dangerous in JAX (e.g. there are many assumptions about functional purity in JAX which the results of `hk.transform` satisfies that this does not). I would like to make this available for folks to play with, but not expose it or promote it widely for now. Ping #16. PiperOrigin-RevId: 301013361 Change-Id: Ia5766a470ef08109c90be6d07f1226b742880c2b
- Loading branch information
1 parent
3a34141
commit 7560469
Showing
3 changed files
with
143 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters