Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.nn.sigmoid raises Exception: Leaked trace errors #7613

Closed
A-Alaa opened this issue Aug 13, 2021 · 14 comments
Closed

jax.nn.sigmoid raises Exception: Leaked trace errors #7613

A-Alaa opened this issue Aug 13, 2021 · 14 comments
Assignees
Labels
bug Something isn't working

Comments

@A-Alaa
Copy link

A-Alaa commented Aug 13, 2021

I got some luck to find that jax.nn.sigmoid was causing a very vague error for the last two days. The following code may reproduce this problem:

from functools import partial
import traceback

import jax
import jax.numpy as jnp
import haiku as hk

jax.config.update('jax_platform_name', 'gpu') # same error on 'cpu or 'gpu'
jax.config.update('jax_check_tracer_leaks', True)
jax.config.update('jax_log_compiles', True)
jax.config.update('jax_enable_checks', True)

def sigmoid(x):
    return 1. / (1. + jnp.exp(-x))

class Adjust(hk.Module):
    def __init__(self, size):
        super().__init__(name=None)
        self.__f = hk.Linear(size)
        
    def __call__(self, x):
        h = self.__f(x)
        return jax.nn.sigmoid(h) # <- Changing this to sigmoid(h) solves the problem!
    
def wrap_module(module, *module_args, **module_kwargs):
    def wrap(*args, **kwargs):
        model = module(*module_args, **module_kwargs)
        return model(*args, **kwargs)
    return wrap

prng_key = jax.random.PRNGKey(42)
x = jnp.zeros((50,))
init, adjust = hk.without_apply_rng(hk.transform(wrap_module(Adjust, size=50)))
params = init(prng_key, x)

@jax.jit
def loss(params, x):
    h = adjust(params, x)
    return jnp.sum((h - x)**2)

try:
    grads = jax.grad(loss)(params, x)
except Exception as e:
    tb_str = ''.join(traceback.format_exception(None, e, e.__traceback__))
    print(tb_str)
    raise e

And I get the following raised exception: Exception: Leaked sublevel 1., only when the loss function is jitted.
The logged traceback doesn't point to the sigmoid line at all (attached output: jax-error.txt)

The version of the libraries:

dm-haiku==0.0.5.dev0
jax==0.2.18
jaxlib==0.1.69+cuda101

However, using the following older versions have shown more helpful error message pointing to the sigmoid line:

dm-haiku==0.0.4
jax==0.2.12
jaxlib==0.1.64+cuda101
@A-Alaa A-Alaa added the bug Something isn't working label Aug 13, 2021
@mattjj
Copy link
Member

mattjj commented Aug 13, 2021

Thanks so much for reporting this, and for the repro!

Wow, that's a bad internal error. cc @LenaMartens in case it's Haiku-related. I'd love to try to make a repro without Haiku involved, to see if it's possible, since that would narrow things down a lot...

@A-Alaa
Copy link
Author

A-Alaa commented Aug 13, 2021

The same error happens without haiku.

import jax
import jax.numpy as jnp
import numpy.random as npr
from functools import partial
import traceback

jax.config.update('jax_platform_name', 'gpu')
jax.config.update('jax_check_tracer_leaks', True)
jax.config.update('jax_log_compiles', True)
jax.config.update('jax_enable_checks', True)

def sigmoid(x):
    return 1. / (1. + jnp.exp(-x))

x = jnp.ones((50,))
A = jnp.array(npr.randn(50, 50))


@jax.jit
def loss(A, x):
    h = jax.nn.sigmoid(A * x)
    return jnp.sum((h - x)**2)

try:
    grads = jax.grad(loss)(A, x)
except Exception as e:
    tb_str = ''.join(traceback.format_exception(None, e, e.__traceback__))
    print(tb_str)
    raise e

@mattjj mattjj assigned LenaMartens and unassigned LenaMartens Aug 13, 2021
@mattjj
Copy link
Member

mattjj commented Aug 13, 2021

@A-Alaa Hrm, that repro script in your latest message doesn't crash for me (against the current JAX main 74f96f1), at least on the CPU backend (and this error doesn't seem like it would be backend-dependent...).

@mattjj
Copy link
Member

mattjj commented Aug 13, 2021

Actually, the original doesn't repro either. Could this be fixed at HEAD?

@mattjj
Copy link
Member

mattjj commented Aug 13, 2021

Ah, I confirmed that it doesn't work against jax==0.2.18 (the current pypi version), but it does work against HEAD.

The solution is to cut a new pypi release of jax!

@mattjj mattjj self-assigned this Aug 13, 2021
@A-Alaa
Copy link
Author

A-Alaa commented Aug 13, 2021

The error also occurs at G Collaboratory environment:
https://colab.research.google.com/drive/1AlSWZzRO8RqGz7WkSFXoQrm6SikjjcwL?usp=sharing

@A-Alaa
Copy link
Author

A-Alaa commented Aug 13, 2021

  • Did it pass the unit tests at the corresponding versions?
  • Also, the error messages were not as helpful as they were in older versions, it could have saved a lot of time :(

Many thanks for this awesome library, it naturally helps write very clean code. I am very grateful for the team contributions in this library and their contributions in the literature of AD and Neural ODEs.

By the way, I wanted to post a new feature request to improve lax.cond and lax.switch, because they actually run all branches and select the results afterwards. I hope it could really execute single branch instead. I have seen similar requests posted but they were not updated for many months.

@mattjj
Copy link
Member

mattjj commented Aug 13, 2021

Did it pass the unit tests at the corresponding versions?

Great question. I don't know the history, so I can't say for sure, but I don't know of any failures. I should add a test case for this before closing the issue.

Also, the error messages were not as helpful as they were in older versions, it could have saved a lot of time :(

I don't think the error messages have changed. This code should run without error! Instead, this is just a really bad internal error message, and some bug must have been temporarily introduced to trigger it. (I'm not sure what the bug was... I should probably bisect it.)

Many thanks for this awesome library, it naturally helps write very clean code. I am very grateful for the team contributions in this library and their contributions in the literature of AD and Neural ODEs.

Thanks for the kind words!

By the way, I wanted to post a new feature request to improve lax.cond and lax.switch, because they actually run all branches and select the results afterwards. I hope it could really execute single branch instead. I have seen similar requests posted but they were not updated for many months.

They stage out all branches (and in so doing execute and trace the Python callables representing each branch), as they must because their purpose is to stage out control flow which can't be executed in Python (and hence either branch could be taken later, so both branches must be traced). Is that what you mean? (It might help to link the other issues, if you have them handy and if I'm missing the point you're making.)

@mattjj
Copy link
Member

mattjj commented Aug 13, 2021

Just pushed jax==0.2.19 to pypi! Can you confirm the bug no longer reproduces against that version?

@mattjj
Copy link
Member

mattjj commented Aug 13, 2021

I think the reason it seems like the error got worse is that the error must've been happening in the backward pass of AD, so the sigmoid function from the forward pass was already out of the picture, leading to this traceback from your text file (where jax-internal backward-pass-of-AD frames are hidden in ipython, and there are no user-code frames involved):

Exception                                 Traceback (most recent call last)
<ipython-input-19-2730c27a130a> in <module>
      9     tb_str = ''.join(traceback.format_exception(None, e, e.__traceback__))
     10     print(tb_str)
---> 11     raise e
     12 

<ipython-input-19-2730c27a130a> in <module>
      5 
      6 try:
----> 7     grads = jax.grad(loss)(params, x)
      8 except Exception as e:
      9     tb_str = ''.join(traceback.format_exception(None, e, e.__traceback__))

    [... skipping hidden 28 frame]

/opt/anaconda3/lib/python3.8/contextlib.py in __exit__(self, type, value, traceback)
    118         if type is None:
    119             try:
--> 120                 next(self.gen)
    121             except StopIteration:
    122                 return False

~/.local/lib/python3.8/site-packages/jax/core.py in new_sublevel()
    805     del sublevel
    806     if t() is not None:
--> 807       raise Exception(f'Leaked sublevel {t()}.')
    808 
    809 def full_lower(val):

Exception: Leaked sublevel 1.

mattjj added a commit that referenced this issue Aug 13, 2021
mattjj added a commit that referenced this issue Aug 13, 2021
@A-Alaa
Copy link
Author

A-Alaa commented Aug 13, 2021

Just pushed jax==0.2.19 to pypi! Can you confirm the bug no longer reproduces against that version?

Solved now!
Thanks a lot

They stage out all branches (and in so doing execute and trace the Python callables representing each branch), as they must because their purpose is to stage out control flow which can't be executed in Python (and hence either branch could be taken later, so both branches must be traced). Is that what you mean? (It might help to link the other issues, if you have them handy and if I'm missing the point you're making.)

Even after first-time compilation, when branches consists of complicated functions like odeint, the lax.cond executes all branches and then select the results from the two branches based on the condition value. I will post a new request after closing this issue.

@mattjj
Copy link
Member

mattjj commented Aug 13, 2021

A bisection suggests that 2190734 fixed this bug (thanks, @LenaMartens!), and fd7b286 introduced it or at least exposed it (curse you, @mattjj!), or at least was the first commit where the above jax-only repro started failing. (According to my process, it seems that jax==v0.2.12 also had an error, though maybe a different one, and I had to go back to jax==v0.2.10 to find a good pypi release.)

@mattjj
Copy link
Member

mattjj commented Aug 13, 2021

I think the last thing to do is to update colab. We'll get that done asap. In the meantime, I'm going to close this one.

Thanks again for reporting it!

@mattjj mattjj closed this as completed Aug 13, 2021
@mattjj
Copy link
Member

mattjj commented Aug 13, 2021

We're in the process of updating Colab, but it'll take a few days because it's blocked on something else.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants