Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error using batch with jit #5

Closed
ravidziv opened this issue Sep 21, 2019 · 6 comments
Closed

Error using batch with jit #5

ravidziv opened this issue Sep 21, 2019 · 6 comments
Labels
bug Something isn't working

Comments

@ravidziv
Copy link

ravidziv commented Sep 21, 2019

I get the error Too many leaves for PyTreeDef; expected 6. when I'm trying to run the following code -

def get_network(W_std=1):
    init_fun, apply_fun, ker_fun = stax.serial(
        stax.Dense(1, W_std=W_std, b_std=0.1)
    )
    ker_fun =jit(batch(ker_fun, batch_size=25, device_count=0))
    kdd = ker_fun(train_xs, None)
    return 0
jit(get_network)(2.0)
sschoenholz referenced this issue Sep 21, 2019
…by broadcasting numpy arrays and closing over non-arrays.

The first argument has to be a numpy array with the leading dimension of size `device_count` for pmapping. *NO ARRAYS ARE ALLOWED IN KEYWORD ARGUMENTS.

PiperOrigin-RevId: 270212471
@sschoenholz
Copy link
Contributor

sschoenholz commented Sep 21, 2019

Thanks for bringing this up. I believe this issue should have been fixed by Roman's work on jit-compilation in batching. I've tried your repro in colab and it seems to work, but let me know if any problems persist.

@ravidziv
Copy link
Author

Thanks, but it didn't solve it.
Colab code

@sschoenholz
Copy link
Contributor

Ah I see, thanks! So this appears to work if you remove the jit around get_network, interesting! We'll look into it soon, but perhaps not until after ICLR.

One point that we should make more clear in the docs, especially when we have docs, is that we have found poor memory characteristics of applying jit to batch. We think this is an issue on the JAX / XLA end but haven't had time to pursue it to get them a simplified repro. If you run into OOM errors, it might be better to jit-before-batch, though we expect this situation to be temporary.

@sschoenholz sschoenholz added the bug Something isn't working label Sep 22, 2019
@ravidziv
Copy link
Author

I understand, thanks!
Do Do you have an idea how can I vmap the batch for different networks?

@romanngg
Copy link
Contributor

It looks like the latest JAX version has this fixed, and both jit and vmap work! Here's an adapted example from above: https://colab.research.google.com/gist/romanngg/ffdd9a41fdf5479eaeac95772c259d27/jit_or_vmap_of_network.ipynb

Please note that I am also unsure of how reasonable it is to apply jit to batch, especially while having a vmap on top. You may want to also try the options of
a) vmap / pmap on get_network, and have no batch or jit anywhere inside the function.
b) a simple python for loop over the W_stds of the jitted get_network, with no batch or jit inside, or with batch and no jit inside.
etc.
I suggest this because the purpose of batch is to sacrifice parallelization for the purpose of reducing memory cost. Therefore it seems redundant / sub-optimal to first call batch and then parallelize it with vmap or jit afterwards. But I'm no expert on how these JAX optimizations work so I very well might be wrong here...

@romanngg
Copy link
Contributor

Closing, please let me know if there are still related issues!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants