Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to globally disable jit for better debugging #252

Closed
matthewdhoffman opened this issue Jan 16, 2019 · 4 comments
Closed

add option to globally disable jit for better debugging #252

matthewdhoffman opened this issue Jan 16, 2019 · 4 comments
Assignees
Labels
enhancement New feature or request

Comments

@matthewdhoffman
Copy link

(Relates to #196)

It's harder to debug jitted functions than non-jitted functions, since it's essentially impossible to introspect their intermediate state. One option is to comment out the jit decorator, but then you have to remember to put it back when you're done debugging.

It'd be nice to have a mechanism (maybe a context manager?) to globally disable jitting for debugging purposes; once bugs are fixed, you don't have to reinsert all of your @jit statements.

@mattjj mattjj added the enhancement New feature or request label Jan 16, 2019
@mattjj mattjj self-assigned this Feb 7, 2019
@mattjj
Copy link
Collaborator

mattjj commented Feb 7, 2019

Should be fixed by #337.

I tested the solution in #337 manually with

from jax import jit
from jax.config import config

config.update('jax_disable_jit', True)

effects = []

@jit
def f(x):
  effects.append(1)
  return x

f(2)
f(2)

assert len(effects) == 2

There are a few ways to set this option.

If you want to do it programmatically from inside Python, near the top of your source file or notebook you can write

from jax.config import config
config.update('jax_disable_jit', True)

That also works for updating the value of the flag.

Second, if you prefer to pass this option via a command line flag to a script, you can use absl like this

from absl import app
from jax.config import config

def main(argv):
  # ...

if __name__ == '__main__':
  config.config_with_absl()
  app.run(main)

You can instead call config.parse_flags_with_absl() if you want to skip app.run(main).

Finally, you can set the JAX_DISABLE_JIT environment variable to something truthy before importing jax, e.g.

JAX_DISABLE_JIT=1 python my_cool_script.py

Please reopen if this wasn't what you had in mind!

@mattjj
Copy link
Collaborator

mattjj commented Feb 7, 2019

Ah, I forgot you had mentioned a context manager. Let's see if I can add that to #337.

@mattjj
Copy link
Collaborator

mattjj commented Feb 7, 2019

I added a context manager that works like this:

effects = []

@api.jit
def f(x):
  effects.append(1)
  return x

with api.disable_jit():
  f(2)
  f(2)
assert len(effects) == 2

effects = []
f(2)
f(2)
assert len(effects) == 1

I also changed it so that disabling doesn't turn jit itself into a no-op; instead, on every call to an @jit function, a flag is checked.

@mattjj mattjj closed this as completed in 75a2745 Feb 7, 2019
mattjj added a commit that referenced this issue Feb 7, 2019
add flag to disable jit globally (fixes #252)
@starcrest
Copy link

Is there an equivalent flag for inspecting values inside xmap() functions?

The below code only prints Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>, even after disabling jit.

import jax
import jax.numpy as jnp
import numpy as np
from jax.config import config
config.update('jax_disable_jit', True)

def forward(x):
  print(f"forward x: {x}")
  return x + jax.lax.axis_index('batch')

x = jnp.ones([16]) * 10

devices = np.array(jax.devices())
with jax.experimental.maps.mesh(devices, ('dp',)):
    forward_xmap = jax.experimental.maps.xmap(fun=forward,
            in_axes=["batch"],
            out_axes=["batch"],
            axis_resources={'batch': 'dp'})
    res = forward_xmap(x)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants