diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index bf97b4c81407..cd9aca8eec2a 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -55,11 +55,13 @@ from jax import core from jax import lax from jax._src import linear_util as lu +from jax._src import pjit from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_sparse import jax.numpy as jnp from jax._src.api_util import flatten_fun_nokwargs from jax.interpreters import partial_eval as pe from jax.interpreters import xla +from jax.interpreters import pxla from jax.tree_util import tree_flatten, tree_map, tree_unflatten from jax.util import safe_map, safe_zip, split_list from jax._src.config import config @@ -693,6 +695,47 @@ def _xla_call_sparse(spenv, *spvalues, call_jaxpr, donated_invars, **params): sparse_rules[xla.xla_call_p] = _xla_call_sparse + +def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, + resource_env, donated_invars, name, in_positional_semantics, + out_positional_semantics, keep_unused, inline): + if not config.jax_array: + raise ValueError('sparse pjit is only supported with jax.Array.') + if any(donated_invars): + raise NotImplementedError("sparse xla_call with donated_invars") + + sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, jaxpr, *spvalues) + args_flat, _ = tree_flatten(spvalues_to_arrays(spenv, spvalues)) + donated_invars = tuple(False for arg in args_flat) + in_positional_semantics = tuple(pxla._PositionalSemantics.GLOBAL + for _ in args_flat) + out_positional_semantics = tuple(pxla._PositionalSemantics.GLOBAL + for _ in sp_call_jaxpr.out_avals) + + # TODO(yashkatariya, vanderplas): Flatten twice and set the correct sharding + # for data and indices. + in_shardings = in_shardings + tuple( + pxla._UNSPECIFIED for _ in range(len(args_flat) - len(in_shardings))) + out_shardings = out_shardings + tuple( + pxla._UNSPECIFIED for _ in range(len(sp_call_jaxpr.out_avals) - len(out_shardings))) + + out_flat = pjit.pjit_p.bind( + *args_flat, + jaxpr=sp_call_jaxpr, + in_shardings=in_shardings, + out_shardings=out_shardings, + resource_env=resource_env, + donated_invars=donated_invars, + name=name, + in_positional_semantics=in_positional_semantics, + out_positional_semantics=out_positional_semantics, + keep_unused=keep_unused, + inline=inline) + return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat)) + +sparse_rules[pjit.pjit_p] = _pjit_sparse + + def _duplicate_for_sparse_spvalues(spvalues, params): for spvalue, param in safe_zip(spvalues, params): yield from [param, param] if spvalue.is_sparse() else [param] diff --git a/tests/BUILD b/tests/BUILD index 21d8657a33d9..489af392568d 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -729,6 +729,7 @@ jax_test( ], "tpu": ["optonly"], }, + enable_configs = ["cpu_jit_pjit_api_merge"], shard_count = { "cpu": 40, "gpu": 40,