Skip to content

Commit

Permalink
Add the input avals to Lowered and Compiled.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 433505462
  • Loading branch information
jblespiau authored and jax authors committed Mar 9, 2022
1 parent 41a8de6 commit 8a85544
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 80 deletions.
48 changes: 34 additions & 14 deletions jax/_src/api.py
Expand Up @@ -502,14 +502,18 @@ class Lowered:
querying properties of lowered computations across JAX's various
lowering paths (``jit``, ``pmap``, etc.).
"""
__slots__ = ['in_tree', 'out_tree', 'donate_argnums', '_lowering',
'_no_kwargs']
__slots__ = [
"in_tree", "in_avals", "out_tree", "donate_argnums", "_lowering",
"_no_kwargs"
]

# The PyTreeDef of the (positional arguments, keyword arguments).
#
# To get the individual PyTreeDef for the positional an keyword arguments,
# use `in_tree.children() which will return you a sequence of 2 PyTreeDef.
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
in_avals: Any
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_lowering: Union[dispatch.XlaComputation,
Expand All @@ -520,6 +524,7 @@ class Lowered:
def __init__(self,
lowering,
in_tree: PyTreeDef,
in_avals,
out_tree: PyTreeDef,
donate_argnums: Tuple[int],
no_kwargs: bool = False):
Expand All @@ -534,14 +539,15 @@ def __init__(self,
"""
self._lowering = lowering
self.in_tree = in_tree
self.in_avals = in_avals
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs

def compile(self) -> 'Compiled':
return Compiled(
self._lowering.compile(), self.in_tree, self.out_tree,
self.donate_argnums, self._no_kwargs)
self._lowering.compile(), self.in_tree, self.in_avals,
self.out_tree, self.donate_argnums, self._no_kwargs)

def compiler_ir(self, dialect: Optional[str] = None):
if dialect is None or dialect == "mhlo":
Expand All @@ -564,22 +570,28 @@ class Compiled:
common API for querying properties of compiled computations across
JAX's various compilation paths and backends.
"""
__slots__ = ['in_tree', 'out_tree', 'donate_argnums', '_executable',
'_no_kwargs']
__slots__ = [
"in_tree", "in_avals", "out_tree", "donate_argnums", "_executable",
"_no_kwargs"
]


# The PyTreeDef of the (positional arguments, keyword arguments).
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
in_avals: Any
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_executable: Union[dispatch.XlaCompiledComputation,
pxla.MeshExecutable,
pxla.PmapExecutable]
_no_kwargs: bool

def __init__(self, executable, in_tree, out_tree, donate_argnums,
def __init__(self, executable, in_tree, in_avals, out_tree, donate_argnums,
no_kwargs=False):
self._executable = executable
self.in_tree = in_tree
self.in_avals = in_avals
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs
Expand Down Expand Up @@ -666,10 +678,17 @@ def lower(*args, **kwargs) -> Lowered:
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
name = flat_fun.__name__
arg_specs = unsafe_map(arg_spec, args_flat)
computation = dispatch.lower_xla_callable(
flat_fun, device, backend, name, donated_invars, *arg_specs)
return Lowered(computation, in_tree, out_tree(), donate_argnums)
arg_specs_and_device = list(unsafe_map(arg_spec, args_flat))
# Only do this if the list is not empty
if arg_specs_and_device:
arg_specs = zip(*arg_specs_and_device)[0]
else:
arg_specs = []
computation = dispatch.lower_xla_callable(flat_fun, device, backend, name,
donated_invars,
*arg_specs_and_device)
return Lowered(computation, in_tree, in_tree.unflatten(arg_specs),
out_tree(), donate_argnums)

return lower

Expand Down Expand Up @@ -2178,7 +2197,7 @@ def lower(*args, **kwargs) -> Lowered:
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
global_arg_shapes, args, kwargs)
abstract_args = map(xla.abstractify, p.flat_args)
abstract_args = list(map(xla.abstractify, p.flat_args))
computation = pxla.lower_parallel_callable(
p.flat_fun, backend, axis_name,
axis_size=p.local_axis_size, global_axis_size=axis_size,
Expand All @@ -2189,7 +2208,8 @@ def lower(*args, **kwargs) -> Lowered:
donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
avals=abstract_args)
return Lowered(computation, p.in_tree, p.out_tree(), donate_tuple)
return Lowered(computation, p.in_tree, p.in_tree.unflatten(abstract_args),
p.out_tree(), donate_tuple)

return lower

Expand Down Expand Up @@ -2879,7 +2899,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]):
def _device_put_replicated(x):
aval = core.unmapped_aval(len(devices), core.no_axis_name, 0,
core.raise_to_shaped(core.get_aval(x)))
assert (isinstance(aval, core.ShapedArray) and
assert (isinstance(aval, ShapedArray) and
len(xla.aval_to_xla_shapes(aval)) == 1)
buf, = dispatch.device_put(x, devices[0])
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
Expand Down
7 changes: 5 additions & 2 deletions jax/experimental/maps.py
Expand Up @@ -659,9 +659,12 @@ def lower(*args):
params['resource_env'], params['backend'], params['spmd_in_axes'],
params['spmd_out_axes_thunk'], params['in_positional_semantics'],
params['out_positional_semantics'], *avals_flat)

in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
in_avals = in_tree.unflatten(avals_flat)
return Lowered(
computation, treedef_tuple([in_tree, tree_flatten({})[1]]),
out_tree(), donate_argnums, no_kwargs=True)
computation, in_tree, in_avals, out_tree(), donate_argnums,
no_kwargs=True)

fun_mapped = wraps(fun)(decorate_serial(fun_mapped))
fun_mapped.lower = decorate_serial(lower)
Expand Down
14 changes: 8 additions & 6 deletions jax/experimental/pjit.py
Expand Up @@ -256,28 +256,30 @@ def infer_params(*args, **kwargs):
name=getattr(flat_fun, '__name__', '<unnamed function>'),
in_positional_semantics=in_positional_semantics,
out_positional_semantics=out_positional_semantics)
return args_flat, params, in_tree, out_tree(), donate_argnums
return (args_flat, local_in_avals, params, in_tree, out_tree(),
donate_argnums)

@wraps(fun)
def wrapped(*args, **kwargs):
args_flat, params, _, out_tree, _ = infer_params(*args, **kwargs)
args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
for arg in args_flat:
_check_arg(arg)
out = pjit_p.bind(*args_flat, **params)
return tree_unflatten(out_tree, out)

def lower(*args, **kwargs):
args_flat, params, in_tree, out_tree, donate_argnums = \
infer_params(*args, **kwargs)
(args_flat, flat_local_in_avals, params, in_tree, out_tree,
donate_argnums) = infer_params(*args, **kwargs)
lowering = _pjit_lower(
params['jaxpr'], params['in_axis_resources'],
params['out_axis_resources'], params['resource_env'],
params['donated_invars'], params['name'],
params['in_positional_semantics'], params['out_positional_semantics'])

args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
return Lowered(lowering, args_kwargs_in_tree, out_tree, donate_argnums,
no_kwargs=True)
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
return Lowered(lowering, args_kwargs_in_tree, local_in_avals, out_tree,
donate_argnums, no_kwargs=True)

wrapped.lower = lower
return wrapped
Expand Down
105 changes: 55 additions & 50 deletions tests/api_test.py
Expand Up @@ -758,11 +758,16 @@ def f(x):
return jnp.sqrt(x ** 2) + 1.

f_jit = self.jit(f)
f_low = f_jit.lower(1.)
f_exe = f_low.compile()
self.assertAllClose(f_exe(1.), 2.)

self.assertEqual(f_exe.in_tree, jax.tree_flatten(((0,), {}))[1])
lowered = f_jit.lower(1.)
compiled = lowered.compile()
self.assertAllClose(compiled(1.), 2.)
self.assertEqual(lowered.in_avals, compiled.in_avals)
expected_dtype = np.float64 if config.x64_enabled else np.float32
for obj in [lowered, compiled]:
self.assertEqual(
obj.in_avals,
((jax.ShapedArray([], expected_dtype, weak_type=True),), {}))
self.assertEqual(obj.in_tree, jax.tree_flatten(((0,), {}))[1])

def test_jit_lower_duck_typing(self):
f_jit = self.jit(lambda x: 2 * x)
Expand Down Expand Up @@ -3362,8 +3367,8 @@ def sigmoid(x):

@jax.jit
def loss(A, x):
h = jax.nn.sigmoid(A * x)
return jnp.sum((h - x)**2)
h = jax.nn.sigmoid(A * x)
return jnp.sum((h - x)**2)

with jax.checking_leaks():
_ = jax.grad(loss)(A, x) # doesn't crash
Expand Down Expand Up @@ -5357,44 +5362,44 @@ def test_vmap_inside_defjvp(self):

@jax.custom_jvp
def f(mat, aux):
num_rows, num_cols = mat.shape
return jnp.ones((num_rows, 1)) / num_cols
num_rows, num_cols = mat.shape
return jnp.ones((num_rows, 1)) / num_cols

@f.defjvp
def f_jvp(primals, tangents):
mat, aux = primals
vec, _ = tangents
output = f(*primals)
num_rows, num_cols = mat.shape
size = num_rows * num_cols
# -----
bd_mat = mat.reshape(1, 1, num_rows, num_cols)
bd_mat = jnp.tile(bd_mat, reps=(num_rows, num_cols))
bd_mat = bd_mat.reshape(size, num_rows, num_cols)
# -----
rowsum = jnp.sum(mat, axis=1, keepdims=True)
colsum = jnp.sum(mat, axis=0, keepdims=True)
bd_rowsum = jnp.tile(rowsum, reps=(1, num_rows))
bd_colsum = jnp.tile(colsum, reps=(num_cols, 1))
# -----
bd_vec = vec.reshape(size, 1)
# -----
def operate(mx, val):
buf = 0
for i in range(2):
buf = buf + jnp.matmul(mx, bd_colsum) / jnp.power(aux, i)
buf = jnp.matmul(bd_rowsum, buf)
return buf * val[None, :]
# -----
# Vertorizing will raise shape error
bd_buf = jax.vmap(operate, in_axes=(0, 0), out_axes=0)(bd_mat, bd_vec)
# -----
bd_buf = bd_buf / aux
jvp = jnp.sum(bd_buf, axis=0)
jvp = jnp.mean(jvp, axis=1, keepdims=True)
# -----
# JVP ends successfully, but still raise an error
return (output, jvp)
mat, aux = primals
vec, _ = tangents
output = f(*primals)
num_rows, num_cols = mat.shape
size = num_rows * num_cols
# -----
bd_mat = mat.reshape(1, 1, num_rows, num_cols)
bd_mat = jnp.tile(bd_mat, reps=(num_rows, num_cols))
bd_mat = bd_mat.reshape(size, num_rows, num_cols)
# -----
rowsum = jnp.sum(mat, axis=1, keepdims=True)
colsum = jnp.sum(mat, axis=0, keepdims=True)
bd_rowsum = jnp.tile(rowsum, reps=(1, num_rows))
bd_colsum = jnp.tile(colsum, reps=(num_cols, 1))
# -----
bd_vec = vec.reshape(size, 1)
# -----
def operate(mx, val):
buf = 0
for i in range(2):
buf = buf + jnp.matmul(mx, bd_colsum) / jnp.power(aux, i)
buf = jnp.matmul(bd_rowsum, buf)
return buf * val[None, :]
# -----
# Vertorizing will raise shape error
bd_buf = jax.vmap(operate, in_axes=(0, 0), out_axes=0)(bd_mat, bd_vec)
# -----
bd_buf = bd_buf / aux
jvp = jnp.sum(bd_buf, axis=0)
jvp = jnp.mean(jvp, axis=1, keepdims=True)
# -----
# JVP ends successfully, but still raise an error
return (output, jvp)

jax.grad(lambda mat, aux: jnp.sum(f(mat, aux)))(mat, 0.5) # doesn't crash

Expand Down Expand Up @@ -6330,17 +6335,17 @@ def test_custom_vjp_scan_batching_edge_case(self):
def mul(x, coeff): return x * coeff
def mul_fwd(x, coeff): return mul(x, coeff), (x, coeff)
def mul_bwd(res, g):
x, coeff = res
g_x = g * coeff
g_coeff = (x * g).sum()
return g_x, g_coeff
x, coeff = res
g_x = g * coeff
g_coeff = (x * g).sum()
return g_x, g_coeff
mul.defvjp(mul_fwd, mul_bwd)

def scan_over_mul(x, coeff):
def f_(x, t):
return mul(x, coeff), None
y, _ = jax.lax.scan(f_, x, jnp.arange(3))
return y
def f_(x, t):
return mul(x, coeff), None
y, _ = jax.lax.scan(f_, x, jnp.arange(3))
return y

key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key, 2)
Expand Down
15 changes: 11 additions & 4 deletions tests/pjit_test.py
Expand Up @@ -679,8 +679,14 @@ def f(x, y):
x = jnp.arange(np.prod(shape)).reshape(shape)
expected = x @ (x + 1)

exe = f.lower(x, x + 1).compile()
actual = exe(x, x + 1)
lowered = f.lower(x, x + 1)
compiled = lowered.compile()
actual = compiled(x, x + 1)

self.assertEqual(lowered.in_avals, compiled.in_avals)
self.assertEqual(
lowered.in_avals,
((jax.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {}))

splits = np.split(expected, 4)
self.assertAllClose(actual.device_buffers[0].to_py(), splits[0],
Expand All @@ -692,8 +698,9 @@ def f(x, y):
self.assertAllClose(actual.device_buffers[3].to_py(), splits[3],
check_dtypes=False)

self.assertTrue(exe._no_kwargs, True)
self.assertEqual(exe.in_tree, jax.tree_flatten(((0, 0), {}))[1])
for obj in [lowered, compiled]:
self.assertTrue(obj._no_kwargs, True)
self.assertEqual(obj.in_tree, jax.tree_flatten(((0, 0), {}))[1])

@jtu.with_mesh([('x', 2), ('y', 2)])
def testLowerCompileWithKwargs(self):
Expand Down
12 changes: 8 additions & 4 deletions tests/pmap_test.py
Expand Up @@ -159,13 +159,17 @@ def testLowerCompile(self):
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = f(x)
f_exe = f.lower(x).compile()
ans = f_exe(x)
lowered = f.lower(x)
compiled = lowered.compile()
ans = compiled(x)

self.assertAllClose(ans, expected)

# It's a pair of: (positional args, as a tuple of their structures, kwargs).
self.assertFalse(f_exe._no_kwargs)
self.assertEqual(f_exe.in_tree, jax.tree_flatten(((0,), {}))[1])
for obj in [lowered, compiled]:
self.assertFalse(obj._no_kwargs)
self.assertEqual(obj.in_tree, jax.tree_flatten(((0,), {}))[1])
self.assertEqual(obj.in_avals, ((jax.ShapedArray(x.shape, x.dtype),), {}))

def testLowerCompileInTreeMismatch(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
Expand Down

0 comments on commit 8a85544

Please sign in to comment.