-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to print values inside a @jit-compiled function? #196
Comments
Thanks for this question! I want to separate out two questions here:
You actually asked the second one (or close to it), but I want to answer the first one first, since it sets up some useful context for the second. But a short answer to the second question is "you can't print values inside a For the first question, this is getting at some things we haven't been able to document fully yet. They're summarized in the How it works section of the readme, and you should probably take a look at that first for the rest of this comment to make sense, but it's a pretty short sketch. This comment is also not a substitute for real documentation (or a paper, which we'll write someday, really), but hopefully it'll have some additional clues. When JAX traces a function, as it does under a Here's a simpler example: from __future__ import print_function
from jax import jit
@jit
def f(x):
print(x)
y = 2 * x
print(y)
return y
f(2) If you call this function without a
Ignoring the @jit
def g(x):
if x < 0:
return x
else:
return 2 * x
g(2) When evaluating the Python code here, the expression
The reason we do this abstract value stuff is that JAX aims to specialize Python functions so it can transform them. Python functions on their own can mean too many different things to be automatically transformed. For example, take the function To get a handle on what kinds of behaviors a Python function can have, we can try asking a more specific question: instead of "what behaviors does this function have when applied to any possible Python object?" we can ask "what behaviors does this function have when applied to any possible integer?" That is, we can specialize the function to subsets of possible argument values. For example, if we want to figure out what You can think of this as similar to having type constraints in source code. That is, if the source code specified something like How far should we abstract the values on which we specialize a function? There's a tradeoff here: the more abstract the arguments, the more general the view of the function that you get. But that's only if you actually succeed in evaluating the function on your abstract value. We saw above that by abstracting to a I've focused on the abstraction used for So to summarize, when the print function in your example gets called, there are actually no (concrete) underlying values to be printed at all! Instead, the only times the Python code is evaluated, there's just an object representing a set of possible values. Okay, so that was an answer to the first question. Hopefully it also sets up why the answer to the second question isn't just a simple Python If we want to be able to print values inside a @jit
def foo(x, python_io):
x = 2 * x
python_io = jax.print(python_io, x)
return x, python_io That The upside is that once we have something like We'd love to hear any thoughts on all that! And just to make sure it's not missed, here's the short answer to your question one more time: you can't use |
@jit
-compiled function?
@jit
-compiled function?
By the way, "jaxpr" is short for "JAX expression", and it's a name for the intermediate representation (IR) of a function that is sometimes formed during JAX transformations. That's distinct from the abstract values (and We also like saying "jaxpr" because it reminds us of our friend and colleague @JasperSnoek. |
I'm going to close this issue because it's inactive and because I think the question was answered, but please reopen if it needs more attention (or open a new one). |
Hi |
So using the experimental call host callback works! E.g.
|
As noted in #4615, in the latest release 0.3.16 there's now |
So is there a way to transform dynamic array to static, and then print it? |
Hi @BraveDrXuTF – this is covered in the FAQ here: https://jax.readthedocs.io/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-array. If all you want is to print it, you can use |
This might be related to #98.
I'm just dipping my toe in the water with JAX and tweaked
update
to try to print the value of a network's parameters:The prints a list consisting of entries like
How do I inspect the underlying values?
The text was updated successfully, but these errors were encountered: