Skip to content

Commit a4c1bee

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add a config option to remove size 1 mesh axis from ShapedArray.sharding.
PiperOrigin-RevId: 804088314
1 parent 4d71185 commit a4c1bee

File tree

4 files changed

+48
-2
lines changed

4 files changed

+48
-2
lines changed

jax/_src/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def trace_context():
233233
eager_constant_folding.value,
234234
numpy_dtype_promotion.value,
235235
default_device.value, random_seed_offset.value,
236+
remove_size_one_mesh_axis_from_type.value,
236237
threefry_partitionable.value,
237238
threefry_gpu_kernel_lowering.value,
238239
use_direct_linearize.value,
@@ -1150,6 +1151,13 @@ def _safer_randint_deprecation(new_val):
11501151
'DO NOT RELY ON THIS FLAG.'),
11511152
include_in_jit_key=True)
11521153

1154+
remove_size_one_mesh_axis_from_type = bool_state(
1155+
name='jax_remove_size_one_mesh_axis_from_type',
1156+
default=False,
1157+
upgrade=True,
1158+
help="Removes mesh axes of size 1 from ShapedArray.sharding",
1159+
include_in_jit_key=True)
1160+
11531161
# TODO make it so people don't use this, this is internal...
11541162
_check_vma = bool_state(
11551163
name='check_vma',

jax/_src/core.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,7 +1684,7 @@ def mem_space_to_kind(mem_space: MemorySpace) -> str:
16841684
assert False, "unreachable"
16851685

16861686

1687-
@cache(max_size=4096, trace_context_in_key=False)
1687+
@cache(max_size=4096, trace_context_in_key=True)
16881688
def update_aval_with_sharding(aval, sharding):
16891689
if isinstance(sharding, NamedSharding):
16901690
return aval.update(
@@ -2082,6 +2082,17 @@ def modify_spec_for_auto_manual(spec, mesh) -> P:
20822082
if mesh._name_to_type[u] == AxisType.Explicit}
20832083
return P(*new_spec, unreduced=new_unreduced, reduced=new_reduced)
20842084

2085+
def remove_size_one_mesh_axis(spec, mesh) -> P:
2086+
new_spec = [] # type: ignore
2087+
for s in spec:
2088+
if s is None:
2089+
new_spec.append(s) # type: ignore
2090+
elif isinstance(s, tuple):
2091+
new_spec.append(tuple(i for i in s if mesh.shape[i] != 1))
2092+
else:
2093+
new_spec.append(None if mesh.shape[s] == 1 else s) # type: ignore
2094+
return P(*new_spec, unreduced=spec.unreduced, reduced=spec.reduced)
2095+
20852096
def _maybe_modify_sharding(sharding, ndim):
20862097
if len(sharding.spec) == 0 or all(s is None for s in sharding.spec):
20872098
out = sharding
@@ -2090,6 +2101,8 @@ def _maybe_modify_sharding(sharding, ndim):
20902101
else:
20912102
out = sharding.update(spec=modify_spec_for_auto_manual(
20922103
sharding.spec, sharding.mesh))
2104+
if config.remove_size_one_mesh_axis_from_type.value:
2105+
out = out.update(spec=remove_size_one_mesh_axis(out.spec, out.mesh))
20932106
if len(out.spec) != ndim:
20942107
out = _make_lengths_same(out, ndim)
20952108
return out
@@ -2108,7 +2121,7 @@ def _check_divisibility(sharding, shape):
21082121
f" {size} times, but does not evenly divide the dimension size {sh}."
21092122
f" Got shape: {shape} and sharding {sharding}")
21102123

2111-
@cache(max_size=4096, trace_context_in_key=False)
2124+
@cache(max_size=4096, trace_context_in_key=True)
21122125
def get_sharding(sharding, shape):
21132126
"""Modifies and checks the sharding.
21142127

jax/_src/pjit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2725,6 +2725,8 @@ def check_shardings_are_auto(s: Sharding) -> None:
27252725
def assert_shardings_equal(x_aval, user_sharding: NamedSharding):
27262726
x_spec = x_aval.sharding.spec
27272727
user_spec = user_sharding.spec._normalized_spec_for_aval(x_aval.ndim)
2728+
if config.remove_size_one_mesh_axis_from_type.value:
2729+
user_spec = core.remove_size_one_mesh_axis(user_spec, user_sharding.mesh)
27282730
for x, s in zip(x_spec, user_spec):
27292731
if s is PartitionSpec.UNCONSTRAINED:
27302732
continue

tests/pjit_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8658,6 +8658,29 @@ def f(x):
86588658
f(inp)
86598659
self.assertEqual(tracing_count(), 2) # twice for f
86608660

8661+
@parameterized.named_parameters(
8662+
('1', P('x', 'y'), P('x', None)),
8663+
('2', P(('x', 'y')), P('x', None)),
8664+
('3', P('y'), P(None, None)),
8665+
)
8666+
@config.remove_size_one_mesh_axis_from_type(True)
8667+
@jtu.with_explicit_mesh((2, 1), ('x', 'y'))
8668+
def test_remove_size_one_mesh_axis(self, arr_spec, type_spec, mesh):
8669+
arr = jax.device_put(np.arange(16).reshape(8, 2), arr_spec)
8670+
8671+
@jax.jit
8672+
def f(x):
8673+
self.assertEqual(x.aval.sharding.spec, type_spec)
8674+
out = x * 2
8675+
self.assertEqual(out.aval.sharding.spec, type_spec)
8676+
# wsc should act as an assert.
8677+
out = with_sharding_constraint(out, arr_spec)
8678+
return out
8679+
8680+
out = f(arr)
8681+
self.assertEqual(out.sharding, NamedSharding(mesh, type_spec))
8682+
self.assertArraysEqual(out, arr * 2)
8683+
86618684

86628685
@jtu.pytest_mark_if_available('multiaccelerator')
86638686
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)