Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make tracers tree-pretty-print their contents #2591

Merged
merged 1 commit into from Apr 3, 2020
Merged

Conversation

mattjj
Copy link
Member

@mattjj mattjj commented Apr 3, 2020

Fixes #1476.

import jax.numpy as np
from jax import vmap

def f(x):
  y = x ** 2
  print(y)
  return y

vmap(f)(np.arange(4))

Before this PR, that prints:

Traced<ShapedArray(int32[])>with<BatchTrace(level=0/0)>

After this PR, it prints:

Traced<ShapedArray(int32[])>with<BatchTrace(level=0/0)>
  with val = DeviceArray([0, 1, 4, 9], dtype=int32)
       batch_dim = 0

Here's another example:

import jax.numpy as np
from jax import vmap, jacfwd

def f(x):
  y = x ** 2
  print(y)
  return y

jacfwd(f)(np.arange(4.))
Traced<ConcreteArray([0. 1. 4. 9.])>with<JVPTrace(level=1/0)>
  with primal = DeviceArray([0., 1., 4., 9.], dtype=float32)
       tangent = Traced<ShapedArray(float32[4])>with<BatchTrace(level=0/0)>
                   with val = DeviceArray([[0., 0., 0., 0.],
                                           [0., 2., 0., 0.],
                                           [0., 0., 4., 0.],
                                           [0., 0., 0., 6.]], dtype=float32)
                        batch_dim = 0

The same works in pdb.

One more example!

import jax.numpy as np
from jax import vmap, jacfwd, grad, jit


def f(x):
  y = x ** 2
  print(y)
  return y

grad(f)(3.)
grad(jit(f))(3.)
Traced<ConcreteArray(9.0)>with<JVPTrace(level=1/0)>
  with primal = DeviceArray(9., dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=0/0)>

Traced<ShapedArray(float32[])>with<JVPTrace(level=1/1)>
  with primal = Traced<ShapedArray(float32[]):JaxprTrace(level=-1/1)>
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=0/1)>

That shows how grad specializes on concrete primal values, but traces the tangents with abstract values, unless of course there's a jit involved staging out the primal stuff too.

JAX is pretty special in that it performs many transformations on-the-fly while it's evaluating user code, rather than first staging things out of Python into some IR and then doing the transformations on that IR. We can take advantage of that to give users better debugging workflows, using standard Python tools (like prints and pdb).

This PR is just a change to the Tracer repr method.

@mattjj mattjj marked this pull request as ready for review April 3, 2020 22:42
@mattjj mattjj merged commit 60de46a into master Apr 3, 2020
@mattjj mattjj deleted the tracer-printing branch April 3, 2020 22:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Print values inside a vmap function
2 participants