-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
It's very difficult to write libraries that support both Haiku and plain Jax #100
Comments
I think that's fine. I (personally) won't be able to make time to make the change, but happy to review a PR that would. @tomhennigan for visibility |
Okay, great, thanks for the quick reply. I'll start working on a PR. |
Hi @NeilGirdhar and @trevorcai , I have some reservations about lifting this restriction. I think we should try as much as possible to avoid Haiku "leaking" into more of your program than it needs to be, and this is why hk.scan etc should only be used when you know that the code running is going to be hk.transform-ed. One concrete reason why this might be a bad idea, is that you are now locking your users into using Haiku for NN, while they may prefer to use another library (e.g. Flax, Trax etc). The lock in comes from the fact that each OOP JAX library needs to provide you a drop in for scan et al when they are managing the state, so say if you want to support Flax you would need to be able to switch between hk and flax.scan (I assume they have one). I'll reply on your other issues in a moment, but I wonder if it would be easier for your library to instead integrate with Haiku's pure functions that you get back from transform (e.g. |
Hi Tom, thanks for the detailed explanation.
The problem is that you don't know whether the code is going to transformed. Let's look at some concrete examples. In my exponential family library, efax, I calculate the Fisher information using The problem I'm actually running into is in tjax with my fixed point solver here. Same as before, there's no way to know whether this will be called by Jax or Haiku.
Yes, I agree with your point. I would love to hear alternatives. The issue is that this code leaks tracers if it's called from Haiku and isn't properly wrapped. I would love it if my libraries didn't have to know about Haiku.
Same as with the other problem, it's not possible because these functions are called deep in the code. If they were the outermost thing (for example, if I was always only calculating the Fisher information of the whole network, or the fixed point iteration was only ever applied to the whole network), then I agree that this is a reasonable workaround. Maybe, to prevent the locking-in behavior, we could come up with a way for the wrappers to be registered with Jax? It would be more work on the Jax side, but it would be much nicer for users. Something like this:
That makes things easier for users since they never have to remember to use Another benefit is that it would simplify Haiku's interface by eliminating the user-facing functions Finally, no one is trapped in any choice, and my libraries don't even need to know about Haiku. They just blindly use What do you think? |
I really like the idea of enabling libraries to override/monkey patch jit et al and I think I've discussed this before with @shoyer and @jekbradbury although I can't find the relevant issue. Do either of you remember if we filed a github issue? |
Ah here is the issue google/jax#4117. |
@tomhennigan Nice! I love the way they implemented that. I'm not sure what I should do while I wait for that pull request to be merged. Maybe I should merge that PR locally to my version of Jax and then add the appropriate Haiku calls (if you haven't got those somewhere already?) |
I think that sounds good, It looks like the PR mostly needs rebasing and someone to merge it so I don't think you'll need to wait for too long. Wrt how/if we make this a part of Haiku I think we'll need to think quite carefully about whether there are performance implications (e.g. transforms like _TRANSFORMS = {'lax.scan': hk.scan} # etc ..
def override_jax_transforms(f):
@functools.wraps
def wrapped(*args, **kwargs):
with jax.override_context(_TRANSFORMS):
return f(*args, **kwargs)
return wrapped
# user code
@hk.experimental.override_jax_transforms
def f(x):
...
# everything else stays the same
f = hk.transform(f)
params = f.init(..) |
Looks great! Looking forward to this. |
Yes, I'm happy to revive google/jax#4117 :) My original motivation was actually exactly this issue: we have code that we want to support both Haiku and JAX. So far we've gotten around this by writing our own stateful version of higher order functions like |
I'm working with @shoyer's pull request, and entering the context manager, but I'm getting
First off, what about parameters that are initialized by the Also, I can do something like what's recommended my code, but this is very painful for libraries that don't want to know anything about Haiku. Why not just make def hk.while_loop(cond_fun, body_fun, init_val):
if not base.params_frozen():
cond_fun(init_val)
return body_fun(init_val)
... I guess this doesn't work in some very weird cases where the parameter-getter in If this is unacceptable, then what about simply forcing the user to have every parameter that's used in the condition or body of a while loop to have already been initialized before the while loop? Something like, def hk.while_loop(cond_fun, body_fun, init_val):
if not base.params_frozen():
try:
with base.assert_state_unchanged():
cond_fun(init_val)
return body_fun(init_val)
except StateChangedError as e:
raise hk.StateChangedError("""No part of the Haiku managed state can be initialized in a while_loop.
Try to initialize the state beforehand. For example,
# Unconditionally initialize the state.
jax.initialize(cond_fun, init_val)
jax.initialize(body_fun, init_val)
val = hk.while_loop(cond_fun, body_fun, init_val)""") from e
... The problem with the latter approach is that you might have to convince the Jax team to add some kind of corresponding hook, like @overrideable('initialize')
def initialize(f, *args, **kwargs):
pass so that libraries can run this initialization code without knowing about Haiku. Haiku would then override it to def initialize(f, *args, **kwargs):
if hk.running_init():
f(*args, **kwargs) Thoughts? |
Sorry for the long delay, it's been a busy few weeks and I've not had the headspace to dig into this.
IIRC cond has a requirement that the output structure of each branch must be the same, so I think by construction Since we know one of the branches will run, and both branches create/use the same params we can safely support creation in cond.
This is what is implemented, we enforce this via requiring you to have I guess what you're asking is whether we can allow
It seems like there has not been much movement on here. An alternative solution might be to document some alternative designs that would allow library authors to accept functionally impure code and make use of jax transforms. For example instead of: def my_library_f(f, x):
x = jax.some_transform(f)(x)
...
return x We could suggest libraries support users passing in the transforms to use:
In the while loop case, users can then work around any library specific restrictions (as we have in hk.while_loop) without the library needing to care: if hk.running_init():
x = hk.Linear(1)(x).reshape(..)
else:
x = my_library_f(hk.Linear(1), x, hk.some_transform)
There are other ways to solve the "state management" problem. Haiku, Flax et al have point solutions outside JAX (which are convenient at first, but have some very sharp edges when combined with jax transforms and other libraries). JAX itself could support "implicit state" so at least the sharp edges would be consistent across stateful libraries. Doing so without losing the beautiful simplicity JAX's current explicit data flow design will be a challenge. I know that @LenaMartens, @mattjj, @jekbradbury and others have been thinking about this for a while, but there isn't a clear solution. |
(I'm sick in bed, so apologies if this reply doesn't make sense.)
Sorry, I wasn't clear. If you look at the error message in
I was just saying that you might also want to recommend
I'm just saying that ultimately I figured that
To my eyes, @shoyer's solution is extremely elegant by comparison.
Yes, I understand your point. Designing this well is going to make a really big difference. I appreciate all of the thought you all are putting into this. Jax is marvel of beautiful design, and Haiku is getting there! |
We decided to not merge google/jax#4117. But let me share how we've solved this problem in our own codebase, using our own versions of higher order functions like First, we define a version of import jax
import jax.numpy as jnp
import contextlib
_INITIALIZING = False
@contextlib.contextmanager
def init_context():
global _INITIALIZING
assert not _INITIALIZING
_INITIALIZING = True
yield
_INITIALIZING = False
def init_safe_scan(f, init, xs, length=None, default_scan=jax.lax.scan):
# version of lax.scan that allows for use under flax/haiku initialization
if _INITIALIZING: # could also use hk.running_init() here
xs_flat, treedef = jax.tree_flatten(xs)
if length is None:
length, = {x.shape[0] for x in xs_flat}
x0 = jax.tree_unflatten(treedef, [x[0] for x in xs_flat])
carry, y0 = f(init, x0)
ys = jax.tree_multimap(lambda *z: jnp.stack(z), *(length * [y0]))
return carry, ys
return default_scan(f, init, xs, length) Then in Haiku, you can write something like: import haiku as hk
def neural_net(x):
return hk.Linear(5)(x)
def haiku_init_safe_scan(f, init, xs, length=None):
return init_safe_scan(f, init, xs, length=None, default_scan=hk.scan)
def my_model(step_fn):
def doubled_step(x):
y, _ = haiku_init_safe_scan(
lambda x, _: (step_fn(x), _), init=x, xs=jnp.arange(2))
return y
return doubled_step
rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 5])
forward = hk.transform(my_model(neural_net))
with init_context():
params = forward.init(rng, x)
print(params) # only a single set of weights
logits = forward.apply(params, rng, x)
print(logits) # does not crash Presumably this sort of thing could be done for most/all higher order functions in JAX. It doesn't even have to be library specific, so I can imagine this being a good fit for a third-party library or perhaps even a |
I would like to extend my fixed point solver to work in both Haiku and Jax. I thought it would be as simple as replacing
jax.lax.scan
withhk.scan
, etc. Unfortunately, I getPerhaps I'm missing something, but would it be possible to reverse the design decision (#17) to raise an error and instead simply fall back to the Jax version of the command if the stateful context isn't needed?
Also, just out of curiosity, but is
jacfwd
broken in Haiku, or does it not need a stateful wrapper?@trevorcai WDYT?
The text was updated successfully, but these errors were encountered: