Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to print values inside a @jit-compiled function? #196

Closed
ndronen opened this issue Jan 5, 2019 · 8 comments
Closed

How to print values inside a @jit-compiled function? #196

ndronen opened this issue Jan 5, 2019 · 8 comments
Assignees
Labels
question Questions for the JAX team

Comments

@ndronen
Copy link

ndronen commented Jan 5, 2019

This might be related to #98.

I'm just dipping my toe in the water with JAX and tweaked update to try to print the value of a network's parameters:

  @jit
   def update(i, opt_state, batch):
       params = minmax.get_params(opt_state)
       print('params', params)
       return opt_update(i, grad(loss)(params, batch), opt_state)

The prints a list consisting of entries like

Traced<ShapedArray(float32[1,64]):JaxprTrace(level=-1/1)>

How do I inspect the underlying values?

@mattjj
Copy link
Collaborator

mattjj commented Jan 5, 2019

Thanks for this question!

I want to separate out two questions here:

  1. When we call print in a JAX-transformed Python function (like one with an @jit decorator), why does it print things like Traced<ShapedArray(float32[1,64]):JaxprTrace(level=-1/1)>?
  2. How do we print values in a jit-compiled function?

You actually asked the second one (or close to it), but I want to answer the first one first, since it sets up some useful context for the second.

But a short answer to the second question is "you can't print values inside a jit-compiled function yet, so in this case you should just remove @jit, and maybe move it to opt_update instead."

For the first question, this is getting at some things we haven't been able to document fully yet. They're summarized in the How it works section of the readme, and you should probably take a look at that first for the rest of this comment to make sense, but it's a pretty short sketch. This comment is also not a substitute for real documentation (or a paper, which we'll write someday, really), but hopefully it'll have some additional clues.

When JAX traces a function, as it does under a jit decorator, the function isn't evaluated on regular ndarray values. Instead, it's evaluated on abstract values, which are Python objects that model sets of possible values. In this case, the abstract value you're printing (an element of params) is a ShapedArray(float32[1,64]), which models the set of all arrays with shape (1, 64) and dtype float32. That object generally acts a lot like an ndarray, in that you can call jax.numpy functions on it and access methods and properties like .ravel() and .shape, but it's not committed to any specific numerical values. And if you try to force it to take on a specific numerical value, it'll raise an error.

Here's a simpler example:

from __future__ import print_function

from jax import jit

@jit
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y

f(2)

If you call this function without a @jit decorator, you know what to expect. But under @jit, you'll see prints like these:

Traced<ShapedArray(int32[]):JaxprTrace(level=-1/1)>
Traced<ShapedArray(int32[]):JaxprTrace(level=-1/1)>

Ignoring the Traced and JaxprTrace parts, you can see that the values flowing through the Python program are ShapedArray(int32[]), i.e. abstract values representing any possible scalar int32. What if we try to force these things to have specific values?

@jit
def g(x):
  if x < 0:
    return x
  else:
    return 2 * x

g(2)

When evaluating the Python code here, the expression x < 0 will be evaluated to an abstract value like ShapedArray(bool[]), which basically models the set {True, False}. Python will then try to coerce to a concrete bool (either True or False) so that it can take the branch, resulting in...

TypeError: Abstract value passed to `bool`, which requires a concrete value.
The function to be transformed can't be traced at the required level of abstraction.
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions
instead.

The reason we do this abstract value stuff is that JAX aims to specialize Python functions so it can transform them. Python functions on their own can mean too many different things to be automatically transformed. For example, take the function lambda x: x + 2. You might think that it adds two to its argument, and you'd be right if it's evaluated on an integer or a float, but what if that same function is called on an instance of a class that overrides __add__(self, x) to POST to some web server and start brewing x espresso shots? That's amazing flexibility for one bit of code to have, but it also means that we can't hope to understand and then transform raw Python functions: there are just too many behaviors to know what will happen until we actually run it.

To get a handle on what kinds of behaviors a Python function can have, we can try asking a more specific question: instead of "what behaviors does this function have when applied to any possible Python object?" we can ask "what behaviors does this function have when applied to any possible integer?" That is, we can specialize the function to subsets of possible argument values. For example, if we want to figure out what lambda x: x + 2 does to an integer, we could try applying it to an object that models the set of all possible integers (but not any one concrete integer). If we watch what happens to that object, and it makes it all the way to the other end of the function, then we can safely say we've figured out what the function lambda x: x + 2 would do to any possible integer. That is, we have a representation of the function specialized to integer values (but not specialized further). That's the kind of representation we can think of usefully compiling, where we can reuse the specialized-and-compiled version for any integer value we want to apply the function to. But if we want to figure out what lambda x: x + 2 does to a float, we'd have to repeat the process (since, for all we know, without looking at the source code, it could have things like dtype checks that switch its behavior).

You can think of this as similar to having type constraints in source code. That is, if the source code specified something like lambda x::int32 : x + 2::int32, then we'd know it can only ever apply to int32 values, and we could transform it without worrying about having to handle behaviors like making espressos. But JAX doesn't actually statically constrain the source code at all. Instead, it lets Python code be as general as Python likes to be, and attempts to take specialized views of it as needed at runtime.

How far should we abstract the values on which we specialize a function? There's a tradeoff here: the more abstract the arguments, the more general the view of the function that you get. But that's only if you actually succeed in evaluating the function on your abstract value. We saw above that by abstracting to a ShapedArray(int32[]), we could no longer successfully abstract-evaluate a Python function that had a test like if x < 0. If we had abstracted to the Unshaped(int32[]) level, then we couldn't handle a function that included Python expressions like if x.shape[0] < 5 or for row in x. With less specialization, we also wouldn't have as much information for a compiler to use to optimize the generated code. For jit, we think abstracting to shape/dtype level strikes a nice balance of traceability and also XLA optimizability (though that's actually easy to change, since JAX's core is general).

I've focused on the abstraction used for jit tracing, but it's interesting to think about other transformations from this perspective too. For example, for automatic differentiation we don't need to abstract the values much at all, which is why grad(lambda x: x if x > 0 else -x) works (except possibly at x == 0...). So different transformations have different abstraction tradeoff profiles, and thus abstractify to different levels in our implementation.

So to summarize, when the print function in your example gets called, there are actually no (concrete) underlying values to be printed at all! Instead, the only times the Python code is evaluated, there's just an object representing a set of possible values.

Okay, so that was an answer to the first question. Hopefully it also sets up why the answer to the second question isn't just a simple Python print statement. For now, the only answer we have is that you can't print things in compiled code.

If we want to be able to print values inside a jit-compiled function, we'll need to use a JAX-specific construct that stages out the printing into the compiled code. Printing is actually a side effect on the state of Python (and your terminal emulator), so print statements are harder to stage than other functions. Eventually we might support something like this:

@jit
def foo(x, python_io):
  x = 2 * x
  python_io = jax.print(python_io, x)
  return x, python_io

That python_io thing would be a special token that represents the state of the Python interpreter (and the rest of the world), and so you'd need to thread it through the parts of the computation you want to have side-effects, and it'd have to obey some linearity constraints. (It'd basically be an IO monad without any syntactic sugar.)

The upside is that once we have something like python_io, we could add support for more general statements like y, python_io = jax.py_call(some_py_fun, x, python_io) to let compiled code call back into side-effecting Python (not only to print but also to append values to lists, etc).

We'd love to hear any thoughts on all that! And just to make sure it's not missed, here's the short answer to your question one more time: you can't use print inside jit-compiled functions, and moreover there's currently no other way to print values from inside a compiled function, so you have to move jit if you want to print things.

@mattjj mattjj added the question Questions for the JAX team label Jan 5, 2019
@mattjj mattjj self-assigned this Jan 5, 2019
@mattjj mattjj changed the title How to get the value of a jaxpr? How to print values inside a @jit-compiled function? Jan 5, 2019
@mattjj mattjj changed the title How to print values inside a @jit-compiled function? How to print values inside a @jit-compiled function? Jan 5, 2019
@mattjj
Copy link
Collaborator

mattjj commented Jan 5, 2019

By the way, "jaxpr" is short for "JAX expression", and it's a name for the intermediate representation (IR) of a function that is sometimes formed during JAX transformations. That's distinct from the abstract values (and Tracers that wrap them) that flow through user code.

We also like saying "jaxpr" because it reminds us of our friend and colleague @JasperSnoek.

@mattjj
Copy link
Collaborator

mattjj commented Jan 12, 2019

I'm going to close this issue because it's inactive and because I think the question was answered, but please reopen if it needs more attention (or open a new one).

@AdrienCorenflos
Copy link
Contributor

Hi
Just posting that you developed an experimental thingy to do this

https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html?highlight=print#jax.experimental.host_callback.id_print

@KaleabTessera
Copy link

So using the experimental call host callback works!

E.g.

import jax
import jax.numpy as jnp
from jax.experimental.host_callback import call

@jax.jit
def selu(x, alpha=1.67, lmbda=1.05):
  call(lambda x: print(f"x: {x}"), x)
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1000000,))
selu(x)

@st--
Copy link

st-- commented Aug 24, 2022

As noted in #4615, in the latest release 0.3.16 there's now jax.debug.print: https://jax.readthedocs.io/en/latest/debugging/print_breakpoint.html

@BraveDrXuTF
Copy link

BraveDrXuTF commented Jan 14, 2024

So is there a way to transform dynamic array to static, and then print it?

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 14, 2024

Hi @BraveDrXuTF – this is covered in the FAQ here: https://jax.readthedocs.io/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-array. If all you want is to print it, you can use jax.debug.print.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

7 participants