Skip to content

Commit

Permalink
When passing ShapedArray as input to pjit.lower() (which is not all…
Browse files Browse the repository at this point in the history
…owed normally i.e `pjit(f)(*args)`), consider it as Global. This will also help during the auto sharding change.

PiperOrigin-RevId: 437094594
  • Loading branch information
yashk2810 authored and jax authors committed Mar 24, 2022
1 parent 45f80c0 commit 58efa00
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
28 changes: 17 additions & 11 deletions jax/experimental/pjit.py
Expand Up @@ -200,7 +200,7 @@ def pjit(fun: Callable,
donate_argnums = _ensure_index_tuple(donate_argnums)
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)

def infer_params(*args, **kwargs):
def infer_params(*args, _global_avals=False, **kwargs):
if kwargs:
raise NotImplementedError("pjit does not support kwargs")
if max(static_argnums + donate_argnums, default=-1) >= len(args):
Expand Down Expand Up @@ -235,9 +235,8 @@ def infer_params(*args, **kwargs):
local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
# TODO(yashkatariya): This is a hack. This should go away when avals have
# is_global attribute.
in_positional_semantics = tuple(
maps._PositionalSemantics.GLOBAL if isinstance(a, GDA) else maps._positional_semantics.val
for a in args_flat)
in_positional_semantics = tuple(tree_map(
partial(_get_in_positional_semantics, _global_avals), args_flat))
out_positional_semantics = maps._positional_semantics.val

global_in_avals, canonicalized_in_axis_resources_flat = _process_in_axis_resources(
Expand Down Expand Up @@ -270,9 +269,9 @@ def wrapped(*args, **kwargs):
out = pjit_p.bind(*args_flat, **params)
return tree_unflatten(out_tree, out)

def lower(*args, **kwargs):
def lower(*args, _global_avals=False, **kwargs):
(args_flat, flat_local_in_avals, params, in_tree, out_tree,
donate_argnums) = infer_params(*args, **kwargs)
donate_argnums) = infer_params(*args, _global_avals=_global_avals, **kwargs)
in_is_global = _calc_is_global_sequence(
params['in_positional_semantics'], params['in_axis_resources'])
lowering = _pjit_lower(
Expand Down Expand Up @@ -370,14 +369,15 @@ def _process_in_axis_resources(mesh, local_in_avals, in_axis_resources_thunk,
# will be raised because get_array_mapping (in local_to_global) of a
# FROM_GDA cannot happen.
tree_map(_check_resources_mismatch, in_axis_resources_flat, is_gda)
# If all inputs are either GDAs or fully replicated, then the avals are
# If all inputs have global semantics or fully replicated, then the avals are
# global and the mesh should also be global. This split is because
# non-contiguous mesh can only be used if all inputs are either GDAs or fully
# replicated.
# non-contiguous mesh can only be used if all inputs have global semantics or
# fully replicated.
# Use canonicalized in_axis_resources here because we want to treat P(None)
# and None (for example) as equivalent.
if all(((not _is_from_gda(p) and p.partitions == ()) or ig)
for p, ig in safe_zip(canonicalized_in_axis_resources_flat, is_gda)):
if all(
(not _is_from_gda(p) and p.partitions == ()) or ips == maps._PositionalSemantics.GLOBAL
for p, ips in safe_zip(canonicalized_in_axis_resources_flat, in_positional_semantics)):
# Shapes should be checked against non canonicalized in_axis_resources.
# For example, partitions of () and ((),) are not equivalent, since the
# first one is a valid spec for a scalar value, while the second is not!
Expand Down Expand Up @@ -1069,6 +1069,12 @@ def _calc_is_global_sequence(in_positional_semantics, in_axis_resources):
ips == maps._PositionalSemantics.GLOBAL or p.partitions == ()
for ips, p in safe_zip(in_positional_semantics, in_axis_resources))

def _get_in_positional_semantics(global_avals: bool, arg) -> maps._PositionalSemantics:
if isinstance(arg, GDA):
return maps._PositionalSemantics.GLOBAL
if global_avals and isinstance(arg, core.ShapedArray):
return maps._PositionalSemantics.GLOBAL
return maps._positional_semantics.val

def _create_cpspec(x):
return x if _is_from_gda(x) else CanonicalizedParsedPartitionSpec(x)
Expand Down
16 changes: 16 additions & 0 deletions tests/pjit_test.py
Expand Up @@ -27,6 +27,7 @@
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax import stages
from jax.errors import JAXTypeError
from jax import lax
# TODO(skye): do we still wanna call this PartitionSpec?
Expand Down Expand Up @@ -807,6 +808,21 @@ def f(x, y):
self.assertEqual(f(1, 'hi' ), 4)
self.assertEqual(f(1, 'bye'), 5)

@jtu.with_mesh([('x', 4), ('y', 2)])
def testLowerCompileWithAvals(self):
@partial(pjit,
in_axis_resources=P(('x', 'y'),),
out_axis_resources=P(('x', 'y'),))
def f(x, y):
return x @ y

shape = (8, 8)
aval = jax.ShapedArray(shape, jnp.int64)
x = jnp.arange(np.prod(shape)).reshape(shape)
exe = f.lower(aval, x, _global_avals=True).compile()
self.assertIsInstance(exe, stages.Compiled)
self.assertArraysEqual(exe(x, x), x @ x)


class GDAPjitTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 58efa00

Please sign in to comment.