-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add option to globally disable jit for better debugging #252
Comments
Should be fixed by #337. I tested the solution in #337 manually with from jax import jit
from jax.config import config
config.update('jax_disable_jit', True)
effects = []
@jit
def f(x):
effects.append(1)
return x
f(2)
f(2)
assert len(effects) == 2 There are a few ways to set this option. If you want to do it programmatically from inside Python, near the top of your source file or notebook you can write from jax.config import config
config.update('jax_disable_jit', True) That also works for updating the value of the flag. Second, if you prefer to pass this option via a command line flag to a script, you can use absl like this from absl import app
from jax.config import config
def main(argv):
# ...
if __name__ == '__main__':
config.config_with_absl()
app.run(main) You can instead call Finally, you can set the JAX_DISABLE_JIT=1 python my_cool_script.py Please reopen if this wasn't what you had in mind! |
Ah, I forgot you had mentioned a context manager. Let's see if I can add that to #337. |
I added a context manager that works like this: effects = []
@api.jit
def f(x):
effects.append(1)
return x
with api.disable_jit():
f(2)
f(2)
assert len(effects) == 2
effects = []
f(2)
f(2)
assert len(effects) == 1 I also changed it so that disabling doesn't turn |
add flag to disable jit globally (fixes #252)
Is there an equivalent flag for inspecting values inside The below code only prints import jax
import jax.numpy as jnp
import numpy as np
from jax.config import config
config.update('jax_disable_jit', True)
def forward(x):
print(f"forward x: {x}")
return x + jax.lax.axis_index('batch')
x = jnp.ones([16]) * 10
devices = np.array(jax.devices())
with jax.experimental.maps.mesh(devices, ('dp',)):
forward_xmap = jax.experimental.maps.xmap(fun=forward,
in_axes=["batch"],
out_axes=["batch"],
axis_resources={'batch': 'dp'})
res = forward_xmap(x) |
(Relates to #196)
It's harder to debug jitted functions than non-jitted functions, since it's essentially impossible to introspect their intermediate state. One option is to comment out the jit decorator, but then you have to remember to put it back when you're done debugging.
It'd be nice to have a mechanism (maybe a context manager?) to globally disable jitting for debugging purposes; once bugs are fixed, you don't have to reinsert all of your @jit statements.
The text was updated successfully, but these errors were encountered: