Skip to content

Commit

Permalink
Add pjit rule to sparse_rules to support pjit. This is done to me…
Browse files Browse the repository at this point in the history
…rge the jit and pjit API.

PiperOrigin-RevId: 499311841
  • Loading branch information
yashk2810 authored and jax authors committed Jan 3, 2023
1 parent 9674b06 commit c3bb260
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
43 changes: 43 additions & 0 deletions jax/experimental/sparse/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@
from jax import core
from jax import lax
from jax._src import linear_util as lu
from jax._src import pjit
from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_sparse
import jax.numpy as jnp
from jax._src.api_util import flatten_fun_nokwargs
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax.util import safe_map, safe_zip, split_list
from jax._src.config import config
Expand Down Expand Up @@ -693,6 +695,47 @@ def _xla_call_sparse(spenv, *spvalues, call_jaxpr, donated_invars, **params):

sparse_rules[xla.xla_call_p] = _xla_call_sparse


def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics, keep_unused, inline):
if not config.jax_array:
raise ValueError('sparse pjit is only supported with jax.Array.')
if any(donated_invars):
raise NotImplementedError("sparse xla_call with donated_invars")

sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, jaxpr, *spvalues)
args_flat, _ = tree_flatten(spvalues_to_arrays(spenv, spvalues))
donated_invars = tuple(False for arg in args_flat)
in_positional_semantics = tuple(pxla._PositionalSemantics.GLOBAL
for _ in args_flat)
out_positional_semantics = tuple(pxla._PositionalSemantics.GLOBAL
for _ in sp_call_jaxpr.out_avals)

# TODO(yashkatariya, vanderplas): Flatten twice and set the correct sharding
# for data and indices.
in_shardings = in_shardings + tuple(
pxla._UNSPECIFIED for _ in range(len(args_flat) - len(in_shardings)))
out_shardings = out_shardings + tuple(
pxla._UNSPECIFIED for _ in range(len(sp_call_jaxpr.out_avals) - len(out_shardings)))

out_flat = pjit.pjit_p.bind(
*args_flat,
jaxpr=sp_call_jaxpr,
in_shardings=in_shardings,
out_shardings=out_shardings,
resource_env=resource_env,
donated_invars=donated_invars,
name=name,
in_positional_semantics=in_positional_semantics,
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused,
inline=inline)
return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat))

sparse_rules[pjit.pjit_p] = _pjit_sparse


def _duplicate_for_sparse_spvalues(spvalues, params):
for spvalue, param in safe_zip(spvalues, params):
yield from [param, param] if spvalue.is_sparse() else [param]
Expand Down
1 change: 1 addition & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,7 @@ jax_test(
],
"tpu": ["optonly"],
},
enable_configs = ["cpu_jit_pjit_api_merge"],
shard_count = {
"cpu": 40,
"gpu": 40,
Expand Down

0 comments on commit c3bb260

Please sign in to comment.