-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.nn.sigmoid
raises Exception: Leaked trace
errors
#7613
Comments
Thanks so much for reporting this, and for the repro! Wow, that's a bad internal error. cc @LenaMartens in case it's Haiku-related. I'd love to try to make a repro without Haiku involved, to see if it's possible, since that would narrow things down a lot... |
The same error happens without haiku. import jax
import jax.numpy as jnp
import numpy.random as npr
from functools import partial
import traceback
jax.config.update('jax_platform_name', 'gpu')
jax.config.update('jax_check_tracer_leaks', True)
jax.config.update('jax_log_compiles', True)
jax.config.update('jax_enable_checks', True)
def sigmoid(x):
return 1. / (1. + jnp.exp(-x))
x = jnp.ones((50,))
A = jnp.array(npr.randn(50, 50))
@jax.jit
def loss(A, x):
h = jax.nn.sigmoid(A * x)
return jnp.sum((h - x)**2)
try:
grads = jax.grad(loss)(A, x)
except Exception as e:
tb_str = ''.join(traceback.format_exception(None, e, e.__traceback__))
print(tb_str)
raise e |
Actually, the original doesn't repro either. Could this be fixed at HEAD? |
Ah, I confirmed that it doesn't work against jax==0.2.18 (the current pypi version), but it does work against HEAD. The solution is to cut a new pypi release of jax! |
The error also occurs at G Collaboratory environment: |
Many thanks for this awesome library, it naturally helps write very clean code. I am very grateful for the team contributions in this library and their contributions in the literature of AD and Neural ODEs. By the way, I wanted to post a new feature request to improve |
Great question. I don't know the history, so I can't say for sure, but I don't know of any failures. I should add a test case for this before closing the issue.
I don't think the error messages have changed. This code should run without error! Instead, this is just a really bad internal error message, and some bug must have been temporarily introduced to trigger it. (I'm not sure what the bug was... I should probably bisect it.)
Thanks for the kind words!
They stage out all branches (and in so doing execute and trace the Python callables representing each branch), as they must because their purpose is to stage out control flow which can't be executed in Python (and hence either branch could be taken later, so both branches must be traced). Is that what you mean? (It might help to link the other issues, if you have them handy and if I'm missing the point you're making.) |
Just pushed jax==0.2.19 to pypi! Can you confirm the bug no longer reproduces against that version? |
I think the reason it seems like the error got worse is that the error must've been happening in the backward pass of AD, so the
|
Solved now!
Even after first-time compilation, when branches consists of complicated functions like |
A bisection suggests that 2190734 fixed this bug (thanks, @LenaMartens!), and fd7b286 introduced it or at least exposed it (curse you, @mattjj!), or at least was the first commit where the above jax-only repro started failing. (According to my process, it seems that jax==v0.2.12 also had an error, though maybe a different one, and I had to go back to jax==v0.2.10 to find a good pypi release.) |
I think the last thing to do is to update colab. We'll get that done asap. In the meantime, I'm going to close this one. Thanks again for reporting it! |
We're in the process of updating Colab, but it'll take a few days because it's blocked on something else. |
I got some luck to find that
jax.nn.sigmoid
was causing a very vague error for the last two days. The following code may reproduce this problem:And I get the following raised exception:
Exception: Leaked sublevel 1.
, only when the loss function is jitted.The logged traceback doesn't point to the sigmoid line at all (attached output: jax-error.txt)
The version of the libraries:
However, using the following older versions have shown more helpful error message pointing to the sigmoid line:
The text was updated successfully, but these errors were encountered: