# `jax.debug.print` for simple inspection

Here is a rule of thumb:

- Use `jax.debug.print()` for traced (dynamic) array values with `jax.jit()`, `jax.vmap()` and others.

- Use Python `print()` for static values, such as dtypes and array shapes.

Recall from Just-in-time compilation that when transforming a function with `jax.jit()`, the Python code is executed with abstract tracers in place of your arrays. Because of this, the Python `print()` function will only print this tracer value:

In [2]:
import jax
import jax.numpy as jnp


@jax.jit
def f(x):
    print("print(x) ->", x)
    y = jnp.sin(x)
    print("print(y) ->", y)
    return y


result = f(2.0)

print(x) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
print(y) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [3]:
@jax.jit
def f(x):
    jax.debug.print("jax.debug.print(x) -> {x}", x=x)
    y = jnp.sin(x)
    jax.debug.print("jax.debug.print(y) -> {y}", y=y)
    return y


result = f(2.0)

jax.debug.print(y) -> 0.9092974662780762
jax.debug.print(x) -> 2.0


In [4]:
def f(x):
    jax.debug.print("jax.debug.print(x) -> {}", x)
    y = jnp.sin(x)
    jax.debug.print("jax.debug.print(y) -> {}", y)
    return y


xs = jnp.arange(3.0)

result = jax.vmap(f)(xs)

jax.debug.print(x) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.0
jax.debug.print(y) -> 0.8414710164070129
jax.debug.print(y) -> 0.9092974662780762


In [5]:
# a sequential map rather than a vectorization
result = jax.lax.map(f, xs)

jax.debug.print(x) -> 0.0
jax.debug.print(y) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(y) -> 0.8414710164070129
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974662780762


In [6]:
def f(x):
    jax.debug.print("jax.debug.print(x) -> {}", x)
    return x**2


result = jax.grad(f)(1.0)

jax.debug.print(x) -> 1.0


In [None]:
@jax.jit
def f(x, y):
    jax.debug.print("jax.debug.print(x) -> {}", x)
    jax.debug.print("jax.debug.print(y) -> {}", y)
    return x + y


f(1, 2)


@jax.jit
def f(x, y):
    jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
    jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
    return x + y


f(1, 2)

jax.debug.print(x) -> 1
jax.debug.print(y) -> 2
jax.debug.print(x) -> 1
jax.debug.print(y) -> 2


Array(3, dtype=int32, weak_type=True)

# `jax.debug.breakpoint` for `pdb`-like debugging

To pause your compiled JAX program during certain points during debugging, you can use `jax.debug.breakpoint()`. The prompt is similar to Python `pdb`, and it allows you to inspect the values in the call stack. In fact, `jax.debug.breakpoint()` is an application of `jax.debug.callback()` that captures information about the call stack.

To print all available commands during a `breakpoint` debugging session, use the `help` command. (Full debugger commands, the Sharp Bits, its strengths and limitations are covered in Advanced debugging.)

In [9]:
@jax.jit
def f(x):
    y, z = jnp.sin(x), jnp.cos(x)
    jax.debug.breakpoint()
    return y * z


f(2.0)  # ==> Pauses during execution

Entering jdb:
Array(2., dtype=float32)
Array(0.90929747, dtype=float32)
Array(-0.4161468, dtype=float32)
Array(1., dtype=float32)

Documented commands (type help <topic>):
EOF  c     continue  down  help  list  pp  quit  up  where
bt   cont  d         exit  l     p     q   u     w 



Array(-0.37840125, dtype=float32, weak_type=True)

In [10]:
def breakpoint_if_nonfinite(x):
    is_finite = jnp.isfinite(x).all()

    def true_fn(x):
        pass

    def false_fn(x):
        jax.debug.breakpoint()

    jax.lax.cond(is_finite, true_fn, false_fn, x)


@jax.jit
def f(x, y):
    z = x / y
    breakpoint_if_nonfinite(z)
    return z


f(2.0, 1.0)  # ==> No breakpoint

Array(2., dtype=float32, weak_type=True)

# `jax.debug.callback` for more control during debugging

Both `jax.debug.print()` and `jax.debug.breakpoint()` are implemented using the more flexible `jax.debug.callback()`, which gives greater control over the host-side logic executed via a Python callback. It is compatible with `jax.jit()`, `jax.vmap()`, `jax.grad()` and other transformations (refer to the Flavors of callback table in External callbacks for more information).

In [11]:
import logging


def log_value(x):
    logging.warning(f"Logged value: {x}")


@jax.jit
def f(x):
    jax.debug.callback(log_value, x)
    return x


f(1.0);



In [14]:
x = jnp.arange(5.0)
jax.vmap(f)(x)



Array([0., 1., 2., 3., 4.], dtype=float32)

In [15]:
jax.grad(f)(1.0)



Array(1., dtype=float32, weak_type=True)